Skip to content
Open

Knn #21

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2f72f7d
Initial commit for decision tree.
CCallahanIV Feb 8, 2017
2ad032d
including data file.
CCallahanIV Feb 8, 2017
a944f66
Wrote out initial tree class.
CCallahanIV Feb 8, 2017
cb95f60
added some stuff.
CCallahanIV Feb 9, 2017
38e6948
gini and split functions added
pasaunders Feb 9, 2017
7b6964d
test database.
CCallahanIV Feb 9, 2017
8529b82
Wrote and tested test_split and calculate gini.
CCallahanIV Feb 10, 2017
e8e6f4c
Wrote out draft of fit, predict, and helper functions.
CCallahanIV Feb 10, 2017
772a3e6
testing _get_split
pasaunders Feb 10, 2017
ac1b615
Merge branch 'decision-tree' of https://github.com/CCallahanIV/data-s…
pasaunders Feb 10, 2017
5646a18
Pushing up tinkering changes and troubleshooting gini splits.
CCallahanIV Feb 11, 2017
36f5025
Troubleshooting weird behavior. Passing the torch.
CCallahanIV Feb 11, 2017
149a405
stopping fiddling for now.
CCallahanIV Feb 12, 2017
6fbc80a
Wrote out first draft of knn.
CCallahanIV Feb 12, 2017
132c87e
testing distance function, debugging distance fucntion.
pasaunders Feb 21, 2017
dfb41ec
edge case test of distance function - zero distance
pasaunders Feb 21, 2017
5d2d7be
Troubleshooting knn.py and wrote test_simple_prediction
CCallahanIV Feb 21, 2017
323abe6
fixed merge conlflict in test_knn
CCallahanIV Feb 21, 2017
1baad41
classify unit test
pasaunders Feb 21, 2017
7248c2e
run test predictions of the entire flowers dataset, tests pass.
pasaunders Feb 21, 2017
5a92423
cleared corpse code, more semantic vars
pasaunders Feb 22, 2017
e841260
added readme section
pasaunders Feb 22, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ python:

# command to install dependencies
install:
- pip install pandas
- pip install -e .[test]
# - pip install coveralls
# command to run tests
Expand Down
18 changes: 17 additions & 1 deletion README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,20 @@ in Python containing the following methods:
__init__ : O(n)
get: O(1) + O(k)
set: O(k) + O(1)
_hash: O(m)
_hash: O(m)



K Nearest Neighbors classifier:

The KNearestNeighbors class initilaizes with the required argument
data, which contains the dataset to base calculations on, and the optional
argument k. K defaults to five and determines the number of neighbors the
algorithm uses to classify new data.

The predict method is the only one intended for the end user. It accepts the
data to be tested as an argument, with the optional keyword argument of
test_k_val. If test_k_val is not assigned, it defaults to the k value
determined at instantiation. If data classification results in a tie, the
k value is reduced by one and reevaluated. Predict returns a value corrisponding
to one of the groups in the instantiation data.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author_email="",
license="MIT",
package_dir={'': 'src'},
py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"],
py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "wheel", "numpy" , "pandas"],
install_requires=[],
extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]},
entry_points={}
Expand Down
132 changes: 132 additions & 0 deletions src/decision_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pandas as pd


class TreeNode(object):
"""Define a Node object for use in a decision tree classifier."""

def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None):
"""Initialize a node object for a decision tree classifier."""
self.left = left
self.right = right
self.data = data
self.split_value = split_value
self.split_gini = split_gini
self.split_col = split_col
self.label = label

def _has_children(self):
"""Return True or False if Node has children."""
if self.right or self.left:
return True
return False


class DecisionTree(object):
"""Define a Decision Tree class object."""

def __init__(self, max_depth, min_leaf_size):
"""Initialize a Decision Tree object."""
self.max_depth = max_depth
self.min_leaf_size = min_leaf_size
self.root = None

def fit(self, data):
"""Create a tree to fit the data."""
split_col, split_value, split_gini, split_groups = self._get_split(data)
new_node = TreeNode(data, split_value, split_gini, split_col)
self.root = new_node
self.root.left = self._build_tree(split_groups[0])
self.root.right = self._build_tree(split_groups[1])

def _build_tree(self, data, depth_count=1):
"""Given a node, build the tree."""
split_col, split_value, split_gini, split_groups = self._get_split(data)
new_node = TreeNode(data, split_value, split_gini, split_col)
try:
new_node.label = data[data.columns[-1]].mode()[0]
except:
pass
if self._can_split(split_gini, depth_count, len(data)):
print("splitting")
new_node.left = self._build_tree(split_groups[0], depth_count + 1)
new_node.right = self._build_tree(split_groups[1], depth_count + 1)
else:
print("terminating")
return new_node
return new_node

def _can_split(self, gini, depth_count, data_size):
"""Given a gini value, determine whether or not tree can split."""
if gini == 0.0:
print("Bad gini")
return False
elif depth_count >= self.max_depth:
print("bad depth")
return False
elif data_size <= self.min_leaf_size:
print("bad data size")
return False
else:
return True

