diff --git a/.travis.yml b/.travis.yml index 8447992..6eb9590 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ python: # command to install dependencies install: + - pip install pandas - pip install -e .[test] # - pip install coveralls # command to run tests diff --git a/README.MD b/README.MD index 1187fff..6bcf093 100644 --- a/README.MD +++ b/README.MD @@ -46,4 +46,20 @@ in Python containing the following methods: __init__ : O(n) get: O(1) + O(k) set: O(k) + O(1) - _hash: O(m) \ No newline at end of file + _hash: O(m) + + + +K Nearest Neighbors classifier: + +The KNearestNeighbors class initilaizes with the required argument +data, which contains the dataset to base calculations on, and the optional +argument k. K defaults to five and determines the number of neighbors the +algorithm uses to classify new data. + +The predict method is the only one intended for the end user. It accepts the +data to be tested as an argument, with the optional keyword argument of +test_k_val. If test_k_val is not assigned, it defaults to the k value +determined at instantiation. If data classification results in a tie, the +k value is reduced by one and reevaluated. Predict returns a value corrisponding +to one of the groups in the instantiation data. \ No newline at end of file diff --git a/setup.py b/setup.py index 24d582d..f96dcf6 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author_email="", license="MIT", package_dir={'': 'src'}, - py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"], + py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "wheel", "numpy" , "pandas"], install_requires=[], extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]}, entry_points={} diff --git a/src/decision_tree.py b/src/decision_tree.py new file mode 100644 index 0000000..410d235 --- /dev/null +++ b/src/decision_tree.py @@ -0,0 +1,132 @@ +import pandas as pd + + +class TreeNode(object): + """Define a Node object for use in a decision tree classifier.""" + + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None): + """Initialize a node object for a decision tree classifier.""" + self.left = left + self.right = right + self.data = data + self.split_value = split_value + self.split_gini = split_gini + self.split_col = split_col + self.label = label + + def _has_children(self): + """Return True or False if Node has children.""" + if self.right or self.left: + return True + return False + + +class DecisionTree(object): + """Define a Decision Tree class object.""" + + def __init__(self, max_depth, min_leaf_size): + """Initialize a Decision Tree object.""" + self.max_depth = max_depth + self.min_leaf_size = min_leaf_size + self.root = None + + def fit(self, data): + """Create a tree to fit the data.""" + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + self.root = new_node + self.root.left = self._build_tree(split_groups[0]) + self.root.right = self._build_tree(split_groups[1]) + + def _build_tree(self, data, depth_count=1): + """Given a node, build the tree.""" + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + try: + new_node.label = data[data.columns[-1]].mode()[0] + except: + pass + if self._can_split(split_gini, depth_count, len(data)): + print("splitting") + new_node.left = self._build_tree(split_groups[0], depth_count + 1) + new_node.right = self._build_tree(split_groups[1], depth_count + 1) + else: + print("terminating") + return new_node + return new_node + + def _can_split(self, gini, depth_count, data_size): + """Given a gini value, determine whether or not tree can split.""" + if gini == 0.0: + print("Bad gini") + return False + elif depth_count >= self.max_depth: + print("bad depth") + return False + elif data_size <= self.min_leaf_size: + print("bad data size") + return False + else: + return True + + def _depth(self, start=''): + """Return the integer depth of the BST.""" + def depth_wrapped(start): + if start is None: + return 0 + else: + right_depth = depth_wrapped(start.right) + left_depth = depth_wrapped(start.left) + return max(right_depth, left_depth) + 1 + if start is '': + return depth_wrapped(self.root) + else: + return depth_wrapped(start) + + def _calculate_gini(self, groups, class_values): + """Calculate gini for a given data_set.""" + gini = 0.0 + for class_value in class_values: + for group in groups: + size = len(group) + if size == 0: + continue + import pdb;pdb.set_trace() + proportion = len(group[group[group.columns[-1]] == class_value]) / float(size) + gini += (proportion * (1.0 - proportion)) + return gini + + def _get_split(self, data): + """Choose a split point with lowest gini index.""" + classes = data[data.columns[-1]].unique() + split_col, split_value, split_gini, split_groups =\ + float('inf'), float('inf'), float('inf'), None + for col in data.columns.values[:-1]: + for row in data.iterrows(): + groups = self._test_split(col, row[1][col], data) + gini = self._calculate_gini(groups, classes) + if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: + split_col, split_value, split_gini, split_groups =\ + col, row[1][col], gini, groups + return split_col, split_value, split_gini, split_groups + + def _test_split(self, col, value, data): + """Given a dataset, column index, and value, split the dataset.""" + left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns) + for row in data.iterrows(): + if row[1][col] < value: + left = left.append(row[1]) + else: + right = right.append(row[1]) + return left, right + + def predict(self, data): + """Given data, return labels for that data.""" + curr_node = self.root + while curr_node._has_children(): + if data[curr_node.label]: + curr_node = curr_node.right + else: + curr_node = curr_node.left + return curr_node.label + diff --git a/src/flowers_data.csv b/src/flowers_data.csv new file mode 100644 index 0000000..63fed67 --- /dev/null +++ b/src/flowers_data.csv @@ -0,0 +1,101 @@ +petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names +1.4,0.2,5.1,3.5,0,setosa +1.4,0.2,4.9,3.0,0,setosa +1.3,0.2,4.7,3.2,0,setosa +1.5,0.2,4.6,3.1,0,setosa +1.4,0.2,5.0,3.6,0,setosa +1.7,0.4,5.4,3.9,0,setosa +1.4,0.3,4.6,3.4,0,setosa +1.5,0.2,5.0,3.4,0,setosa +1.4,0.2,4.4,2.9,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.5,0.2,5.4,3.7,0,setosa +1.6,0.2,4.8,3.4,0,setosa +1.4,0.1,4.8,3.0,0,setosa +1.1,0.1,4.3,3.0,0,setosa +1.2,0.2,5.8,4.0,0,setosa +1.5,0.4,5.7,4.4,0,setosa +1.3,0.4,5.4,3.9,0,setosa +1.4,0.3,5.1,3.5,0,setosa +1.7,0.3,5.7,3.8,0,setosa +1.5,0.3,5.1,3.8,0,setosa +1.7,0.2,5.4,3.4,0,setosa +1.5,0.4,5.1,3.7,0,setosa +1.0,0.2,4.6,3.6,0,setosa +1.7,0.5,5.1,3.3,0,setosa +1.9,0.2,4.8,3.4,0,setosa +1.6,0.2,5.0,3.0,0,setosa +1.6,0.4,5.0,3.4,0,setosa +1.5,0.2,5.2,3.5,0,setosa +1.4,0.2,5.2,3.4,0,setosa +1.6,0.2,4.7,3.2,0,setosa +1.6,0.2,4.8,3.1,0,setosa +1.5,0.4,5.4,3.4,0,setosa +1.5,0.1,5.2,4.1,0,setosa +1.4,0.2,5.5,4.2,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.2,0.2,5.0,3.2,0,setosa +1.3,0.2,5.5,3.5,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.3,0.2,4.4,3.0,0,setosa +1.5,0.2,5.1,3.4,0,setosa +1.3,0.3,5.0,3.5,0,setosa +1.3,0.3,4.5,2.3,0,setosa +1.3,0.2,4.4,3.2,0,setosa +1.6,0.6,5.0,3.5,0,setosa +1.9,0.4,5.1,3.8,0,setosa +1.4,0.3,4.8,3.0,0,setosa +1.6,0.2,5.1,3.8,0,setosa +1.4,0.2,4.6,3.2,0,setosa +1.5,0.2,5.3,3.7,0,setosa +1.4,0.2,5.0,3.3,0,setosa +4.7,1.4,7.0,3.2,1,versicolor +4.5,1.5,6.4,3.2,1,versicolor +4.9,1.5,6.9,3.1,1,versicolor +4.0,1.3,5.5,2.3,1,versicolor +4.6,1.5,6.5,2.8,1,versicolor +4.5,1.3,5.7,2.8,1,versicolor +4.7,1.6,6.3,3.3,1,versicolor +3.3,1.0,4.9,2.4,1,versicolor +4.6,1.3,6.6,2.9,1,versicolor +3.9,1.4,5.2,2.7,1,versicolor +3.5,1.0,5.0,2.0,1,versicolor +4.2,1.5,5.9,3.0,1,versicolor +4.0,1.0,6.0,2.2,1,versicolor +4.7,1.4,6.1,2.9,1,versicolor +3.6,1.3,5.6,2.9,1,versicolor +4.4,1.4,6.7,3.1,1,versicolor +4.5,1.5,5.6,3.0,1,versicolor +4.1,1.0,5.8,2.7,1,versicolor +4.5,1.5,6.2,2.2,1,versicolor +3.9,1.1,5.6,2.5,1,versicolor +4.8,1.8,5.9,3.2,1,versicolor +4.0,1.3,6.1,2.8,1,versicolor +4.9,1.5,6.3,2.5,1,versicolor +4.7,1.2,6.1,2.8,1,versicolor +4.3,1.3,6.4,2.9,1,versicolor +4.4,1.4,6.6,3.0,1,versicolor +4.8,1.4,6.8,2.8,1,versicolor +5.0,1.7,6.7,3.0,1,versicolor +4.5,1.5,6.0,2.9,1,versicolor +3.5,1.0,5.7,2.6,1,versicolor +3.8,1.1,5.5,2.4,1,versicolor +3.7,1.0,5.5,2.4,1,versicolor +3.9,1.2,5.8,2.7,1,versicolor +5.1,1.6,6.0,2.7,1,versicolor +4.5,1.5,5.4,3.0,1,versicolor +4.5,1.6,6.0,3.4,1,versicolor +4.7,1.5,6.7,3.1,1,versicolor +4.4,1.3,6.3,2.3,1,versicolor +4.1,1.3,5.6,3.0,1,versicolor +4.0,1.3,5.5,2.5,1,versicolor +4.4,1.2,5.5,2.6,1,versicolor +4.6,1.4,6.1,3.0,1,versicolor +4.0,1.2,5.8,2.6,1,versicolor +3.3,1.0,5.0,2.3,1,versicolor +4.2,1.3,5.6,2.7,1,versicolor +4.2,1.2,5.7,3.0,1,versicolor +4.2,1.3,5.7,2.9,1,versicolor +4.3,1.3,6.2,2.9,1,versicolor +3.0,1.1,5.1,2.5,1,versicolor +4.1,1.3,5.7,2.8,1,versicolor diff --git a/src/knn.py b/src/knn.py new file mode 100644 index 0000000..b42739b --- /dev/null +++ b/src/knn.py @@ -0,0 +1,59 @@ +"""Implement a K-Nearest Neighbors aglorithm.""" +import pandas as pd +from math import sqrt + + +class KNearestNeighbors(object): + """Define a K-Nearest Neighbors object.""" + + def __init__(self, data, k=5): + """Initialize a k nearest neighbors object.""" + if type(k) is not int or k <= 0: + raise ValueError("Please initalize with positive integer.") + else: + self.k = k + if type(data) is not pd.DataFrame: + try: + self.data = pd.DataFrame(data) + except pd.PandasError: + pass + else: + self.data = data + + def predict(self, test_data, test_k_val=None): + """Given data, categorize the data by its k nearest neighbors.""" + if test_k_val is None: + test_k_val = self.k + if type(test_data) is not pd.DataFrame: + try: + test_data = pd.Series(test_data) + except pd.PandasError: + raise ValueError("BAD DATA YA TURKEY") + distances = [] + for row in self.data.iterrows(): + distances.append((row[1][-1], self._distance(row[1], test_data))) + distances.sort(key=lambda x: x[1]) + my_class = self._classify(distances[:test_k_val]) + if my_class: + return my_class + else: + self.predict(test_data, test_k_val - 1) + + def _classify(self, res_list): + """Classify an object given a set of data about its classes.""" + classes = {item[0] for item in res_list} + class_counts = [] + for a_class in classes: + class_counts.append((a_class, len([item for item in res_list if item[0] == a_class]))) + class_counts.sort(key=lambda x: x[1], reverse=True) + if len(class_counts) > 1 and class_counts[0][1] == class_counts[1][1]: + return + else: + return class_counts[0][0] + + def _distance(self, row1, row2): + """Calcute the distance between two rows.""" + dist = 0.0 + for i in range(len(row1) - 1): + dist += (row1[i] - row2[i]) ** 2 + return sqrt(dist) diff --git a/src/test_dataset2.csv b/src/test_dataset2.csv new file mode 100644 index 0000000..e2a182d --- /dev/null +++ b/src/test_dataset2.csv @@ -0,0 +1,11 @@ +x,y,class_names +2.771244718,1.784783929,0 +1.728571309,1.169761413,0 +3.678319846,2.81281357,0 +3.961043357,2.61995032,0 +2.999208922,2.209014212,0 +7.497545867,3.162953546,1 +9.00220326,3.339047188,1 +7.444542326,0.476683375,1 +10.12493903,3.234550982,1 +6.642287351,3.319983761,1 \ No newline at end of file diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py new file mode 100644 index 0000000..1c83f42 --- /dev/null +++ b/src/test_decision_tree.py @@ -0,0 +1,103 @@ +import pytest +import pandas as pd + + +DATASET2 = pd.read_csv("src/test_dataset2.csv") +DATASET2_VALUES = [ + (2.771, 'x'), + (1.728, 'x'), + (3.678, 'x'), + (3.961, 'x'), + (2.999, 'x'), + (7.497, 'x'), + (9.002, 'x'), + (7.444, 'x'), + (10.124, 'x'), + (6.642, 'x'), + (1.784, 'y'), + (1.168, 'y'), + (2.812, 'y'), + (2.619, 'y'), + (2.209, 'y'), + (3.162, 'y'), + (3.339, 'y'), + (0.476, 'y'), + (3.234, 'y'), + (3.319, 'y'), +] + +DATASET2_GINI = [ + 0.494, + 0.500, + 0.408, + 0.278, + 0.469, + 0.408, + 0.469, + 0.278, + 0.494, + 0.000, + 1.000, + 0.494, + 0.640, + 0.819, + 0.934, + 0.278, + 0.494, + 0.500, + 0.408, + 0.469, +] + +# X1 < 2.771 Gini=0.494 +# X1 < 1.729 Gini=0.500 +# X1 < 3.678 Gini=0.408 +# X1 < 3.961 Gini=0.278 +# X1 < 2.999 Gini=0.469 +# X1 < 7.498 Gini=0.408 +# X1 < 9.002 Gini=0.469 +# X1 < 7.445 Gini=0.278 +# X1 < 10.125 Gini=0.494 +# X1 < 6.642 Gini=0.000 +# X2 < 1.785 Gini=1.000 +# X2 < 1.170 Gini=0.494 +# X2 < 2.813 Gini=0.640 +# X2 < 2.620 Gini=0.819 +# X2 < 2.209 Gini=0.934 +# X2 < 3.163 Gini=0.278 +# X2 < 3.339 Gini=0.494 +# X2 < 0.477 Gini=0.500 +# X2 < 3.235 Gini=0.408 +# X2 < 3.320 Gini=0.469 + + +def test_test_split(): + """Test _test_split method with test dataset.""" + from decision_tree import DecisionTree + data = pd.DataFrame([[1.0, 2.0, '1'], [3.0, 4.0, '0'], [5.0, 6.0, '1'], [7.0, 8.0, '0']]) + left = data[data[data.columns[0]] < 3] + right = data[data[data.columns[0]] >= 3] + dtree = DecisionTree(1, 1) + assert dtree._test_split(0, 3, data)[0].equals(left) + assert dtree._test_split(0, 3, data)[1].equals(right) + + +def test_calculate_gini(): + """Test calculate gini with know data set.""" + from decision_tree import DecisionTree + dtree = DecisionTree(1, 1) + data = pd.DataFrame(DATASET2) + for i in range(len(DATASET2_VALUES)): + left, right = dtree._test_split(DATASET2_VALUES[i][1], DATASET2_VALUES[i][0], data) + assert round(dtree._calculate_gini([left, right], [0.0, 1.0]), 3) == DATASET2_GINI[i] + + +def test__get_split(): + """Test get optimal split point.""" + from decision_tree import DecisionTree + data_table = pd.DataFrame(DATASET2) + dtree = DecisionTree(1, 1) + split = dtree._get_split(data_table) + # import pdb; pdb.set_trace() + for i in range(len(split[2])): + assert split[2][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file diff --git a/src/test_knn.py b/src/test_knn.py new file mode 100644 index 0000000..7bd7c7e --- /dev/null +++ b/src/test_knn.py @@ -0,0 +1,91 @@ +"""Test the K-NN algorithm.""" +import pytest +import os +import pandas as pd +from math import sqrt + + +BAD_Ks = [-1, "whoops", 0] +DATA = pd.read_csv(os.path.abspath('src/flowers_data.csv')) + +SIMPLE_COLUMNS = ["x", "y", "class"] +SIMPLE_DATA = [[6, 6, 0], + [5, 5, 0], + [4, 4, 0], + [3, 3, 1], + [2, 2, 1], + [1, 1, 1], + [0, 0, 1]] + + +@pytest.fixture +def simple_knn(): + """Create a default knn with flowers data.""" + data = pd.DataFrame(SIMPLE_DATA, columns=SIMPLE_COLUMNS) + from knn import KNearestNeighbors + k = KNearestNeighbors(data) + return k + + +def test_initialize_k_nearest_bad_k(): + """Test initializing with bad k value raises error.""" + from knn import KNearestNeighbors + for test_item in BAD_Ks: + with pytest.raises(ValueError): + KNearestNeighbors(DATA, test_item) + + +def test_initialize_k_nearest_good_k(): + """Test initializing with good k value.""" + from knn import KNearestNeighbors + k = KNearestNeighbors(DATA, 2) + assert type(k) is KNearestNeighbors + + +def test_distance_calc(): + """Test correctness of distance calc funciton.""" + from knn import KNearestNeighbors + rows = [[2, 2, 1], [0, 0, 1]] + data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) + test_data = KNearestNeighbors(data) + assert test_data._distance(data.loc[0], data.loc[1]) == sqrt(8) + + +def test_distance_calc_zero(): + """Test correctness of distance calc funciton when distance is zero.""" + from knn import KNearestNeighbors + rows = [[2, 2, 1], [2, 2, 1]] + data = pd.DataFrame(data=rows, columns=['x', 'y', 'class']) + test_data = KNearestNeighbors(data) + assert test_data._distance(data.loc[0], data.loc[1]) == 0 + + +def test_classify(simple_knn): + """Test _classify method returns expected value.""" + knn = simple_knn + data = pd.DataFrame(data=SIMPLE_DATA, columns=SIMPLE_COLUMNS) + test_data = [0.5, 0.5] + distances = [] + for row in data.iterrows(): + distances.append((row[1][-1], knn._distance(row[1], test_data))) + distances.sort(key=lambda x: x[1]) + assert knn._classify(distances[:5]) == 1 + + +def test_simple_prediction(simple_knn): + """Test a simple prediction.""" + knn = simple_knn + test_data = [0.5, 0.5] + prediction = knn.predict(test_data) + assert prediction == 1 + + +def test_flowers_integration(): + """Test knn predictions using flowers data.""" + from knn import KNearestNeighbors + total_data = DATA.drop('target', axis=1) + calibration_number = len(total_data) // 5 + new_data = total_data.sample(n=calibration_number) + k = KNearestNeighbors(new_data) + for row in new_data.iterrows(): + assert k.predict(row[1][:-1]) == row[1][-1]