forked from m4jidRafiei/Decision-Tree-Python-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMain.py
More file actions
32 lines (22 loc) · 1003 Bytes
/
Main.py
File metadata and controls
32 lines (22 loc) · 1003 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from DecisionTree import DecisionTree
import pandas as pd
#Reading CSV file as data set by Pandas
data = pd.read_csv('playtennis.csv')
columns = data.columns
#All columns except the last one are descriptive by default
descriptive_features = columns[:-1]
#The last column is considered as label
label = columns[-1]
#Converting all the columns to string
for column in columns:
data[column]= data[column].astype(str)
data_descriptive = data[descriptive_features].values
data_label = data[label].values
#Calling DecisionTree constructor (the last parameter is criterion which can also be "gini")
decisionTree = DecisionTree(data_descriptive.tolist(), descriptive_features.tolist(), data_label.tolist(), "entropy")
#Here you can pass pruning features (gain_threshold and minimum_samples)
decisionTree.id3(0,0)
#Visualizing decision tree by Graphviz
decisionTree.print_visualTree()
print("System entropy: ", format(decisionTree.entropy))
print("System gini: ", format(decisionTree.gini))