def _depth(self, start=''):
"""Return the integer depth of the BST."""
def depth_wrapped(start):
if start is None:
return 0
else:
right_depth = depth_wrapped(start.right)
left_depth = depth_wrapped(start.left)
return max(right_depth, left_depth) + 1
if start is '':
return depth_wrapped(self.root)
else:
return depth_wrapped(start)

def _calculate_gini(self, groups, class_values):
"""Calculate gini for a given data_set."""
gini = 0.0
for class_value in class_values:
for group in groups:
size = len(group)
if size == 0:
continue
import pdb;pdb.set_trace()
proportion = len(group[group[group.columns[-1]] == class_value]) / float(size)
gini += (proportion * (1.0 - proportion))
return gini

def _get_split(self, data):
"""Choose a split point with lowest gini index."""
classes = data[data.columns[-1]].unique()
split_col, split_value, split_gini, split_groups =\
float('inf'), float('inf'), float('inf'), None
for col in data.columns.values[:-1]:
for row in data.iterrows():
groups = self._test_split(col, row[1][col], data)
gini = self._calculate_gini(groups, classes)
if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0:
split_col, split_value, split_gini, split_groups =\
col, row[1][col], gini, groups
return split_col, split_value, split_gini, split_groups

def _test_split(self, col, value, data):
"""Given a dataset, column index, and value, split the dataset."""
left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns)
for row in data.iterrows():
if row[1][col] < value:
left = left.append(row[1])
else:
right = right.append(row[1])
return left, right

def predict(self, data):
"""Given data, return labels for that data."""
curr_node = self.root
while curr_node._has_children():
if data[curr_node.label]:
curr_node = curr_node.right
else:
curr_node = curr_node.left
return curr_node.label

