-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyze_dt.py
143 lines (115 loc) · 4.32 KB
/
analyze_dt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 Christopher C. Strelioff <[email protected]>
#
# Distributed under terms of the MIT license.
"""analyze_dt.py -- probe a decision tree found with scikit-learn."""
from __future__ import print_function
import os
import subprocess
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz
def get_code(tree, feature_names, target_names, spacer_base=" "):
"""Produce psuedo-code for decision tree.
Args
----
tree -- scikit-leant DescisionTree.
feature_names -- list of feature names.
target_names -- list of target (class) names.
spacer_base -- used for spacing code (default: " ").
Notes
-----
based on http://stackoverflow.com/a/30104792.
"""
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
print(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features, left[node],
depth+1)
print(spacer + "}\n" + spacer +"else {")
if right[node] != -1:
recurse (left, right, threshold, features, right[node],
depth+1)
print(spacer + "}")
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
print(spacer + "return " + str(target_name) + " ( " + \
str(target_count) + " examples )")
recurse(left, right, threshold, features, 0, 0)
def visualize_tree(tree, feature_names):
"""Create tree png using graphviz.
Args
----
tree -- scikit-learn DecsisionTree.
feature_names -- list of feature names.
"""
with open("dt.dot", 'w') as f:
export_graphviz(tree, out_file=f, feature_names=feature_names)
command = ["dot", "-Tpng", "dt.dot", "-o", "dt.png"]
try:
subprocess.check_call(command)
except:
exit("Could not run dot, ie graphviz, to produce visualization")
def encode_target(df, target_column):
"""Add column to df with integers for the target.
Args
----
df -- pandas DataFrame.
target_column -- column to map to int, producing new Target column.
Returns
-------
df -- modified DataFrame.
targets -- list of target names.
"""
df_mod = df.copy()
targets = df_mod[target_column].unique()
map_to_int = {name: n for n, name in enumerate(targets)}
df_mod["Target"] = df_mod[target_column].replace(map_to_int)
return (df_mod, targets)
def get_iris_data():
"""Get the iris data, from local csv or pandas repo."""
if os.path.exists("iris.csv"):
print("-- iris.csv found locally")
df = pd.read_csv("iris.csv", index_col=0)
else:
print("-- trying to download from github")
fn = "https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv"
try:
df = pd.read_csv(fn)
except:
exit("-- Unable to download iris.csv")
with open("iris.csv", 'w') as f:
print("-- writing to local iris.csv file")
df.to_csv(f)
return df
if __name__ == '__main__':
print("\n-- get data:")
df = get_iris_data()
print("\n-- df.head():")
print(df.head(), end="\n\n")
features = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth"]
df, targets = encode_target(df, "Name")
y = df["Target"]
X = df[features]
dt = DecisionTreeClassifier(min_samples_split=20, random_state=99)
dt.fit(X, y)
print("\n-- get_code:")
get_code(dt, features, targets)
print("\n-- look back at original data using pandas")
print("-- df[df['PetalLength'] <= 2.45]]['Name'].unique(): ",
df[df['PetalLength'] <= 2.45]['Name'].unique(), end="\n\n")
visualize_tree(dt, features)