diff --git a/src/tabpfn/finetuning/finetuned_base.py b/src/tabpfn/finetuning/finetuned_base.py index 648c23141..1fe2b1bb0 100644 --- a/src/tabpfn/finetuning/finetuned_base.py +++ b/src/tabpfn/finetuning/finetuned_base.py @@ -136,6 +136,10 @@ class FinetunedTabPFNBase(BaseEstimator, ABC): data batches. This is helpful in most cases because, e.g., the column order will stay the same across batches. If False, the preprocessing will use a different random seed for each batch. + validation_frequency: How often (in epochs) to run validation. If set to + an integer N, validation is run every N epochs. If None, validation is + disabled entirely, which also disables early stopping. Defaults to 1 + (validate every epoch). """ def __init__( # noqa: PLR0913 @@ -163,6 +167,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + validation_frequency: int | None = 1, ): super().__init__() self.device = device @@ -188,6 +193,7 @@ def __init__( # noqa: PLR0913 self.save_checkpoint_interval = save_checkpoint_interval self.meta_batch_size = META_BATCH_SIZE self.use_fixed_preprocessing_seed = use_fixed_preprocessing_seed + self.validation_frequency = validation_frequency if self.use_fixed_preprocessing_seed and not ( self.n_estimators_finetune @@ -528,16 +534,26 @@ def _fit( # noqa: C901,PLR0912 use_amp = self.device.startswith("cuda") and torch.cuda.is_available() scaler = GradScaler() if use_amp else None # type: ignore - logger.info("--- 🚀 Eval default model ---") - eval_result = self._evaluate_model( - validation_eval_config, - X_train, # pyright: ignore[reportArgumentType] - y_train, # pyright: ignore[reportArgumentType] - X_val, # pyright: ignore[reportArgumentType] - y_val, # pyright: ignore[reportArgumentType] - ) - self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None) - best_metric: float = eval_result.primary + if self.validation_frequency is not None: + logger.info("--- 🚀 Eval default model ---") + eval_result = self._evaluate_model( + validation_eval_config, + X_train, # pyright: ignore[reportArgumentType] + y_train, # pyright: ignore[reportArgumentType] + X_val, # pyright: ignore[reportArgumentType] + y_val, # pyright: ignore[reportArgumentType] + ) + self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None) + best_metric: float = eval_result.primary + else: + if self.early_stopping: + warnings.warn( + "`early_stopping` is enabled but `validation_frequency` is None. " + "Early stopping requires validation; it will be disabled.", + UserWarning, + stacklevel=2, + ) + best_metric = self._get_initial_best_metric() static_seed, rng = infer_random_state(self.random_state) preprocessing_random_state = ( @@ -684,61 +700,67 @@ def _fit( # noqa: C901,PLR0912 epoch_loss_sum / epoch_batches if epoch_batches > 0 else None ) - eval_result = self._evaluate_model( - validation_eval_config, - X_train, # pyright: ignore[reportArgumentType] - y_train, # pyright: ignore[reportArgumentType] - X_val, # pyright: ignore[reportArgumentType] - y_val, # pyright: ignore[reportArgumentType] + run_validation = ( + self.validation_frequency is not None + and (epoch + 1) % self.validation_frequency == 0 ) - self._log_epoch_evaluation(epoch, eval_result, mean_train_loss) + if run_validation: + eval_result = self._evaluate_model( + validation_eval_config, + X_train, # pyright: ignore[reportArgumentType] + y_train, # pyright: ignore[reportArgumentType] + X_val, # pyright: ignore[reportArgumentType] + y_val, # pyright: ignore[reportArgumentType] + ) - primary_metric = eval_result.primary + self._log_epoch_evaluation(epoch, eval_result, mean_train_loss) - if output_dir is not None and not np.isnan(primary_metric): - save_interval_checkpoint = ( - self.save_checkpoint_interval is not None - and (epoch + 1) % self.save_checkpoint_interval == 0 - ) + primary_metric = eval_result.primary - is_best = self._is_improvement(primary_metric, best_metric) - - if save_interval_checkpoint or is_best: - save_checkpoint( - estimator=self.finetuned_estimator_, - output_dir=output_dir, - epoch=epoch + 1, - optimizer=optimizer, - metrics=self._get_checkpoint_metrics(eval_result), - train_size=train_size, - is_best=is_best, - save_interval_checkpoint=save_interval_checkpoint, + if output_dir is not None and not np.isnan(primary_metric): + save_interval_checkpoint = ( + self.save_checkpoint_interval is not None + and (epoch + 1) % self.save_checkpoint_interval == 0 ) - if self.early_stopping and not np.isnan(primary_metric): - if self._is_improvement(primary_metric, best_metric): - best_metric = primary_metric - patience_counter = 0 - best_model = copy.deepcopy(self.finetuned_estimator_) - else: - patience_counter += 1 - logger.info( - "⚠️ No improvement for %s epochs. Best %s: %.4f", - patience_counter, - self._metric_name, - best_metric, - ) + is_best = self._is_improvement(primary_metric, best_metric) + + if save_interval_checkpoint or is_best: + save_checkpoint( + estimator=self.finetuned_estimator_, + output_dir=output_dir, + epoch=epoch + 1, + optimizer=optimizer, + metrics=self._get_checkpoint_metrics(eval_result), + train_size=train_size, + is_best=is_best, + save_interval_checkpoint=save_interval_checkpoint, + ) - if patience_counter >= self.early_stopping_patience: - logger.info( - "🛑 Early stopping triggered. Best %s: %.4f", - self._metric_name, - best_metric, - ) - if best_model is not None: - self.finetuned_estimator_ = best_model - break + if self.early_stopping and not np.isnan(primary_metric): + if self._is_improvement(primary_metric, best_metric): + best_metric = primary_metric + patience_counter = 0 + best_model = copy.deepcopy(self.finetuned_estimator_) + else: + patience_counter += 1 + logger.info( + "⚠️ No improvement for %s epochs. Best %s: %.4f", + patience_counter, + self._metric_name, + best_metric, + ) + + if patience_counter >= self.early_stopping_patience: + logger.info( + "🛑 Early stopping triggered. Best %s: %.4f", + self._metric_name, + best_metric, + ) + if best_model is not None: + self.finetuned_estimator_ = best_model + break if self.time_limit is not None: elapsed_time = time.monotonic() - start_time diff --git a/src/tabpfn/finetuning/finetuned_classifier.py b/src/tabpfn/finetuning/finetuned_classifier.py index e052a0066..3808bd713 100644 --- a/src/tabpfn/finetuning/finetuned_classifier.py +++ b/src/tabpfn/finetuning/finetuned_classifier.py @@ -115,6 +115,10 @@ class FinetunedTabPFNClassifier(FinetunedTabPFNBase, ClassifierMixin): data batches. This is helpful in most cases because, e.g., the column order will stay the same across batches. If False, the preprocessing will use a different random seed for each batch. + validation_frequency: How often (in epochs) to run validation. If set to + an integer N, validation is run every N epochs. If None, validation is + disabled entirely, which also disables early stopping. Defaults to 1 + (validate every epoch). FinetunedTabPFNClassifier specific arguments: @@ -150,6 +154,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + validation_frequency: int | None = 1, extra_classifier_kwargs: dict[str, Any] | None = None, eval_metric: Literal["roc_auc", "log_loss"] | None = None, ): @@ -176,6 +181,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing=use_activation_checkpointing, save_checkpoint_interval=save_checkpoint_interval, use_fixed_preprocessing_seed=use_fixed_preprocessing_seed, + validation_frequency=validation_frequency, ) self.extra_classifier_kwargs = extra_classifier_kwargs self.eval_metric = eval_metric diff --git a/src/tabpfn/finetuning/finetuned_regressor.py b/src/tabpfn/finetuning/finetuned_regressor.py index 0dc07060a..7b81a1c93 100644 --- a/src/tabpfn/finetuning/finetuned_regressor.py +++ b/src/tabpfn/finetuning/finetuned_regressor.py @@ -282,6 +282,10 @@ class FinetunedTabPFNRegressor(FinetunedTabPFNBase, RegressorMixin): data batches. This is helpful in most cases because, e.g., the column order will stay the same across batches. If False, the preprocessing will use a different random seed for each batch. + validation_frequency: How often (in epochs) to run validation. If set to + an integer N, validation is run every N epochs. If None, validation is + disabled entirely, which also disables early stopping. Defaults to 1 + (validate every epoch). FinetunedTabPFNRegressor specific arguments: @@ -333,6 +337,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing: bool = True, save_checkpoint_interval: int | None = 10, use_fixed_preprocessing_seed: bool = True, + validation_frequency: int | None = 1, extra_regressor_kwargs: dict[str, Any] | None = None, ce_loss_weight: float = 0.0, crps_loss_weight: float = 1.0, @@ -366,6 +371,7 @@ def __init__( # noqa: PLR0913 use_activation_checkpointing=use_activation_checkpointing, save_checkpoint_interval=save_checkpoint_interval, use_fixed_preprocessing_seed=use_fixed_preprocessing_seed, + validation_frequency=validation_frequency, ) self.extra_regressor_kwargs = extra_regressor_kwargs self.eval_metric = eval_metric