diff --git a/src/tabpfn_extensions/rf_pfn/sklearn_based_decision_tree_tabpfn.py b/src/tabpfn_extensions/rf_pfn/sklearn_based_decision_tree_tabpfn.py index 14df899b..b58334ba 100644 --- a/src/tabpfn_extensions/rf_pfn/sklearn_based_decision_tree_tabpfn.py +++ b/src/tabpfn_extensions/rf_pfn/sklearn_based_decision_tree_tabpfn.py @@ -3,6 +3,8 @@ from __future__ import annotations +import logging +import random import warnings # For type checking only @@ -37,6 +39,8 @@ ) from tabpfn_extensions.utils import softmax +logger = logging.getLogger(__name__) + ############################################################################### # BASE DECISION TREE # ############################################################################### @@ -292,6 +296,8 @@ def _fit( self : DecisionTreeTabPFNBase The fitted model. """ + if self.verbose: + logger.info("Starting DecisionTreeTabPFN fit process...") # Initialize attributes (per scikit-learn conventions) self._leaf_nodes = [] self._leaf_train_data = {} @@ -309,6 +315,8 @@ def _fit( y, ensure_all_finite=False, # scikit-learn sets self.n_features_in_ automatically ) + if self.verbose: + logger.info(f"Input data shape: X={X.shape}, y={y.shape}") if self.task_type == "multiclass": self.classes_ = unique_labels(y) @@ -337,6 +345,10 @@ def _fit( # If adaptive_tree is on, do a train/validation split if self.adaptive_tree: + if self.verbose: + logger.info( + "Adaptive tree is enabled. Preparing train/validation split." + ) stratify = y_ if (self.task_type == "multiclass") else None # Basic checks for classification to see if splitting is feasible @@ -344,6 +356,10 @@ def _fit( unique_classes, counts = np.unique(y_, return_counts=True) # Disable adaptive tree in extreme cases if counts.min() == 1 or len(unique_classes) < 2: + if self.verbose: + logger.info( + "Disabling adaptive tree: minimum class count is 1 or only one class present." + ) self.adaptive_tree = False elif len(unique_classes) > int(len(y_) * self.adaptive_tree_test_size): self.adaptive_tree_test_size = min( @@ -351,6 +367,8 @@ def _fit( len(unique_classes) / len(y_) * 1.5, ) if len(y_) < 10: + if self.verbose: + logger.info("Disabling adaptive tree: fewer than 10 samples.") self.adaptive_tree = False if self.adaptive_tree: @@ -372,9 +390,18 @@ def _fit( random_state=self.random_state, stratify=stratify, ) + if self.verbose: + logger.info( + f"Train/Valid split created: " + f"Train size={len(y_train)}, Valid size={len(y_valid)}" + ) # Safety check - if split is empty, revert if len(y_train) == 0 or len(y_valid) == 0: + if self.verbose: + logger.info( + "Disabling adaptive tree: train or validation split is empty." + ) self.adaptive_tree = False X_train, X_preproc_train, y_train, sw_train = ( X, @@ -390,6 +417,10 @@ def _fit( and self.adaptive_tree and (len(np.unique(y_train)) != len(np.unique(y_valid))) ): + if self.verbose: + logger.info( + "Disabling adaptive tree: train and validation sets have different classes." + ) self.adaptive_tree = False else: # If we were disabled, keep all data as training @@ -402,6 +433,8 @@ def _fit( X_valid = X_preproc_valid = y_valid = sw_valid = None else: # Not adaptive, everything is train + if self.verbose: + logger.info("Adaptive tree is disabled. Using all data for training.") X_train, X_preproc_train, y_train, sw_train = ( X, X_preprocessed, @@ -411,9 +444,15 @@ def _fit( X_valid = X_preproc_valid = y_valid = sw_valid = None # Build the sklearn decision tree + if self.verbose: + logger.info("Fitting the initial scikit-learn decision tree structure...") self._decision_tree = self._init_decision_tree() self._decision_tree.fit(X_preproc_train, y_train, sample_weight=sw_train) self._tree = self._decision_tree # for sklearn compatibility + if self.verbose: + logger.info( + f"Decision tree fitting complete. Tree has {self._tree.tree_.node_count} nodes." + ) # Keep references for potential post-fitting (leaf-level fitting) self.X = X @@ -431,6 +470,8 @@ def _fit( # We will do a leaf-fitting step on demand (lazy) in predict self._need_post_fit = True + if self.verbose: + logger.info("Leaf fitting is deferred until the first predict() call.") # If verbose, optionally do it right away: if self.verbose: @@ -453,7 +494,7 @@ def _init_decision_tree(self) -> BaseDecisionTree: def _post_fit(self) -> None: """Hook after the decision tree is fitted. Can be used for final prints/logs.""" if self.verbose: - pass + logger.info("Base tree structure has been fitted.") def _preprocess_data_for_tree(self, X: np.ndarray) -> np.ndarray: """Handle missing data prior to feeding into the decision tree. @@ -612,15 +653,25 @@ def _predict_internal( """ # If we haven't yet done the final leaf fit, do it here if self._need_post_fit: + if self.verbose: + logger.info("First predict call: executing deferred leaf fitting.") self._need_post_fit = False if self.adaptive_tree: # Fit leaves on train data, check performance on valid data if available + if self.verbose: + logger.info( + "Fitting leaves on training data for adaptive pruning..." + ) self.fit_leaves(self.train_X, self.train_y) if ( hasattr(self, "valid_X") and self.valid_X is not None and self.valid_y is not None ): + if self.verbose: + logger.info( + "Evaluating node performance on validation set for pruning decisions." + ) # Force a pass to evaluate node performance # so we can prune or decide node updates self._predict_internal( @@ -629,6 +680,8 @@ def _predict_internal( check_input=False, ) # Now fit leaves again using the entire dataset (train + valid, effectively) + if self.verbose: + logger.info("Fitting leaves on the full dataset.") self.fit_leaves(self.X, self.y) # Assign TabPFNs categorical features if needed @@ -638,6 +691,10 @@ def _predict_internal( # Find leaf membership in X X_leaf_nodes = self._apply_tree(X) n_samples, n_nodes, n_estims = X_leaf_nodes.shape + if self.verbose: + logger.info( + f"Starting prediction for {n_samples} samples across {n_nodes} nodes." + ) # Track intermediate predictions y_prob: dict[int, dict[int, np.ndarray]] = {} @@ -693,6 +750,13 @@ def _predict_internal( X_leaf_nodes[test_sample_indices, leaf_id + 1 :, est_id].sum() == 0.0 ) + if self.verbose: + logger.info( + f"Processing Node {leaf_id}: " + f"Train Samples={X_train_leaf.shape[0]}, " + f"Test Samples={len(test_sample_indices)}, " + f"Is Final Leaf={is_leaf}" + ) # If it's not a leaf and we are not fitting internal nodes, skip # (unless leaf_id==0 and we do a top-level check for adaptive_tree) @@ -701,6 +765,10 @@ def _predict_internal( and (not self.fit_nodes) and not (leaf_id == 0 and self.adaptive_tree) ): + if self.verbose: + logger.info( + f" -> Skipping Node {leaf_id}: Not a final leaf and fit_nodes is False." + ) if do_pruning: self._node_prediction_type[est_id][leaf_id] = "previous" continue @@ -717,6 +785,10 @@ def _predict_internal( should_skip_previously_pruned = True if should_skip_previously_pruned: + if self.verbose: + logger.info( + f" -> Skipping Node {leaf_id}: Node was previously pruned." + ) continue # Skip if classification is missing a class @@ -725,6 +797,10 @@ def _predict_internal( and len(np.unique(y_train_leaf)) < self.n_classes_ and self.adaptive_tree_skip_class_missing ): + if self.verbose: + logger.info( + f" -> Skipping Node {leaf_id}: Not all classes are present in training data." + ) self._node_prediction_type[est_id][leaf_id] = "previous" continue @@ -741,6 +817,10 @@ def _predict_internal( and not is_leaf ) ): + if self.verbose: + logger.info( + f" -> Skipping Node {leaf_id}: Does not meet sample size requirements for adaptive fitting." + ) if do_pruning: self._node_prediction_type[est_id][leaf_id] = "previous" continue @@ -789,10 +869,18 @@ def _predict_internal( y, y_prob[est_id][leaf_id], ) + if self.verbose: + logger.info( + f" -> Pruning Result for Node {leaf_id}: " + f"Type='{self._node_prediction_type[est_id][leaf_id]}', " + f"Score={y_metric[est_id][leaf_id]:.4f}" + ) else: # If not validating and not adaptive, just use replacement y_prob[est_id][leaf_id] = y_prob_replacement + if self.verbose: + logger.info("Prediction process finished.") # Final predictions come from the last estimators last node return y_prob[n_estims - 1][n_nodes - 1] @@ -1144,12 +1232,18 @@ def _predict_leaf( # If only one class, fill probability 1.0 for that class if len(classes_in_leaf) == 1: + if self.verbose: + logger.info( + f" -> Node {leaf_id}: Only one class present. Predicting 1.0 for class {classes_in_leaf[0]}." + ) y_eval_prob[indices, classes_in_leaf[0]] = 1.0 return y_eval_prob # Otherwise, fit TabPFN leaf_seed = leaf_id + self.tree_seed try: + if self.verbose: + logger.info(f" -> Node {leaf_id}: Fitting TabPFNClassifier.") self.tabpfn.random_state = leaf_seed self.tabpfn.fit(X_train_leaf, y_train_leaf) @@ -1176,6 +1270,10 @@ def _predict_leaf( "One node has constant features for TabPFN. Using class-ratio fallback.", stacklevel=2, ) + if self.verbose: + logger.warning( + f" -> Node {leaf_id}: TabPFN failed due to constant features. Using class ratio fallback." + ) _, counts = np.unique(y_train_leaf, return_counts=True) ratio = counts / counts.sum() for i, c in enumerate(classes_in_leaf): @@ -1225,7 +1323,7 @@ def predict_proba(self, X: np.ndarray, check_input: bool = True) -> np.ndarray: def _post_fit(self) -> None: """Optional hook after the decision tree is fitted.""" if self.verbose: - pass + logger.info("Classifier tree structure has been fitted.") ############################################################################### @@ -1347,23 +1445,37 @@ def _predict_leaf( # If no training data or just 1 sample, fall back to 0 or single value if len(X_train_leaf) < 1: + if self.verbose: + logger.info( + f" -> Node {leaf_id}: No training samples. Predicting 0.0." + ) warnings.warn( f"Leaf {leaf_id} has zero training samples. Returning 0.0 predictions.", stacklevel=2, ) return y_eval elif len(X_train_leaf) == 1: + if self.verbose: + logger.info( + f" -> Node {leaf_id}: Only one training sample. Predicting its value." + ) y_eval[indices] = y_train_leaf[0] return y_eval # If all y are identical, return that constant if np.all(y_train_leaf == y_train_leaf[0]): + if self.verbose: + logger.info( + f" -> Node {leaf_id}: All target values are constant. Predicting {y_train_leaf[0]}." + ) y_eval[indices] = y_train_leaf[0] return y_eval # Fit TabPFNRegressor leaf_seed = leaf_id + self.tree_seed try: + if self.verbose: + logger.info(f" -> Node {leaf_id}: Fitting TabPFNRegressor.") self.tabpfn.random_state = leaf_seed self.tabpfn.fit(X_train_leaf, y_train_leaf) @@ -1383,6 +1495,10 @@ def _predict_leaf( f"TabPFN fit/predict failed at leaf {leaf_id}: {e}. Using mean fallback.", stacklevel=2, ) + if self.verbose: + logger.warning( + f" -> Node {leaf_id}: TabPFN failed ({e}). Using mean fallback." + ) y_eval[indices] = np.mean(y_train_leaf) return y_eval @@ -1436,4 +1552,4 @@ def predict_full(self, X: np.ndarray) -> np.ndarray: def _post_fit(self) -> None: """Optional hook after the regressor's tree is fitted.""" if self.verbose: - pass + logger.info("Regressor tree structure has been fitted.")