1515## for show_tree
1616from io import StringIO
1717from sklearn .tree import export_graphviz
18+ from sklearn .tree import DecisionTreeRegressor , DecisionTreeClassifier
1819from IPython .display import Image
1920import pydot
2021
@@ -178,17 +179,18 @@ def show_coefficients(grid:GridSearchCV, columns:List[str]) -> pd.DataFrame:
178179 return df_coef
179180
180181
181- def show_tree (grid :GridSearchCV , columns :List [str ]) -> Image :
182+ def show_tree (model :Union [DecisionTreeRegressor , DecisionTreeClassifier , GridSearchCV ],
183+ columns :Union [List [str ],pd .Index ]) -> Image :
182184 """
183185 Visualize the given trained DecisionTree model on Jupyter.
184186 This function requires also pydot.
185187
186- :param grid: GridSearchCV instance with a Tree model
188+ :param model: DecisionTree model or GridSearchCV instance with a Tree model
187189 :param columns: names of columns
188190 :return: Image instance (for Jupyter)
189191 """
190-
191- model = pick_the_last_estimator (grid )
192+ if isinstance ( model , GridSearchCV ):
193+ model = pick_the_last_estimator (model )
192194
193195 dot_data = StringIO ()
194196 export_graphviz (model , out_file = dot_data , feature_names = columns ,
@@ -197,17 +199,20 @@ def show_tree(grid:GridSearchCV, columns:List[str]) -> Image:
197199 return Image (graph .create_png ())
198200
199201
200- def show_feature_importance (grid :GridSearchCV , columns :List [str ]) -> pd .Series :
202+ def show_feature_importance (estimator :BaseEstimator ,
203+ columns :Union [List [str ],pd .Index ]) -> pd .Series :
201204 """
202205 Return the series of feature importance of given random forest model.
203206 XGB model and DecisionTree model can be accepted as well.
204207
205- :param grid : fitted GridSearchCV instance with `feature_importances_` attribute
206- :param columns: list of column names
208+ :param estimator : fitted Estimator with `feature_importances_` attribute
209+ :param columns: list (or pandas.Index) of column names
207210 :return: Series of feature importance
208211 """
209- ## TODO: change the name of the method. (Do not use show_)
210- model = pick_the_last_estimator (grid )
212+ if isinstance (estimator , GridSearchCV ):
213+ model = pick_the_last_estimator (estimator )
214+ else :
215+ model = estimator
211216
212217 if hasattr (model ,"feature_importances_" ):
213218 s = pd .Series (model .feature_importances_ , index = columns , name = "importance" )
@@ -216,8 +221,6 @@ def show_feature_importance(grid:GridSearchCV, columns:List[str]) -> pd.Series:
216221 raise AttributeError ("Your model does not have an attribute 'feature_importances_'." )
217222
218223
219-
220-
221224def recover_label (data :pd .DataFrame , field2columns :Dict [str ,List [str ]],
222225 sep :str = None , other :str = "other" , inplace = True ) -> pd .DataFrame :
223226 """
0 commit comments