Skip to content

Commit f6a2d44

Browse files
authored
Merge pull request #13 from stdiff/dev
Dev
2 parents 41b744c + 18d0fd7 commit f6a2d44

8 files changed

Lines changed: 304 additions & 344 deletions

File tree

.travis.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ install:
1010
- sudo apt-get install graphviz
1111
- pip install -r requirements.txt
1212
script:
13-
- python -m unittest test/test_processing.py
14-
- python -m unittest test/test_utilities.py
15-
- python -m unittest test/test_modeling.py
13+
- coverage run -m unittest test/test_processing.py
14+
- coverage run -m unittest test/test_utilities.py
15+
- coverage run -m unittest test/test_modeling.py
16+
- bash <(curl -s https://codecov.io/bash)
1617
- python setup.py sdist bdist_wheel
1718
- pip install dist/adhoc-0.2-py3-none-any.whl
1819
- cd notebooks
1920
- python usage-processing.py
2021
- python usage-modeling.py
22+
23+

adhoc/modeling.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
## for show_tree
1616
from io import StringIO
1717
from sklearn.tree import export_graphviz
18+
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
1819
from IPython.display import Image
1920
import pydot
2021

@@ -178,17 +179,18 @@ def show_coefficients(grid:GridSearchCV, columns:List[str]) -> pd.DataFrame:
178179
return df_coef
179180

180181

181-
def show_tree(grid:GridSearchCV, columns:List[str]) -> Image:
182+
def show_tree(model:Union[DecisionTreeRegressor, DecisionTreeClassifier, GridSearchCV],
183+
columns:Union[List[str],pd.Index]) -> Image:
182184
"""
183185
Visualize the given trained DecisionTree model on Jupyter.
184186
This function requires also pydot.
185187
186-
:param grid: GridSearchCV instance with a Tree model
188+
:param model: DecisionTree model or GridSearchCV instance with a Tree model
187189
:param columns: names of columns
188190
:return: Image instance (for Jupyter)
189191
"""
190-
191-
model = pick_the_last_estimator(grid)
192+
if isinstance(model, GridSearchCV):
193+
model = pick_the_last_estimator(model)
192194

193195
dot_data = StringIO()
194196
export_graphviz(model, out_file=dot_data, feature_names=columns,
@@ -197,17 +199,20 @@ def show_tree(grid:GridSearchCV, columns:List[str]) -> Image:
197199
return Image(graph.create_png())
198200

199201

200-
def show_feature_importance(grid:GridSearchCV, columns:List[str]) -> pd.Series:
202+
def show_feature_importance(estimator:BaseEstimator,
203+
columns:Union[List[str],pd.Index]) -> pd.Series:
201204
"""
202205
Return the series of feature importance of given random forest model.
203206
XGB model and DecisionTree model can be accepted as well.
204207
205-
:param grid: fitted GridSearchCV instance with `feature_importances_` attribute
206-
:param columns: list of column names
208+
:param estimator: fitted Estimator with `feature_importances_` attribute
209+
:param columns: list (or pandas.Index) of column names
207210
:return: Series of feature importance
208211
"""
209-
## TODO: change the name of the method. (Do not use show_)
210-
model = pick_the_last_estimator(grid)
212+
if isinstance(estimator, GridSearchCV):
213+
model = pick_the_last_estimator(estimator)
214+
else:
215+
model = estimator
211216

212217
if hasattr(model,"feature_importances_"):
213218
s = pd.Series(model.feature_importances_, index=columns, name="importance")
@@ -216,8 +221,6 @@ def show_feature_importance(grid:GridSearchCV, columns:List[str]) -> pd.Series:
216221
raise AttributeError("Your model does not have an attribute 'feature_importances_'.")
217222

218223

219-
220-
221224
def recover_label(data:pd.DataFrame, field2columns:Dict[str,List[str]],
222225
sep:str=None, other:str="other", inplace=True) -> pd.DataFrame:
223226
"""

adhoc/utilities.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
import tempfile
1010
import shutil
11+
import re
1112

1213
import numpy as np
1314
import pandas as pd
@@ -156,10 +157,18 @@ def fetch_adult_dataset(csv_path:Path):
156157
raise Exception("You seem to have downloaded a wrong file")
157158

158159

160+
def grep_data(data:pd.DataFrame, column:str, expr:str) -> pd.DataFrame:
161+
"""
162+
Pick the rows with a specified expression and return them as
163+
the subset of the given DataFrame.
159164
160-
161-
162-
pass
165+
:param data: panda's DataFrame
166+
:param column: column to check
167+
:param expr: expression to find (passed to re.search)
168+
:return: copy of the matched rows
169+
"""
170+
s_matched = data[column].apply(lambda s: True if re.search(expr,str(s)) else False)
171+
return data[s_matched].copy()
163172

164173

165174
def facet_grid_scatter_plot(data:pd.DataFrame, row:str, col:str,

0 commit comments

Comments
 (0)