From 2f72f7d2626a0f38abd52692c8cf4b247a9f9dff Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:05:11 -0800 Subject: [PATCH 01/19] Initial commit for decision tree. --- src/decision_tree.py | 0 src/test_decision_tree.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/decision_tree.py create mode 100644 src/test_decision_tree.py diff --git a/src/decision_tree.py b/src/decision_tree.py new file mode 100644 index 0000000..e69de29 diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py new file mode 100644 index 0000000..e69de29 From 2ad032de48b13b1d3a8fb442873792b0e2ebe13e Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:06:31 -0800 Subject: [PATCH 02/19] including data file. --- src/flowers_data.csv | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 src/flowers_data.csv diff --git a/src/flowers_data.csv b/src/flowers_data.csv new file mode 100644 index 0000000..63fed67 --- /dev/null +++ b/src/flowers_data.csv @@ -0,0 +1,101 @@ +petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names +1.4,0.2,5.1,3.5,0,setosa +1.4,0.2,4.9,3.0,0,setosa +1.3,0.2,4.7,3.2,0,setosa +1.5,0.2,4.6,3.1,0,setosa +1.4,0.2,5.0,3.6,0,setosa +1.7,0.4,5.4,3.9,0,setosa +1.4,0.3,4.6,3.4,0,setosa +1.5,0.2,5.0,3.4,0,setosa +1.4,0.2,4.4,2.9,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.5,0.2,5.4,3.7,0,setosa +1.6,0.2,4.8,3.4,0,setosa +1.4,0.1,4.8,3.0,0,setosa +1.1,0.1,4.3,3.0,0,setosa +1.2,0.2,5.8,4.0,0,setosa +1.5,0.4,5.7,4.4,0,setosa +1.3,0.4,5.4,3.9,0,setosa +1.4,0.3,5.1,3.5,0,setosa +1.7,0.3,5.7,3.8,0,setosa +1.5,0.3,5.1,3.8,0,setosa +1.7,0.2,5.4,3.4,0,setosa +1.5,0.4,5.1,3.7,0,setosa +1.0,0.2,4.6,3.6,0,setosa +1.7,0.5,5.1,3.3,0,setosa +1.9,0.2,4.8,3.4,0,setosa +1.6,0.2,5.0,3.0,0,setosa +1.6,0.4,5.0,3.4,0,setosa +1.5,0.2,5.2,3.5,0,setosa +1.4,0.2,5.2,3.4,0,setosa +1.6,0.2,4.7,3.2,0,setosa +1.6,0.2,4.8,3.1,0,setosa +1.5,0.4,5.4,3.4,0,setosa +1.5,0.1,5.2,4.1,0,setosa +1.4,0.2,5.5,4.2,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.2,0.2,5.0,3.2,0,setosa +1.3,0.2,5.5,3.5,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.3,0.2,4.4,3.0,0,setosa +1.5,0.2,5.1,3.4,0,setosa +1.3,0.3,5.0,3.5,0,setosa +1.3,0.3,4.5,2.3,0,setosa +1.3,0.2,4.4,3.2,0,setosa +1.6,0.6,5.0,3.5,0,setosa +1.9,0.4,5.1,3.8,0,setosa +1.4,0.3,4.8,3.0,0,setosa +1.6,0.2,5.1,3.8,0,setosa +1.4,0.2,4.6,3.2,0,setosa +1.5,0.2,5.3,3.7,0,setosa +1.4,0.2,5.0,3.3,0,setosa +4.7,1.4,7.0,3.2,1,versicolor +4.5,1.5,6.4,3.2,1,versicolor +4.9,1.5,6.9,3.1,1,versicolor +4.0,1.3,5.5,2.3,1,versicolor +4.6,1.5,6.5,2.8,1,versicolor +4.5,1.3,5.7,2.8,1,versicolor +4.7,1.6,6.3,3.3,1,versicolor +3.3,1.0,4.9,2.4,1,versicolor +4.6,1.3,6.6,2.9,1,versicolor +3.9,1.4,5.2,2.7,1,versicolor +3.5,1.0,5.0,2.0,1,versicolor +4.2,1.5,5.9,3.0,1,versicolor +4.0,1.0,6.0,2.2,1,versicolor +4.7,1.4,6.1,2.9,1,versicolor +3.6,1.3,5.6,2.9,1,versicolor +4.4,1.4,6.7,3.1,1,versicolor +4.5,1.5,5.6,3.0,1,versicolor +4.1,1.0,5.8,2.7,1,versicolor +4.5,1.5,6.2,2.2,1,versicolor +3.9,1.1,5.6,2.5,1,versicolor +4.8,1.8,5.9,3.2,1,versicolor +4.0,1.3,6.1,2.8,1,versicolor +4.9,1.5,6.3,2.5,1,versicolor +4.7,1.2,6.1,2.8,1,versicolor +4.3,1.3,6.4,2.9,1,versicolor +4.4,1.4,6.6,3.0,1,versicolor +4.8,1.4,6.8,2.8,1,versicolor +5.0,1.7,6.7,3.0,1,versicolor +4.5,1.5,6.0,2.9,1,versicolor +3.5,1.0,5.7,2.6,1,versicolor +3.8,1.1,5.5,2.4,1,versicolor +3.7,1.0,5.5,2.4,1,versicolor +3.9,1.2,5.8,2.7,1,versicolor +5.1,1.6,6.0,2.7,1,versicolor +4.5,1.5,5.4,3.0,1,versicolor +4.5,1.6,6.0,3.4,1,versicolor +4.7,1.5,6.7,3.1,1,versicolor +4.4,1.3,6.3,2.3,1,versicolor +4.1,1.3,5.6,3.0,1,versicolor +4.0,1.3,5.5,2.5,1,versicolor +4.4,1.2,5.5,2.6,1,versicolor +4.6,1.4,6.1,3.0,1,versicolor +4.0,1.2,5.8,2.6,1,versicolor +3.3,1.0,5.0,2.3,1,versicolor +4.2,1.3,5.6,2.7,1,versicolor +4.2,1.2,5.7,3.0,1,versicolor +4.2,1.3,5.7,2.9,1,versicolor +4.3,1.3,6.2,2.9,1,versicolor +3.0,1.1,5.1,2.5,1,versicolor +4.1,1.3,5.7,2.8,1,versicolor From a944f660a7f619d28ccf88662f1b62b19273d5d8 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:13:34 -0800 Subject: [PATCH 03/19] Wrote out initial tree class. --- src/decision_tree.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/decision_tree.py b/src/decision_tree.py index e69de29..5505c7c 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -0,0 +1,17 @@ + +class DecisionTree(object): + """Define a Decision Tree class object.""" + + def __init__(self, max_depth, min_leaf_size): + """Initialize a Decision Tree object.""" + self.max_depth = max_depth + self.min_leaf_size = min_leaf_size + self.root = None + + def fit(self, data): + """Create a tree to fit the data.""" + pass + + def predict(self, data): + """Given data, return labels for that data.""" + pass From cb95f603ca1dade94be7d54417784a1b5ceffc4a Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 11:17:07 -0800 Subject: [PATCH 04/19] added some stuff. --- src/decision_tree.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/decision_tree.py b/src/decision_tree.py index 5505c7c..392e5e9 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,4 +1,17 @@ +class TreeNode(object): + """Define a Node object for use in a decision tree classifier.""" + + def __init__(self, split_value, data_set, label=None, left=None, right=None, parent=None): + """Initialize a node object for a decision tree classifier.""" + self.left = left + self.right = right + self.parent = parent + self.data_set = data_set + self.split_value = split_value + self.label = label + + class DecisionTree(object): """Define a Decision Tree class object.""" @@ -7,11 +20,17 @@ def __init__(self, max_depth, min_leaf_size): self.max_depth = max_depth self.min_leaf_size = min_leaf_size self.root = None + self.class_values = [] def fit(self, data): """Create a tree to fit the data.""" pass + def _calculate_gini(self, data): + """Calculate gini for a given data_set.""" + pass + def predict(self, data): """Given data, return labels for that data.""" pass + From 38e6948004c6cdc774aa5796f3556fe01d437b7d Mon Sep 17 00:00:00 2001 From: pasaunders Date: Thu, 9 Feb 2017 13:14:06 -0800 Subject: [PATCH 05/19] gini and split functions added --- src/decision_tree.py | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 392e5e9..fb67629 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,3 +1,5 @@ +import pandas as pd + class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" @@ -26,11 +28,45 @@ def fit(self, data): """Create a tree to fit the data.""" pass - def _calculate_gini(self, data): + def _calculate_gini(self, groups, class_values): """Calculate gini for a given data_set.""" - pass + gini = 0.0 + for class_value in class_values: + for group in groups: + size = len(group) + if size == 0: + continue + proportion = [row[-1] for row in group].count(class_value) / float(size) + gini += (proportion * (1.0 - proportion)) + return gini + + def _get_split(self, data): + """Choose a split point with lowest gini index.""" + classes = data["class_names"].unique() + split_col_index, split_value, split_gini, split_groups =\ + float('inf'), float('inf'), float('inf'), None + for col_index in range(len(data.columns) - 2): + for row in data: + groups = self._test_split(col_index, row[col_index], data) + gini = self._calculate_gini(groups, classes) + if gini < split_gini: + split_col_index, split_value, split_gini, split_groups =\ + col_index, row[col_index], gini, groups + return split_col_index, split_value, split_groups + + def _calculate_split(self, data): + lowest_gini = 1.0 + lowest_row = None + lowest_col = None + for row in data: + for col in data: + gini = self._calculate_gini(row, col) + if gini < lowest_gini: + lowest_gini = gini + lowest_row = row + lowest_col = col + return lowest_row, lowest_col def predict(self, data): """Given data, return labels for that data.""" pass - From 7b6964de7d790b88621fc98eafcda2d4560125bc Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 13:17:16 -0800 Subject: [PATCH 06/19] test database. --- src/decision_tree.py | 12 +++++++++++- src/test_dataset2.csv | 11 +++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 src/test_dataset2.csv diff --git a/src/decision_tree.py b/src/decision_tree.py index 392e5e9..b1c1708 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,3 +1,5 @@ +import pandas as pd + class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" @@ -28,7 +30,15 @@ def fit(self, data): def _calculate_gini(self, data): """Calculate gini for a given data_set.""" - pass + gini = 0.0 + for class_name in data["class_names"].unique(): + for col in data.columns[:-2]: + total_size = len(data) + if total_size == 0: + continue + proportion = [row[-1] for row in col].count(class_name) / float(total_size) + gini += (proportion * (1.0 - proportion)) + return gini def predict(self, data): """Given data, return labels for that data.""" diff --git a/src/test_dataset2.csv b/src/test_dataset2.csv new file mode 100644 index 0000000..e2a182d --- /dev/null +++ b/src/test_dataset2.csv @@ -0,0 +1,11 @@ +x,y,class_names +2.771244718,1.784783929,0 +1.728571309,1.169761413,0 +3.678319846,2.81281357,0 +3.961043357,2.61995032,0 +2.999208922,2.209014212,0 +7.497545867,3.162953546,1 +9.00220326,3.339047188,1 +7.444542326,0.476683375,1 +10.12493903,3.234550982,1 +6.642287351,3.319983761,1 \ No newline at end of file From e8e6f4cc84187470d4b4729354eaed22d521ed04 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 18:37:19 -0800 Subject: [PATCH 07/19] Wrote out draft of fit, predict, and helper functions. --- src/decision_tree.py | 78 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 25a4761..f2a6795 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -4,14 +4,27 @@ class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" - def __init__(self, split_value, data_set, label=None, left=None, right=None, parent=None): + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, parent=None): """Initialize a node object for a decision tree classifier.""" self.left = left self.right = right - self.parent = parent - self.data_set = data_set + self.data = data self.split_value = split_value - self.label = label + self.split_gini = split_gini + self.split_col = split_col + + def _has_children(self): + """Return True or False if Node has children.""" + if self.right or self.left: + return True + return False + + def _return_children(self): + """Return all children of a Node.""" + if self.left and self.right: + return [self.left, self.right] + elif self.left or self.right: + return [self.left] if self.left else [self.right] class DecisionTree(object): @@ -22,11 +35,53 @@ def __init__(self, max_depth, min_leaf_size): self.max_depth = max_depth self.min_leaf_size = min_leaf_size self.root = None - self.class_values = [] def fit(self, data): """Create a tree to fit the data.""" - pass + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + if not self.root: + self.root = new_node + if self._can_split(split_gini, len(data)): + self.root.left = self.fit(split_groups[0]) + self.root.right = self.fit(split_groups[1]) + else: + return new_node + + # def _build_tree(self, data): + # """Given a node, build the tree.""" + # split_col, split_value, split_gini, split_groups = self._get_split(data) + # new_node = TreeNode(split_value, data, label=split_col) + # if self._can_split(split_gini, len(data)): + # self.root.left = self._build_tree(split_groups[0], new_node) + # self.root.right = self._build_tree(split_groups[1], new_node) + # else: + # return new_node + + def _can_split(self, gini, data_size): + """Given a gini value, determine whether or not tree can split.""" + if gini == 0.0: + return False + elif self._depth() >= self.max_depth: + return False + elif data_size <= self.min_leaf_size: + return False + else: + return True + + def _depth(self, start=''): + """Return the integer depth of the BST.""" + def depth_wrapped(start): + if start is None: + return 0 + else: + right_depth = depth_wrapped(start.right) + left_depth = depth_wrapped(start.left) + return max(right_depth, left_depth) + 1 + if start is '': + return depth_wrapped(self.root) + else: + return depth_wrapped(start) def _calculate_gini(self, groups, class_values): """Calculate gini for a given data_set.""" @@ -53,7 +108,7 @@ def _get_split(self, data): if gini < split_gini: split_col, split_value, split_gini, split_groups =\ col, row[col], gini, groups - return split_col, split_value, split_groups + return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): """Given a dataset, column index, and value, split the dataset.""" @@ -68,4 +123,11 @@ def _test_split(self, col, value, data): def predict(self, data): """Given data, return labels for that data.""" - pass + curr_node = self.root + while curr_node._has_children(): + if data[curr_node.label]: + curr_node = curr_node.right + else: + curr_node = curr_node.left + return curr_node.label + From 772a3e6792be4c351b39992a5cf87974447cd9f3 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Thu, 9 Feb 2017 18:37:21 -0800 Subject: [PATCH 08/19] testing _get_split --- setup.py | 2 +- src/test_decision_tree.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 24d582d..27ad944 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author_email="", license="MIT", package_dir={'': 'src'}, - py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"], + py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "pandas"], install_requires=[], extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]}, entry_points={} diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py index e65c19b..1c83f42 100644 --- a/src/test_decision_tree.py +++ b/src/test_decision_tree.py @@ -90,3 +90,14 @@ def test_calculate_gini(): for i in range(len(DATASET2_VALUES)): left, right = dtree._test_split(DATASET2_VALUES[i][1], DATASET2_VALUES[i][0], data) assert round(dtree._calculate_gini([left, right], [0.0, 1.0]), 3) == DATASET2_GINI[i] + + +def test__get_split(): + """Test get optimal split point.""" + from decision_tree import DecisionTree + data_table = pd.DataFrame(DATASET2) + dtree = DecisionTree(1, 1) + split = dtree._get_split(data_table) + # import pdb; pdb.set_trace() + for i in range(len(split[2])): + assert split[2][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file From 5646a18552bad0a1ee3390985b7a648a806d6f70 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 11:15:36 -0800 Subject: [PATCH 09/19] Pushing up tinkering changes and troubleshooting gini splits. --- src/decision_tree.py | 56 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index f2a6795..ea67541 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -4,7 +4,7 @@ class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" - def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, parent=None): + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None): """Initialize a node object for a decision tree classifier.""" self.left = left self.right = right @@ -12,6 +12,7 @@ def __init__(self, data, split_value, split_gini, split_col, left=None, right=No self.split_value = split_value self.split_gini = split_gini self.split_col = split_col + self.label = label def _has_children(self): """Return True or False if Node has children.""" @@ -19,13 +20,6 @@ def _has_children(self): return True return False - def _return_children(self): - """Return all children of a Node.""" - if self.left and self.right: - return [self.left, self.right] - elif self.left or self.right: - return [self.left] if self.left else [self.right] - class DecisionTree(object): """Define a Decision Tree class object.""" @@ -40,31 +34,37 @@ def fit(self, data): """Create a tree to fit the data.""" split_col, split_value, split_gini, split_groups = self._get_split(data) new_node = TreeNode(data, split_value, split_gini, split_col) - if not self.root: - self.root = new_node - if self._can_split(split_gini, len(data)): - self.root.left = self.fit(split_groups[0]) - self.root.right = self.fit(split_groups[1]) + self.root = new_node + self.root.left = self._build_tree(split_groups[0]) + self.root.right = self._build_tree(split_groups[1]) + + def _build_tree(self, data, depth_count=0): + """Given a node, build the tree.""" + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + try: + new_node.label = data[data.columns[-1]].mode()[0] + except: + pass + if self._can_split(split_gini, depth_count, len(data)): + print("splitting") + new_node.left = self._build_tree(split_groups[0], depth_count + 1) + new_node.right = self._build_tree(split_groups[1], depth_count + 1) else: + print("terminating") return new_node + return new_node - # def _build_tree(self, data): - # """Given a node, build the tree.""" - # split_col, split_value, split_gini, split_groups = self._get_split(data) - # new_node = TreeNode(split_value, data, label=split_col) - # if self._can_split(split_gini, len(data)): - # self.root.left = self._build_tree(split_groups[0], new_node) - # self.root.right = self._build_tree(split_groups[1], new_node) - # else: - # return new_node - - def _can_split(self, gini, data_size): + def _can_split(self, gini, depth_count, data_size): """Given a gini value, determine whether or not tree can split.""" if gini == 0.0: + print("Bad gini") return False - elif self._depth() >= self.max_depth: + elif depth_count >= self.max_depth: + print("bad depth") return False elif data_size <= self.min_leaf_size: + print("bad data size") return False else: return True @@ -105,9 +105,9 @@ def _get_split(self, data): row = data.iloc[i] groups = self._test_split(col, row[col], data) gini = self._calculate_gini(groups, classes) - if gini < split_gini: - split_col, split_value, split_gini, split_groups =\ - col, row[col], gini, groups + if gini < split_gini: + split_col, split_value, split_gini, split_groups =\ + col, row[col], gini, groups return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): From 36f502598ab6a4c9e35d5c87c8b53ba81fead961 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 12:37:00 -0800 Subject: [PATCH 10/19] Troubleshooting weird behavior. Passing the torch. --- src/decision_tree.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index ea67541..88f940c 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -38,7 +38,7 @@ def fit(self, data): self.root.left = self._build_tree(split_groups[0]) self.root.right = self._build_tree(split_groups[1]) - def _build_tree(self, data, depth_count=0): + def _build_tree(self, data, depth_count=1): """Given a node, build the tree.""" split_col, split_value, split_gini, split_groups = self._get_split(data) new_node = TreeNode(data, split_value, split_gini, split_col) @@ -101,24 +101,23 @@ def _get_split(self, data): split_col, split_value, split_gini, split_groups =\ float('inf'), float('inf'), float('inf'), None for col in data.columns.values[:-2]: - for i in range(len(data)): - row = data.iloc[i] - groups = self._test_split(col, row[col], data) + for row in data.iterrows(): + groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) - if gini < split_gini: + if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: split_col, split_value, split_gini, split_groups =\ - col, row[col], gini, groups + col, row[1][col], gini, groups + # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): """Given a dataset, column index, and value, split the dataset.""" left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns) - for i in range(len(data)): - row = data.iloc[i] - if row[col] < value: - left = left.append(row) + for row in data.iterrows(): + if row[1][col] < value: + left = left.append(row[1]) else: - right = right.append(row) + right = right.append(row[1]) return left, right def predict(self, data): From 149a4054259c2e95ee9002c5d6a8307718cbad12 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 16:03:41 -0800 Subject: [PATCH 11/19] stopping fiddling for now. --- src/decision_tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 88f940c..410d235 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -91,6 +91,7 @@ def _calculate_gini(self, groups, class_values): size = len(group) if size == 0: continue + import pdb;pdb.set_trace() proportion = len(group[group[group.columns[-1]] == class_value]) / float(size) gini += (proportion * (1.0 - proportion)) return gini @@ -100,14 +101,13 @@ def _get_split(self, data): classes = data[data.columns[-1]].unique() split_col, split_value, split_gini, split_groups =\ float('inf'), float('inf'), float('inf'), None - for col in data.columns.values[:-2]: + for col in data.columns.values[:-1]: for row in data.iterrows(): groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: split_col, split_value, split_gini, split_groups =\ col, row[1][col], gini, groups - # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): From 6fbc80a2c1927014ccc4e9c96929db533a3e24f1 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 17:35:26 -0800 Subject: [PATCH 12/19] Wrote out first draft of knn. --- src/knn.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ src/test_knn.py | 23 +++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 src/knn.py create mode 100644 src/test_knn.py diff --git a/src/knn.py b/src/knn.py new file mode 100644 index 0000000..299a97b --- /dev/null +++ b/src/knn.py @@ -0,0 +1,59 @@ +"""Implement a K-Nearest Neighbors aglorithm.""" +import pandas as pd +from math import sqrt + + +class KNearestNeighbors(object): + """Define a K-Nearest Neighbors object.""" + + def __init__(self, data, k=5): + """Initialize a k nearest neighbors object.""" + if type(k) is not int or k <= 0: + raise ValueError("Please initalize with positive integer.") + else: + self.k = k + if type(data) is not pd.DataFrame: + try: + self.data = pd.DataFrame(data) + except pd.PandasError: + pass + else: + self.data = data + + def predict(self, test_data, tk=None): + """Given data, categorize the data by its k nearest neighbors.""" + if tk is None: + tk = self.k + if type(test_data) is not pd.DataFrame: + try: + test_data = pd.DataFrame(test_data) + except pd.PandasError: + raise ValueError("BAD DATA YA TURKEY") + distances = [] + for row in self.data.iterrorws(): + distances.append(row[-1], self._distance(row, test_data)) + distances.sort(key=lambda x: x[1]) + my_class = self._classify(distances[:tk]) + if my_class: + return my_class + else: + self.predict(test_data, tk - 1) + + def _classify(self, res_list): + """Classify an object given a set of data about its classes.""" + classes = (item[0] for item in res_list) + class_counts = [] + for a_class in classes: + class_counts.append((a_class, len([item for item in res_list if item[0] == a_class]))) + class_counts.sort(key=lambda x: x[1], desc=True) + if class_counts[0][1] == class_counts[1][1]: + return + else: + return class_counts[0][0] + + def _distance(self, row1, row2): + """Calcute the distance between two rows.""" + dist = 0.0 + for col in row1.columns.values[:-1]: + dist += (row1[col] - row2[col]) ** 2 + return sqrt(dist) diff --git a/src/test_knn.py b/src/test_knn.py new file mode 100644 index 0000000..4f9802b --- /dev/null +++ b/src/test_knn.py @@ -0,0 +1,23 @@ +"""Test the K-NN algorithm.""" +import pytest +import os +import pandas as pd + + +BAD_Ks = [-1, "whoops", 0] +DATA = pd.read_csv(os.path.abspath('src/flowers_data.csv')) + + +def test_initialize_k_nearest_bad_k(): + """Test initializing with bad k value raises error.""" + from knn import KNearestNeighbors + for test_item in BAD_Ks: + with pytest.raises(ValueError): + KNearestNeighbors(DATA, test_item) + + +def test_initialize_k_nearest_good_k(): + """Test initializing with good k value.""" + from knn import KNearestNeighbors + k = KNearestNeighbors(DATA, 2) + assert type(k) is KNearestNeighbors From 132c87e29f1c9a6b84fdc1d152cdf417296aa6f6 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 10:33:26 -0800 Subject: [PATCH 13/19] testing distance function, debugging distance fucntion. --- src/knn.py | 4 ++-- src/test_knn.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/knn.py b/src/knn.py index 299a97b..e88e93b 100644 --- a/src/knn.py +++ b/src/knn.py @@ -54,6 +54,6 @@ def _classify(self, res_list): def _distance(self, row1, row2): """Calcute the distance between two rows.""" dist = 0.0 - for col in row1.columns.values[:-1]: - dist += (row1[col] - row2[col]) ** 2 + for i in range(len(row1) - 1): + dist += (row1[i] - row2[i]) ** 2 return sqrt(dist) diff --git a/src/test_knn.py b/src/test_knn.py index 4f9802b..a6b3139 100644 --- a/src/test_knn.py +++ b/src/test_knn.py @@ -2,6 +2,7 @@ import pytest import os import pandas as pd +from math import sqrt BAD_Ks = [-1, "whoops", 0] @@ -21,3 +22,20 @@ def test_initialize_k_nearest_good_k(): from knn import KNearestNeighbors k = KNearestNeighbors(DATA, 2) assert type(k) is KNearestNeighbors + + +def test_distance_calc(): + """Test correctness of distance calc funciton.""" + from knn import KNearestNeighbors + rows = [[2, 2, 1], [0, 0, 1]] + data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) + test_data = KNearestNeighbors(data) + assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) + +def test_distance_calc_zero(): + """Test correctness of distance calc funciton when distance is zero.""" + from knn import KNearestNeighbors + rows = [[2, 2, 1], [0, 0, 1]] + data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) + test_data = KNearestNeighbors(data) + assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) \ No newline at end of file From dfb41ecbba8673f52f70e557e0969214c1b296cd Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 10:36:50 -0800 Subject: [PATCH 14/19] edge case test of distance function - zero distance --- src/test_knn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/test_knn.py b/src/test_knn.py index a6b3139..498d71e 100644 --- a/src/test_knn.py +++ b/src/test_knn.py @@ -32,10 +32,11 @@ def test_distance_calc(): test_data = KNearestNeighbors(data) assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) + def test_distance_calc_zero(): """Test correctness of distance calc funciton when distance is zero.""" from knn import KNearestNeighbors - rows = [[2, 2, 1], [0, 0, 1]] + rows = [[2, 2, 1], [2, 2, 1]] data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) test_data = KNearestNeighbors(data) - assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) \ No newline at end of file + assert test_data._distance(data.loc[0], data.loc[1]) == 0 From 5d2d7be2771ad98ab07308356bca990f9ee7e9e5 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Tue, 21 Feb 2017 11:10:46 -0800 Subject: [PATCH 15/19] Troubleshooting knn.py and wrote test_simple_prediction --- src/knn.py | 10 +++++----- src/test_knn.py | 28 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/knn.py b/src/knn.py index e88e93b..0eb9ab0 100644 --- a/src/knn.py +++ b/src/knn.py @@ -26,12 +26,12 @@ def predict(self, test_data, tk=None): tk = self.k if type(test_data) is not pd.DataFrame: try: - test_data = pd.DataFrame(test_data) + test_data = pd.Series(test_data) except pd.PandasError: raise ValueError("BAD DATA YA TURKEY") distances = [] - for row in self.data.iterrorws(): - distances.append(row[-1], self._distance(row, test_data)) + for row in self.data.iterrows(): + distances.append((row[1][-1], self._distance(row[1], test_data))) distances.sort(key=lambda x: x[1]) my_class = self._classify(distances[:tk]) if my_class: @@ -41,11 +41,11 @@ def predict(self, test_data, tk=None): def _classify(self, res_list): """Classify an object given a set of data about its classes.""" - classes = (item[0] for item in res_list) + classes = {item[0] for item in res_list} class_counts = [] for a_class in classes: class_counts.append((a_class, len([item for item in res_list if item[0] == a_class]))) - class_counts.sort(key=lambda x: x[1], desc=True) + class_counts.sort(key=lambda x: x[1], reverse=True) if class_counts[0][1] == class_counts[1][1]: return else: diff --git a/src/test_knn.py b/src/test_knn.py index a6b3139..f1a9bae 100644 --- a/src/test_knn.py +++ b/src/test_knn.py @@ -8,6 +8,23 @@ BAD_Ks = [-1, "whoops", 0] DATA = pd.read_csv(os.path.abspath('src/flowers_data.csv')) +SIMPLE_COLUMNS = ["x", "y", "class"] +SIMPLE_DATA = [[6, 6, 0], + [5, 5, 0], + [4, 4, 0], + [3, 3, 1], + [2, 2, 1], + [1, 1, 1], + [0, 0, 1]] + +@pytest.fixture +def simple_knn(): + """Create a default knn with flowers data.""" + data = pd.DataFrame(SIMPLE_DATA, columns=SIMPLE_COLUMNS) + from knn import KNearestNeighbors + k = KNearestNeighbors(data) + return k + def test_initialize_k_nearest_bad_k(): """Test initializing with bad k value raises error.""" @@ -32,10 +49,19 @@ def test_distance_calc(): test_data = KNearestNeighbors(data) assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) + def test_distance_calc_zero(): """Test correctness of distance calc funciton when distance is zero.""" from knn import KNearestNeighbors rows = [[2, 2, 1], [0, 0, 1]] data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) test_data = KNearestNeighbors(data) - assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) \ No newline at end of file + assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) + + +def test_simple_prediction(simple_knn): + """Test a simple prediction.""" + knn = simple_knn + test_data = [0.5, 0.5] + prediction = knn.predict(test_data) + assert prediction == 1 From 1baad41f58d810cbacba8405bb0584141b164903 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 11:38:04 -0800 Subject: [PATCH 16/19] classify unit test --- src/knn.py | 1 + src/test_knn.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/knn.py b/src/knn.py index 0eb9ab0..ac77ee8 100644 --- a/src/knn.py +++ b/src/knn.py @@ -33,6 +33,7 @@ def predict(self, test_data, tk=None): for row in self.data.iterrows(): distances.append((row[1][-1], self._distance(row[1], test_data))) distances.sort(key=lambda x: x[1]) + # import pdb; pdb.set_trace() my_class = self._classify(distances[:tk]) if my_class: return my_class diff --git a/src/test_knn.py b/src/test_knn.py index 7cbf843..f10b78e 100644 --- a/src/test_knn.py +++ b/src/test_knn.py @@ -17,6 +17,7 @@ [1, 1, 1], [0, 0, 1]] + @pytest.fixture def simple_knn(): """Create a default knn with flowers data.""" @@ -59,6 +60,18 @@ def test_distance_calc_zero(): assert test_data._distance(data.loc[0], data.loc[1]) == 0 +def test_classify(simple_knn): + """Test _classify method returns expected value.""" + knn = simple_knn + data = pd.DataFrame(data=SIMPLE_DATA, columns=SIMPLE_COLUMNS) + test_data = [0.5, 0.5] + distances = [] + for row in data.iterrows(): + distances.append((row[1][-1], knn._distance(row[1], test_data))) + distances.sort(key=lambda x: x[1]) + assert knn._classify(distances[:5]) == 1 + + def test_simple_prediction(simple_knn): """Test a simple prediction.""" knn = simple_knn From 7248c2e2d39888d7179aa42c6579de0c7b5c20e6 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 12:45:42 -0800 Subject: [PATCH 17/19] run test predictions of the entire flowers dataset, tests pass. --- src/knn.py | 2 +- src/test_knn.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/knn.py b/src/knn.py index ac77ee8..6728e30 100644 --- a/src/knn.py +++ b/src/knn.py @@ -47,7 +47,7 @@ def _classify(self, res_list): for a_class in classes: class_counts.append((a_class, len([item for item in res_list if item[0] == a_class]))) class_counts.sort(key=lambda x: x[1], reverse=True) - if class_counts[0][1] == class_counts[1][1]: + if len(class_counts) > 1 and class_counts[0][1] == class_counts[1][1]: return else: return class_counts[0][0] diff --git a/src/test_knn.py b/src/test_knn.py index f10b78e..7bd7c7e 100644 --- a/src/test_knn.py +++ b/src/test_knn.py @@ -78,3 +78,14 @@ def test_simple_prediction(simple_knn): test_data = [0.5, 0.5] prediction = knn.predict(test_data) assert prediction == 1 + + +def test_flowers_integration(): + """Test knn predictions using flowers data.""" + from knn import KNearestNeighbors + total_data = DATA.drop('target', axis=1) + calibration_number = len(total_data) // 5 + new_data = total_data.sample(n=calibration_number) + k = KNearestNeighbors(new_data) + for row in new_data.iterrows(): + assert k.predict(row[1][:-1]) == row[1][-1] From 5a924230ddf253b35c9698d153d29f06f1d2e237 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 19:47:21 -0800 Subject: [PATCH 18/19] cleared corpse code, more semantic vars --- .travis.yml | 1 + setup.py | 2 +- src/knn.py | 11 +++++------ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 8447992..6eb9590 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ python: # command to install dependencies install: + - pip install pandas - pip install -e .[test] # - pip install coveralls # command to run tests diff --git a/setup.py b/setup.py index 27ad944..f96dcf6 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author_email="", license="MIT", package_dir={'': 'src'}, - py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "pandas"], + py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "wheel", "numpy" , "pandas"], install_requires=[], extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]}, entry_points={} diff --git a/src/knn.py b/src/knn.py index 6728e30..b42739b 100644 --- a/src/knn.py +++ b/src/knn.py @@ -20,10 +20,10 @@ def __init__(self, data, k=5): else: self.data = data - def predict(self, test_data, tk=None): + def predict(self, test_data, test_k_val=None): """Given data, categorize the data by its k nearest neighbors.""" - if tk is None: - tk = self.k + if test_k_val is None: + test_k_val = self.k if type(test_data) is not pd.DataFrame: try: test_data = pd.Series(test_data) @@ -33,12 +33,11 @@ def predict(self, test_data, tk=None): for row in self.data.iterrows(): distances.append((row[1][-1], self._distance(row[1], test_data))) distances.sort(key=lambda x: x[1]) - # import pdb; pdb.set_trace() - my_class = self._classify(distances[:tk]) + my_class = self._classify(distances[:test_k_val]) if my_class: return my_class else: - self.predict(test_data, tk - 1) + self.predict(test_data, test_k_val - 1) def _classify(self, res_list): """Classify an object given a set of data about its classes.""" From e841260612f8da7750cafe0dadc9f11c3e779ac7 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Tue, 21 Feb 2017 20:21:35 -0800 Subject: [PATCH 19/19] added readme section --- README.MD | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/README.MD b/README.MD index 1187fff..6bcf093 100644 --- a/README.MD +++ b/README.MD @@ -46,4 +46,20 @@ in Python containing the following methods: __init__ : O(n) get: O(1) + O(k) set: O(k) + O(1) - _hash: O(m) \ No newline at end of file + _hash: O(m) + + + +K Nearest Neighbors classifier: + +The KNearestNeighbors class initilaizes with the required argument +data, which contains the dataset to base calculations on, and the optional +argument k. K defaults to five and determines the number of neighbors the +algorithm uses to classify new data. + +The predict method is the only one intended for the end user. It accepts the +data to be tested as an argument, with the optional keyword argument of +test_k_val. If test_k_val is not assigned, it defaults to the k value +determined at instantiation. If data classification results in a tie, the +k value is reduced by one and reevaluated. Predict returns a value corrisponding +to one of the groups in the instantiation data. \ No newline at end of file