import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch
from sklearn.datasets import make_moons
from sklearn.tree import DecisionTreeClassifier
def draw_tree_boundary(tree, ax, left, right, bottom, top,
i_node=0, stop_level=None, n_level=0):
if tree.children_left[i_node] == -1 or stop_level == n_level:
fc =\
'tab:orange' if np.argmax(tree.value[i_node][0])==0 else 'tab:blue'
rect = patch.Rectangle(xy=(left, bottom),
width=right-left, height=top-bottom, fc=fc, alpha=0.2)
ax.add_patch(rect)
return
if tree.feature[i_node] == 0:
f0 = tree.threshold[i_node]
ax.plot([f0, f0], [top, bottom])
draw_tree_boundary(tree=tree, ax=ax,
left=left, right=f0, top=top, bottom=bottom,
i_node=tree.children_left[i_node],
stop_level=stop_level, n_level=n_level+1,)
draw_tree_boundary(tree=tree, ax=ax,
left=f0, right=right, top=top, bottom=bottom,
i_node=tree.children_right[i_node],
stop_level=stop_level, n_level=n_level+1)
else:
f1 = tree.threshold[i_node]
ax.plot([left, right], [f1, f1])
draw_tree_boundary(tree=tree, ax=ax,
left=left, right=right, top=f1, bottom=bottom,
i_node=tree.children_left[i_node],
stop_level=stop_level, n_level=n_level+1)
draw_tree_boundary(tree=tree, ax=ax,
left=left, right=right, top=top, bottom=f1,
i_node=tree.children_right[i_node],
stop_level=stop_level, n_level=n_level+1)
X, y = make_moons(n_samples=20, noise=0.25, random_state=5)
treeclf = DecisionTreeClassifier(random_state=0)
treeclf.fit(X, y)
tree = treeclf.tree_
fig, ax = plt.subplots()
ax.scatter(X[y==0][:, 0], X[y==0][:, 1],
ec='k', s=60, marker='o', label="Class 0")
ax.scatter(X[y==1][:, 0], X[y==1][:, 1],
ec='k', s=60, marker='^', label="Class 1")
x0_min, x0_max = -2, 2.5
x1_min, x1_max = -1, 1.5
draw_tree_boundary(tree=tree, i_node=0, ax=ax,
left=x0_min, right=x0_max, bottom=x1_min, top=x1_max, stop_level=None)
ax.set_xlim(x0_min, x0_max)
ax.set_ylim(x1_min, x1_max)
ax.legend()
plt.show()