-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecisiontree.py
147 lines (127 loc) · 5.13 KB
/
decisiontree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
"""
age hasjob hashouse credit label
young n n general n
"""
class Decisiontree():
def __init__(self):
self.feature_dim = None
self.child = list()
self.label = -1
def compute(D, dim):
label_dict = {}
for data in D:
if data[-1] not in label_dict:
label_dict[data[-1]] = 1
else:
label_dict[data[-1]] += 1
HD = 0
numpoints = len(D)
# compute H(D)
for label,count in label_dict.items():
p = 1.0*count/numpoints
HD += (p*np.log2(p))
HD = -1*HD
# compute H(D|A) on each attribute of a feature
# H(D|A) = p(A=ai)*H(D|A=ai)
# H(D|A=ai) = -p(D=c1|A=ai)logp(D=c1|A=ai) +
label_dict = {}
for data in D:
if data[dim] not in label_dict:
label_dict[data[dim]] = dict()
label_dict[data[dim]][data[-1]] = 1
else:
if data[-1] not in label_dict[data[dim]]:
# print label_dict[data[dim]], data[-1], 'yes'
label_dict[data[dim]][data[-1]] = 1
else:
label_dict[data[dim]][data[-1]] += 1
HDA = 0
for attribute, class_attribute in label_dict.items():
datapoints_this_attribute = sum(class_attribute.values())
p_Aai = 1.0*datapoints_this_attribute/numpoints
HDCkAai = 0
for label, count in class_attribute.items():
p_labelAai = 1.0*count/datapoints_this_attribute
HDCkAai += p_labelAai*np.log2(p_labelAai)
# print 'count', 'datapoints_this_attribute', count, datapoints_this_attribute, p_labelAai, HDCkAai
HDCkAai = -1*HDCkAai
HDA += p_Aai*HDCkAai
print 'for feature %d, entropy gain %s' % (dim, HD-HDA)
return HD-HDA
def most_frequent_label(D):
label_dict = dict()
for data in D:
if data[-1] not in label_dict:
label_dict[data[-1]] = 1
else:
label_dict[data[-1]] += 1
max_count = 0
max_label = 0
for label, count in label_dict.items():
if count > max_count:
max_label = label
max_count = count
return max_label
def build_decisiontree(D, dim_list):
max_gain = 0
best_dim = -1
feature_dim = len(D[0])-1
for dim in dim_list:
current_gain = compute(D, dim)
if current_gain > max_gain:
max_gain = current_gain
best_dim = dim
print 'dim_list', 'best_dim', 'max_gain', dim_list, best_dim, max_gain
if max_gain == 0:
print 'entropy gain is zero, all points belong to the same class'
else:
print 'select feature %d, entropy gain %s' % (best_dim, max_gain)
print
tree = Decisiontree()
# all data belong to a class
if max_gain == 0:
tree.label = D[0][-1]
return tree
dim_list.remove(best_dim)
if not D or len(dim_list) == 0:
return tree
tree.feature_dim = best_dim
tree.label = most_frequent_label(D)
sub_D_list = dict()
for data in D:
if data[best_dim] not in sub_D_list:
sub_D_list[data[best_dim]] = [data]
else:
sub_D_list[data[best_dim]].append(data)
for item in sub_D_list:
# print item, 'hello'
# print sub_D_list[item]
# print len(sub_D_list[item])
# print
tree.child.append(build_decisiontree(sub_D_list[item], dim_list))
return tree
def print_decisiontree(decisiontree):
if not decisiontree.child:
print 'I am a leaf node', decisiontree.feature_dim, decisiontree.label
else:
print 'I am a internal node', decisiontree.feature_dim, decisiontree.label
for one_child in decisiontree.child:
print_decisiontree(one_child)
if __name__ == "__main__":
D=[('young', 'not_has_job', 'not_has_house', 'general_credit', 'not_approved'), ('young', 'not_has_job', 'not_has_house', 'good_credit', 'not_approved'), \
('young', 'has_job', 'not_has_house', 'good_credit', 'approved'), ('young', 'has_job', 'has_house', 'general_credit', 'approved'), \
('young', 'not_has_job', 'not_has_house', 'general_credit', 'not_approved'), ('middle_age', 'not_has_job', 'not_has_house', 'general_credit', 'not_approved'), \
('middle_age', 'not_has_job', 'not_has_house', 'good_credit', 'not_approved'), ('middle_age', 'has_job', 'has_house', 'good_credit', 'approved'), \
('middle_age', 'not_has_job', 'has_house', 'very_good_credit', 'approved'), ('middle_age', 'not_has_job', 'has_house', 'very_good_credit', 'approved'), \
('old', 'not_has_job', 'has_house', 'very_good_credit', 'approved'), ('old', 'not_has_job', 'has_house', 'good_credit', 'approved'), \
('old', 'has_job', 'not_has_house', 'good_credit', 'approved'), ('old', 'has_job', 'not_has_house', 'very_good_credit', 'approved'), \
('old', 'not_has_job', 'not_has_house', 'general_credit', 'not_approved') \
]
for i in range(4):
tmp_set = set()
for item in D:
tmp_set.add(item[i])
print tmp_set
decisiontree = build_decisiontree(D, range(len(D[0])-1))
print_decisiontree(decisiontree)