-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecision_tree.py
More file actions
166 lines (141 loc) · 6.09 KB
/
decision_tree.py
File metadata and controls
166 lines (141 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
import math
import sys
import matplotlib.pyplot as plt
from visualise import visualise_tree
LABEL_COLUMN = 7
NUMBER_OF_FOLDS = 10
NUMBER_OF_ROOMS = 4
# Convert the txt dataset to a numpy ndarray
def load_dataset(filepath):
data = np.loadtxt(filepath)
return data
# Recursive function to build a decision tree to classify the data
def decision_tree_learning(dataset, depth):
labels = dataset[:, LABEL_COLUMN]
if np.all(labels[0] == labels):
# Base Case
# Return leaf node with this value, depth
return ({"class": labels[0]}, depth)
else:
# Recursive case
# Find the optimum split
root_node = find_split(dataset)
(left_data, right_data) = split_data(dataset, root_node["attribute"], root_node["value"])
root_node["left"], l_depth = decision_tree_learning(left_data, depth+1)
root_node["right"], r_depth = decision_tree_learning(right_data, depth+1)
return (root_node, max(l_depth, r_depth))
# Go through each attribute, and find which attribute gives the optimal information gain. Return
# the attribute and split value as a node
def find_split(training_data):
data_entropy = entropy(training_data)
max_info_gain, attribute, value = -1, 0, 0;
for i in range(len(training_data[0]) - 1):
sorted_data = training_data[training_data[:, i].argsort()]
for j in range(len(training_data) - 1):
if sorted_data[j][i] != sorted_data[j+1][i]:
split = (sorted_data[j][i] + sorted_data[j+1][i])/2
info_gain = information_gain(data_entropy, sorted_data[:j+1], sorted_data[j+1:])
if info_gain > max_info_gain:
max_info_gain, attribute, value = info_gain, i, split
node = {"attribute": attribute, "value": value}
return node
# Find information gained during a data split. Takes in starting data entropy, left subset and right subset as parameters
def information_gain(data_entropy, left_subset, right_subset):
data_size = len(left_subset) + len(right_subset)
left_entropy = entropy(left_subset)
right_entropy = entropy(right_subset)
return data_entropy - ((len(left_subset) / data_size) * left_entropy + (len(right_subset)/ data_size) * right_entropy)
# Calculate the entropy of a dataset
def entropy(dataset):
ans = 0
labels = dataset[:, LABEL_COLUMN]
label_counts = {}
for label in labels:
room = str(int(label))
label_counts[room] = label_counts.get(room, 0) + 1
for i in range(1, 5):
label_freq = label_counts.get(str(i), 0)
if label_freq != 0:
probability = label_counts.get(str(i), 0)/len(labels)
ans += probability * math.log(probability, 2)
ans *= -1
return ans
# Split the dataset into two by comparing an attribute with a value
def split_data(data, attribute, value):
left_data = data[data[:, attribute] < value]
right_data = data[data[:, attribute] >= value]
return (left_data, right_data)
# 10-fold cross validation of data
def cross_validate(data):
k = NUMBER_OF_FOLDS
np.random.seed(1)
np.random.shuffle(data)
split = np.split(data, k)
avg_confusion_matrix = np.zeros((NUMBER_OF_ROOMS,NUMBER_OF_ROOMS))
avg_precision = np.zeros(NUMBER_OF_ROOMS)
avg_recall = np.zeros(NUMBER_OF_ROOMS)
avg_f1 = np.zeros(NUMBER_OF_ROOMS)
for i in range(NUMBER_OF_FOLDS):
data_copy = split.copy()
test_data = data_copy.pop(i)
training_data = np.concatenate(data_copy)
(trained_tree, depth) = decision_tree_learning(training_data, 1)
(confusion_matrix, precision, recall, f1) = evaluate(test_data, trained_tree)
avg_confusion_matrix += confusion_matrix
avg_precision += precision
avg_recall += recall
avg_f1 += f1
avg_confusion_matrix /= NUMBER_OF_FOLDS
avg_precision /= NUMBER_OF_FOLDS
avg_recall /= NUMBER_OF_FOLDS
avg_f1 /= NUMBER_OF_FOLDS
correct = np.trace(avg_confusion_matrix)
all_elements = np.sum(avg_confusion_matrix)
accuracy = correct / all_elements
return (avg_confusion_matrix, avg_precision, avg_recall, avg_f1, accuracy)
# Find the confusion matrix, precision, recall and f1 of a given trained decision tree
def evaluate(test_db, trained_tree):
confusion_matrix = np.zeros((NUMBER_OF_ROOMS,NUMBER_OF_ROOMS))
for sample in test_db:
room = traverse_tree(sample, trained_tree)
confusion_matrix[int(sample[LABEL_COLUMN]) - 1, int(room) - 1] += 1
precision_arr = np.zeros(NUMBER_OF_ROOMS)
recall_arr = np.zeros(NUMBER_OF_ROOMS)
f1_arr = np.zeros(NUMBER_OF_ROOMS)
for room in range(NUMBER_OF_ROOMS):
precision = confusion_matrix[room, room] / np.sum(confusion_matrix[:, room])
precision_arr[room] = precision
recall = confusion_matrix[room,room] / np.sum(confusion_matrix[room,:])
recall_arr[room] = recall
f1_arr[room] = (2 * precision * recall)/(precision + recall)
return (confusion_matrix, precision_arr, recall_arr, f1_arr)
# Traverse a trained decision tree to find which class a given sample is predicted as
def traverse_tree(sample, trained_tree):
if "class" in trained_tree.keys():
return trained_tree["class"]
elif sample[trained_tree["attribute"]] < trained_tree["value"]:
return traverse_tree(sample, trained_tree["left"])
else:
return traverse_tree(sample, trained_tree["right"])
if __name__ == "__main__":
if sys.argv[1] == None:
print("dataset filepath needed")
else:
dataset = load_dataset(sys.argv[1])
(trained_tree, depth) = decision_tree_learning(dataset, 1)
plt.figure(figsize=(17,8))
visualise_tree(trained_tree)
plt.axis('off')
plt.show()
(avg_confusion_matrix, avg_precision, avg_recall, avg_f1, accuracy) = cross_validate(dataset)
print("Confusion Matrix:")
print(avg_confusion_matrix)
print("Precision:")
print(avg_precision)
print("Recall:")
print(avg_recall)
print("f1:")
print(avg_f1)
print("Accuracy:")
print(accuracy)