-
Notifications
You must be signed in to change notification settings - Fork 2
Knn #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pasaunders
wants to merge
22
commits into
master
Choose a base branch
from
knn
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Knn #21
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
2f72f7d
Initial commit for decision tree.
CCallahanIV 2ad032d
including data file.
CCallahanIV a944f66
Wrote out initial tree class.
CCallahanIV cb95f60
added some stuff.
CCallahanIV 38e6948
gini and split functions added
pasaunders 7b6964d
test database.
CCallahanIV 8529b82
Wrote and tested test_split and calculate gini.
CCallahanIV e8e6f4c
Wrote out draft of fit, predict, and helper functions.
CCallahanIV 772a3e6
testing _get_split
pasaunders ac1b615
Merge branch 'decision-tree' of https://github.com/CCallahanIV/data-s…
pasaunders 5646a18
Pushing up tinkering changes and troubleshooting gini splits.
CCallahanIV 36f5025
Troubleshooting weird behavior. Passing the torch.
CCallahanIV 149a405
stopping fiddling for now.
CCallahanIV 6fbc80a
Wrote out first draft of knn.
CCallahanIV 132c87e
testing distance function, debugging distance fucntion.
pasaunders dfb41ec
edge case test of distance function - zero distance
pasaunders 5d2d7be
Troubleshooting knn.py and wrote test_simple_prediction
CCallahanIV 323abe6
fixed merge conlflict in test_knn
CCallahanIV 1baad41
classify unit test
pasaunders 7248c2e
run test predictions of the entire flowers dataset, tests pass.
pasaunders 5a92423
cleared corpse code, more semantic vars
pasaunders e841260
added readme section
pasaunders File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import pandas as pd | ||
|
|
||
|
|
||
| class TreeNode(object): | ||
| """Define a Node object for use in a decision tree classifier.""" | ||
|
|
||
| def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None): | ||
| """Initialize a node object for a decision tree classifier.""" | ||
| self.left = left | ||
| self.right = right | ||
| self.data = data | ||
| self.split_value = split_value | ||
| self.split_gini = split_gini | ||
| self.split_col = split_col | ||
| self.label = label | ||
|
|
||
| def _has_children(self): | ||
| """Return True or False if Node has children.""" | ||
| if self.right or self.left: | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| class DecisionTree(object): | ||
| """Define a Decision Tree class object.""" | ||
|
|
||
| def __init__(self, max_depth, min_leaf_size): | ||
| """Initialize a Decision Tree object.""" | ||
| self.max_depth = max_depth | ||
| self.min_leaf_size = min_leaf_size | ||
| self.root = None | ||
|
|
||
| def fit(self, data): | ||
| """Create a tree to fit the data.""" | ||
| split_col, split_value, split_gini, split_groups = self._get_split(data) | ||
| new_node = TreeNode(data, split_value, split_gini, split_col) | ||
| self.root = new_node | ||
| self.root.left = self._build_tree(split_groups[0]) | ||
| self.root.right = self._build_tree(split_groups[1]) | ||
|
|
||
| def _build_tree(self, data, depth_count=1): | ||
| """Given a node, build the tree.""" | ||
| split_col, split_value, split_gini, split_groups = self._get_split(data) | ||
| new_node = TreeNode(data, split_value, split_gini, split_col) | ||
| try: | ||
| new_node.label = data[data.columns[-1]].mode()[0] | ||
| except: | ||
| pass | ||
| if self._can_split(split_gini, depth_count, len(data)): | ||
| print("splitting") | ||
| new_node.left = self._build_tree(split_groups[0], depth_count + 1) | ||
| new_node.right = self._build_tree(split_groups[1], depth_count + 1) | ||
| else: | ||
| print("terminating") | ||
| return new_node | ||
| return new_node | ||
|
|
||
| def _can_split(self, gini, depth_count, data_size): | ||
| """Given a gini value, determine whether or not tree can split.""" | ||
| if gini == 0.0: | ||
| print("Bad gini") | ||
| return False | ||
| elif depth_count >= self.max_depth: | ||
| print("bad depth") | ||
| return False | ||
| elif data_size <= self.min_leaf_size: | ||
| print("bad data size") | ||
| return False | ||
| else: | ||
| return True | ||
|
|
||
| def _depth(self, start=''): | ||
| """Return the integer depth of the BST.""" | ||
| def depth_wrapped(start): | ||
| if start is None: | ||
| return 0 | ||
| else: | ||
| right_depth = depth_wrapped(start.right) | ||
| left_depth = depth_wrapped(start.left) | ||
| return max(right_depth, left_depth) + 1 | ||
| if start is '': | ||
| return depth_wrapped(self.root) | ||
| else: | ||
| return depth_wrapped(start) | ||
|
|
||
| def _calculate_gini(self, groups, class_values): | ||
| """Calculate gini for a given data_set.""" | ||
| gini = 0.0 | ||
| for class_value in class_values: | ||
| for group in groups: | ||
| size = len(group) | ||
| if size == 0: | ||
| continue | ||
| import pdb;pdb.set_trace() | ||
| proportion = len(group[group[group.columns[-1]] == class_value]) / float(size) | ||
| gini += (proportion * (1.0 - proportion)) | ||
| return gini | ||
|
|
||
| def _get_split(self, data): | ||
| """Choose a split point with lowest gini index.""" | ||
| classes = data[data.columns[-1]].unique() | ||
| split_col, split_value, split_gini, split_groups =\ | ||
| float('inf'), float('inf'), float('inf'), None | ||
| for col in data.columns.values[:-1]: | ||
| for row in data.iterrows(): | ||
| groups = self._test_split(col, row[1][col], data) | ||
| gini = self._calculate_gini(groups, classes) | ||
| if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: | ||
| split_col, split_value, split_gini, split_groups =\ | ||
| col, row[1][col], gini, groups | ||
| return split_col, split_value, split_gini, split_groups | ||
|
|
||
| def _test_split(self, col, value, data): | ||
| """Given a dataset, column index, and value, split the dataset.""" | ||
| left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns) | ||
| for row in data.iterrows(): | ||
| if row[1][col] < value: | ||
| left = left.append(row[1]) | ||
| else: | ||
| right = right.append(row[1]) | ||
| return left, right | ||
|
|
||
| def predict(self, data): | ||
| """Given data, return labels for that data.""" | ||
| curr_node = self.root | ||
| while curr_node._has_children(): | ||
| if data[curr_node.label]: | ||
| curr_node = curr_node.right | ||
| else: | ||
| curr_node = curr_node.left | ||
| return curr_node.label | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names | ||
| 1.4,0.2,5.1,3.5,0,setosa | ||
| 1.4,0.2,4.9,3.0,0,setosa | ||
| 1.3,0.2,4.7,3.2,0,setosa | ||
| 1.5,0.2,4.6,3.1,0,setosa | ||
| 1.4,0.2,5.0,3.6,0,setosa | ||
| 1.7,0.4,5.4,3.9,0,setosa | ||
| 1.4,0.3,4.6,3.4,0,setosa | ||
| 1.5,0.2,5.0,3.4,0,setosa | ||
| 1.4,0.2,4.4,2.9,0,setosa | ||
| 1.5,0.1,4.9,3.1,0,setosa | ||
| 1.5,0.2,5.4,3.7,0,setosa | ||
| 1.6,0.2,4.8,3.4,0,setosa | ||
| 1.4,0.1,4.8,3.0,0,setosa | ||
| 1.1,0.1,4.3,3.0,0,setosa | ||
| 1.2,0.2,5.8,4.0,0,setosa | ||
| 1.5,0.4,5.7,4.4,0,setosa | ||
| 1.3,0.4,5.4,3.9,0,setosa | ||
| 1.4,0.3,5.1,3.5,0,setosa | ||
| 1.7,0.3,5.7,3.8,0,setosa | ||
| 1.5,0.3,5.1,3.8,0,setosa | ||
| 1.7,0.2,5.4,3.4,0,setosa | ||
| 1.5,0.4,5.1,3.7,0,setosa | ||
| 1.0,0.2,4.6,3.6,0,setosa | ||
| 1.7,0.5,5.1,3.3,0,setosa | ||
| 1.9,0.2,4.8,3.4,0,setosa | ||
| 1.6,0.2,5.0,3.0,0,setosa | ||
| 1.6,0.4,5.0,3.4,0,setosa | ||
| 1.5,0.2,5.2,3.5,0,setosa | ||
| 1.4,0.2,5.2,3.4,0,setosa | ||
| 1.6,0.2,4.7,3.2,0,setosa | ||
| 1.6,0.2,4.8,3.1,0,setosa | ||
| 1.5,0.4,5.4,3.4,0,setosa | ||
| 1.5,0.1,5.2,4.1,0,setosa | ||
| 1.4,0.2,5.5,4.2,0,setosa | ||
| 1.5,0.1,4.9,3.1,0,setosa | ||
| 1.2,0.2,5.0,3.2,0,setosa | ||
| 1.3,0.2,5.5,3.5,0,setosa | ||
| 1.5,0.1,4.9,3.1,0,setosa | ||
| 1.3,0.2,4.4,3.0,0,setosa | ||
| 1.5,0.2,5.1,3.4,0,setosa | ||
| 1.3,0.3,5.0,3.5,0,setosa | ||
| 1.3,0.3,4.5,2.3,0,setosa | ||
| 1.3,0.2,4.4,3.2,0,setosa | ||
| 1.6,0.6,5.0,3.5,0,setosa | ||
| 1.9,0.4,5.1,3.8,0,setosa | ||
| 1.4,0.3,4.8,3.0,0,setosa | ||
| 1.6,0.2,5.1,3.8,0,setosa | ||
| 1.4,0.2,4.6,3.2,0,setosa | ||
| 1.5,0.2,5.3,3.7,0,setosa | ||
| 1.4,0.2,5.0,3.3,0,setosa | ||
| 4.7,1.4,7.0,3.2,1,versicolor | ||
| 4.5,1.5,6.4,3.2,1,versicolor | ||
| 4.9,1.5,6.9,3.1,1,versicolor | ||
| 4.0,1.3,5.5,2.3,1,versicolor | ||
| 4.6,1.5,6.5,2.8,1,versicolor | ||
| 4.5,1.3,5.7,2.8,1,versicolor | ||
| 4.7,1.6,6.3,3.3,1,versicolor | ||
| 3.3,1.0,4.9,2.4,1,versicolor | ||
| 4.6,1.3,6.6,2.9,1,versicolor | ||
| 3.9,1.4,5.2,2.7,1,versicolor | ||
| 3.5,1.0,5.0,2.0,1,versicolor | ||
| 4.2,1.5,5.9,3.0,1,versicolor | ||
| 4.0,1.0,6.0,2.2,1,versicolor | ||
| 4.7,1.4,6.1,2.9,1,versicolor | ||
| 3.6,1.3,5.6,2.9,1,versicolor | ||
| 4.4,1.4,6.7,3.1,1,versicolor | ||
| 4.5,1.5,5.6,3.0,1,versicolor | ||
| 4.1,1.0,5.8,2.7,1,versicolor | ||
| 4.5,1.5,6.2,2.2,1,versicolor | ||
| 3.9,1.1,5.6,2.5,1,versicolor | ||
| 4.8,1.8,5.9,3.2,1,versicolor | ||
| 4.0,1.3,6.1,2.8,1,versicolor | ||
| 4.9,1.5,6.3,2.5,1,versicolor | ||
| 4.7,1.2,6.1,2.8,1,versicolor | ||
| 4.3,1.3,6.4,2.9,1,versicolor | ||
| 4.4,1.4,6.6,3.0,1,versicolor | ||
| 4.8,1.4,6.8,2.8,1,versicolor | ||
| 5.0,1.7,6.7,3.0,1,versicolor | ||
| 4.5,1.5,6.0,2.9,1,versicolor | ||
| 3.5,1.0,5.7,2.6,1,versicolor | ||
| 3.8,1.1,5.5,2.4,1,versicolor | ||
| 3.7,1.0,5.5,2.4,1,versicolor | ||
| 3.9,1.2,5.8,2.7,1,versicolor | ||
| 5.1,1.6,6.0,2.7,1,versicolor | ||
| 4.5,1.5,5.4,3.0,1,versicolor | ||
| 4.5,1.6,6.0,3.4,1,versicolor | ||
| 4.7,1.5,6.7,3.1,1,versicolor | ||
| 4.4,1.3,6.3,2.3,1,versicolor | ||
| 4.1,1.3,5.6,3.0,1,versicolor | ||
| 4.0,1.3,5.5,2.5,1,versicolor | ||
| 4.4,1.2,5.5,2.6,1,versicolor | ||
| 4.6,1.4,6.1,3.0,1,versicolor | ||
| 4.0,1.2,5.8,2.6,1,versicolor | ||
| 3.3,1.0,5.0,2.3,1,versicolor | ||
| 4.2,1.3,5.6,2.7,1,versicolor | ||
| 4.2,1.2,5.7,3.0,1,versicolor | ||
| 4.2,1.3,5.7,2.9,1,versicolor | ||
| 4.3,1.3,6.2,2.9,1,versicolor | ||
| 3.0,1.1,5.1,2.5,1,versicolor | ||
| 4.1,1.3,5.7,2.8,1,versicolor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| """Implement a K-Nearest Neighbors aglorithm.""" | ||
| import pandas as pd | ||
| from math import sqrt | ||
|
|
||
|
|
||
| class KNearestNeighbors(object): | ||
| """Define a K-Nearest Neighbors object.""" | ||
|
|
||
| def __init__(self, data, k=5): | ||
| """Initialize a k nearest neighbors object.""" | ||
| if type(k) is not int or k <= 0: | ||
| raise ValueError("Please initalize with positive integer.") | ||
| else: | ||
| self.k = k | ||
| if type(data) is not pd.DataFrame: | ||
| try: | ||
| self.data = pd.DataFrame(data) | ||
| except pd.PandasError: | ||
| pass | ||
| else: | ||
| self.data = data | ||
|
|
||
| def predict(self, test_data, test_k_val=None): | ||
| """Given data, categorize the data by its k nearest neighbors.""" | ||
| if test_k_val is None: | ||
| test_k_val = self.k | ||
| if type(test_data) is not pd.DataFrame: | ||
| try: | ||
| test_data = pd.Series(test_data) | ||
| except pd.PandasError: | ||
| raise ValueError("BAD DATA YA TURKEY") | ||
| distances = [] | ||
| for row in self.data.iterrows(): | ||
| distances.append((row[1][-1], self._distance(row[1], test_data))) | ||
| distances.sort(key=lambda x: x[1]) | ||
| my_class = self._classify(distances[:test_k_val]) | ||
| if my_class: | ||
| return my_class | ||
| else: | ||
| self.predict(test_data, test_k_val - 1) | ||
|
|
||
| def _classify(self, res_list): | ||
| """Classify an object given a set of data about its classes.""" | ||
| classes = {item[0] for item in res_list} | ||
| class_counts = [] | ||
| for a_class in classes: | ||
| class_counts.append((a_class, len([item for item in res_list if item[0] == a_class]))) | ||
| class_counts.sort(key=lambda x: x[1], reverse=True) | ||
| if len(class_counts) > 1 and class_counts[0][1] == class_counts[1][1]: | ||
| return | ||
| else: | ||
| return class_counts[0][0] | ||
|
|
||
| def _distance(self, row1, row2): | ||
| """Calcute the distance between two rows.""" | ||
| dist = 0.0 | ||
| for i in range(len(row1) - 1): | ||
| dist += (row1[i] - row2[i]) ** 2 | ||
| return sqrt(dist) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| x,y,class_names | ||
| 2.771244718,1.784783929,0 | ||
| 1.728571309,1.169761413,0 | ||
| 3.678319846,2.81281357,0 | ||
| 3.961043357,2.61995032,0 | ||
| 2.999208922,2.209014212,0 | ||
| 7.497545867,3.162953546,1 | ||
| 9.00220326,3.339047188,1 | ||
| 7.444542326,0.476683375,1 | ||
| 10.12493903,3.234550982,1 | ||
| 6.642287351,3.319983761,1 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing out on using the power of Numpy (or pandas) here to broadcast mathematical operations. If row1 and row2 are numpy arrays, then you could just have
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Written this way to account for the difference in length between rows. Test data is submitted without a "classification" column. Present data has such columns.