DecisionTreeClassifier – Treeオブジェクト・再帰表示など

概要

Scikit-learnの決定木モデル、DecisionTreeClassifierについていろいろ試した際のコードをストック。

Treeオブジェクト内容確認

DecisionTreeClassifierオブジェクトのプロパティーtree_はデータセットに対して生成された決定木の構造が保存されている。以下はその内容を確認するためのコード。

Treeクラスはツリー内の各ノードの情報を1次元の配列でもっていて、子ノードを参照するにはノード番号に対応する配列のインデックスを参照する。Treeクラスが持っている主なプロパティーは以下の通り。

node_count
ツリーが持つ全ノード数。
children_left, children_right
各ノードの左/右の子ノードの番号を格納した1次元配列。
feature
各ノードを分割する際に使われる特徴量の番号を格納した1次元配列。
threshold
各ノードをfeatureで示された特性量で分割する際の閾値を格納した1次元配列。
value
各ノードにおける、各クラスのデータ数。クラス数分のデータを格納した1次元配列1つだけを要素とする2次元配列を、ノード数分だけ集めた3次元配列。

コードの実行結果は以下の通り。

親ノードと子ノードの関係は、たとえばノード0の左右の子ノードはchildren_leftchildren_rightの0番目の要素からノード1とノード4、ノード1の左右の子ノードはノード2とノード3、という風に追っていくことができる。

valueがややこしい。この配列は各ノードにおけるクラスごとのデータ数を格納している。全体配列の中にこのケースだとノード数に対応する7個の配列が要素として格納されているが、その配列が2次元配列になっていて、その要素の配列がクラスごとのデータを格納した配列になっている。例えば3番目の要素のクラス1の要素を取り出す場合にはvalue[3, 0, 1]と言う風に指定することになる。

Treeのコンソール表示

Treeオブジェクトのツリー構造を確認し、決定境界の描画などの準備とするために書いたコード。決定木の構造をコンソールに表示させる。2つの再帰関数を定義していて、本体は決定木学習後にそれらの関数を呼び出すのみ。

関数print_node1()は、ツリー構造をルートノードから階層が下がるごとに段下げして表示していく。このため、まず親ノードを表示してから左右の子ノードを引数として再帰呼び出しをしている。

終了条件はノードが子ノードを持たない葉(leaf)であることを利用するが、リーフの時のパラメータは以下の通りで、ここでは左子ノードの番号が−1となることを利用している。

  • 子ノードの番号が−1
  • 特性量の番号が−2
  • 特性量の閾値が−2.0

関数print_node2は、決定木の構造を枝分かれした木の形で表示する。左側のノードから右側に移るのを、コンソール上で上から下に表示していく。手順としては、

  1. リーフノードならノードの内容を出力してリターン
  2. リーフノードでなければ、
    1. 左子ノードの処理を呼び出す
    2. それが戻ってきたら(左側の全子孫ノードが出力されたら)自身の内容を出力
    3. 右子ノードの処理を呼び出す
    4. それが戻ってきたら(右側の全子孫ノードが出力されたら)リターン

引数に現在のノードの階層を保持する変数があり、その階層に応じた数のスペースでインデントすることで木の構造を表す。

出力は以下の通り。

決定木の構築過程の表示

make_monns()による2特性量のデータについて、順次ノードを分割する過程を図で描画するためのコード。

draw_tree_boundary()関数は再帰関数で、もしそのノードがリーフノードか指定された終了階層の場合はクラスに応じた色で領域を塗りつぶす。リーフノードでなければ、閾値が特性量0の場合と1の場合で境界線の縦横や開始終了位置を変化させて再帰的に関数を呼び出す。引数stop_levelに正の整数を指定することで、その階層までの描画に留めることができる。関数の内容についてはこちらを参照。

本体はデータをクラスごとの色で散布図として描き、ルートノードについてdraw_tree_boundary()を呼び出している。

以下は、実行例。

以下は、stop_levelを順次増やしていって、領域が分割される過程を描いた例。

決定木のツリー表示

DecisionTreeClassificationオブジェクトを可視化する環境によって、決定木を表示する例。

  1. 環境構築
    1. Pythonでpydotplusパッケージを導入
    2. Graphviz環境を構築
  2. 実行
    1. sklearn.tree.export_graphviz()で決定木のdotデータを得る
    2. pydotplus.graph_from_dot_data()Dotオブジェクトを生成
    3. write_png()などのメソッドでグラフを画像として書き出す

このコードはAtom上でコードを実行したため、Atomのディレクトリーに画像ファイルが書き出される。

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です