101 changes: 101 additions & 0 deletions src/flowers_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names
1.4,0.2,5.1,3.5,0,setosa
1.4,0.2,4.9,3.0,0,setosa
1.3,0.2,4.7,3.2,0,setosa
1.5,0.2,4.6,3.1,0,setosa
1.4,0.2,5.0,3.6,0,setosa
1.7,0.4,5.4,3.9,0,setosa
1.4,0.3,4.6,3.4,0,setosa
1.5,0.2,5.0,3.4,0,setosa
1.4,0.2,4.4,2.9,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.5,0.2,5.4,3.7,0,setosa
1.6,0.2,4.8,3.4,0,setosa
1.4,0.1,4.8,3.0,0,setosa
1.1,0.1,4.3,3.0,0,setosa
1.2,0.2,5.8,4.0,0,setosa
1.5,0.4,5.7,4.4,0,setosa
1.3,0.4,5.4,3.9,0,setosa
1.4,0.3,5.1,3.5,0,setosa
1.7,0.3,5.7,3.8,0,setosa
1.5,0.3,5.1,3.8,0,setosa
1.7,0.2,5.4,3.4,0,setosa
1.5,0.4,5.1,3.7,0,setosa
1.0,0.2,4.6,3.6,0,setosa
1.7,0.5,5.1,3.3,0,setosa
1.9,0.2,4.8,3.4,0,setosa
1.6,0.2,5.0,3.0,0,setosa
1.6,0.4,5.0,3.4,0,setosa
1.5,0.2,5.2,3.5,0,setosa
1.4,0.2,5.2,3.4,0,setosa
1.6,0.2,4.7,3.2,0,setosa
1.6,0.2,4.8,3.1,0,setosa
1.5,0.4,5.4,3.4,0,setosa
1.5,0.1,5.2,4.1,0,setosa
1.4,0.2,5.5,4.2,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.2,0.2,5.0,3.2,0,setosa
1.3,0.2,5.5,3.5,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.3,0.2,4.4,3.0,0,setosa
1.5,0.2,5.1,3.4,0,setosa
1.3,0.3,5.0,3.5,0,setosa
1.3,0.3,4.5,2.3,0,setosa
1.3,0.2,4.4,3.2,0,setosa
1.6,0.6,5.0,3.5,0,setosa
1.9,0.4,5.1,3.8,0,setosa
1.4,0.3,4.8,3.0,0,setosa
1.6,0.2,5.1,3.8,0,setosa
1.4,0.2,4.6,3.2,0,setosa
1.5,0.2,5.3,3.7,0,setosa
1.4,0.2,5.0,3.3,0,setosa
4.7,1.4,7.0,3.2,1,versicolor
4.5,1.5,6.4,3.2,1,versicolor
4.9,1.5,6.9,3.1,1,versicolor
4.0,1.3,5.5,2.3,1,versicolor
4.6,1.5,6.5,2.8,1,versicolor
4.5,1.3,5.7,2.8,1,versicolor
4.7,1.6,6.3,3.3,1,versicolor
3.3,1.0,4.9,2.4,1,versicolor
4.6,1.3,6.6,2.9,1,versicolor
3.9,1.4,5.2,2.7,1,versicolor
3.5,1.0,5.0,2.0,1,versicolor
4.2,1.5,5.9,3.0,1,versicolor
4.0,1.0,6.0,2.2,1,versicolor
4.7,1.4,6.1,2.9,1,versicolor
3.6,1.3,5.6,2.9,1,versicolor
4.4,1.4,6.7,3.1,1,versicolor
4.5,1.5,5.6,3.0,1,versicolor
4.1,1.0,5.8,2.7,1,versicolor
4.5,1.5,6.2,2.2,1,versicolor
3.9,1.1,5.6,2.5,1,versicolor
4.8,1.8,5.9,3.2,1,versicolor
4.0,1.3,6.1,2.8,1,versicolor
4.9,1.5,6.3,2.5,1,versicolor
4.7,1.2,6.1,2.8,1,versicolor
4.3,1.3,6.4,2.9,1,versicolor
4.4,1.4,6.6,3.0,1,versicolor
4.8,1.4,6.8,2.8,1,versicolor
5.0,1.7,6.7,3.0,1,versicolor
4.5,1.5,6.0,2.9,1,versicolor
3.5,1.0,5.7,2.6,1,versicolor
3.8,1.1,5.5,2.4,1,versicolor
3.7,1.0,5.5,2.4,1,versicolor
3.9,1.2,5.8,2.7,1,versicolor
5.1,1.6,6.0,2.7,1,versicolor
4.5,1.5,5.4,3.0,1,versicolor
4.5,1.6,6.0,3.4,1,versicolor
4.7,1.5,6.7,3.1,1,versicolor
4.4,1.3,6.3,2.3,1,versicolor
4.1,1.3,5.6,3.0,1,versicolor
4.0,1.3,5.5,2.5,1,versicolor
4.4,1.2,5.5,2.6,1,versicolor
4.6,1.4,6.1,3.0,1,versicolor
4.0,1.2,5.8,2.6,1,versicolor
3.3,1.0,5.0,2.3,1,versicolor
4.2,1.3,5.6,2.7,1,versicolor
4.2,1.2,5.7,3.0,1,versicolor
4.2,1.3,5.7,2.9,1,versicolor
4.3,1.3,6.2,2.9,1,versicolor
3.0,1.1,5.1,2.5,1,versicolor
4.1,1.3,5.7,2.8,1,versicolor
59 changes: 59 additions & 0 deletions src/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Implement a K-Nearest Neighbors aglorithm."""
import pandas as pd
from math import sqrt


class KNearestNeighbors(object):
"""Define a K-Nearest Neighbors object."""

def __init__(self, data, k=5):
"""Initialize a k nearest neighbors object."""
if type(k) is not int or k <= 0:
raise ValueError("Please initalize with positive integer.")
else:
self.k = k
if type(data) is not pd.DataFrame:
try:
self.data = pd.DataFrame(data)
except pd.PandasError:
pass
else:
self.data = data

def predict(self, test_data, test_k_val=None):
"""Given data, categorize the data by its k nearest neighbors."""
if test_k_val is None:
test_k_val = self.k
if type(test_data) is not pd.DataFrame:
try:
test_data = pd.Series(test_data)
except pd.PandasError:
raise ValueError("BAD DATA YA TURKEY")
distances = []
for row in self.data.iterrows():
distances.append((row[1][-1], self._distance(row[1], test_data)))
distances.sort(key=lambda x: x[1])
my_class = self._classify(distances[:test_k_val])
if my_class:
return my_class
else:
self.predict(test_data, test_k_val - 1)

def _classify(self, res_list):
"""Classify an object given a set of data about its classes."""
classes = {item[0] for item in res_list}
class_counts = []
for a_class in classes:
class_counts.append((a_class, len([item for item in res_list if item[0] == a_class])))
class_counts.sort(key=lambda x: x[1], reverse=True)
if len(class_counts) > 1 and class_counts[0][1] == class_counts[1][1]:
return
else:
return class_counts[0][0]

def _distance(self, row1, row2):
"""Calcute the distance between two rows."""
dist = 0.0
for i in range(len(row1) - 1):
dist += (row1[i] - row2[i]) ** 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing out on using the power of Numpy (or pandas) here to broadcast mathematical operations. If row1 and row2 are numpy arrays, then you could just have

return sqrt(np.sum((row1 - row2)**2))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Written this way to account for the difference in length between rows. Test data is submitted without a "classification" column. Present data has such columns.

return sqrt(dist)
11 changes: 11 additions & 0 deletions src/test_dataset2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
x,y,class_names
2.771244718,1.784783929,0
1.728571309,1.169761413,0
3.678319846,2.81281357,0
3.961043357,2.61995032,0
2.999208922,2.209014212,0
7.497545867,3.162953546,1
9.00220326,3.339047188,1
7.444542326,0.476683375,1
10.12493903,3.234550982,1
6.642287351,3.319983761,1
Loading