From bad0611a00f78734089f9c5f66558d8a2b62a424 Mon Sep 17 00:00:00 2001 From: ClementBourt Date: Tue, 12 May 2026 12:30:43 +0200 Subject: [PATCH 1/3] Refactor outliers() to log-space chain rule to avoid overflow TabPFNUnsupervisedModel.outliers() previously combined per-feature densities via exp(sum(log(p))) using a 1/pdf trick to keep intermediates small. On wider datasets this still overflows when densities are >> 1 (flagged with an existing TODO in the source). - outliers_single_permutation_() now returns log p directly via -criterion.forward(...), dropping the 1/pdf hack and the explicit log(pred) / clamp steps. - outliers() combines permutations via the log-sum-exp identity: log(mean(p_k)) = logsumexp(log p_k) - log(K). - Score semantics: lower scores indicate more likely outliers; callable signatures unchanged. - experiments.py demo adapted to new score semantics: column p -> scores, percentile comparisons flipped (< instead of >), and quantile computation handles -inf rows. --- .../unsupervised/experiments.py | 95 ++++++++++--------- .../unsupervised/unsupervised.py | 77 +++++---------- 2 files changed, 75 insertions(+), 97 deletions(-) diff --git a/src/tabpfn_extensions/unsupervised/experiments.py b/src/tabpfn_extensions/unsupervised/experiments.py index 7ec8e648..b3168809 100644 --- a/src/tabpfn_extensions/unsupervised/experiments.py +++ b/src/tabpfn_extensions/unsupervised/experiments.py @@ -187,42 +187,49 @@ class OutlierDetectionUnsupervisedExperiment(Experiment): def plot(self): # Create a grid of jointplots using PairGrid g = sns.PairGrid(self.data, vars=self.feature_names) - g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["p"]) - g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["p_rank"]) + g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["scores"]) + g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["score_rank"]) g.add_legend() def plot_two(self, **kwargs): - outlier_thresh_p = kwargs.get("outlier_thresh_p", 0.98) - outlier_thresh = np.quantile( - self.data["p"][self.data["p"] > 0], - outlier_thresh_p, - ) - - outlier_thresh_p_1 = kwargs.get("outlier_thresh_p_1", 0.9) - outlier_thresh_1 = np.quantile( - self.data["p"][self.data["p"] > 0], - outlier_thresh_p_1, - ) + outlier_thresh_p = kwargs.get("outlier_thresh_p", 0.02) + outlier_thresh_p_1 = kwargs.get("outlier_thresh_p_1", 0.1) + + # np.quantile returns NaN if any rank position falls on -inf (since + # interpolation across -inf yields -inf - -inf = NaN). Clamp -inf to + # the finite minimum just for the quantile computation; the original + # scores series is preserved for bucketing, where x < thresh keeps + # -inf rows correctly classified as Low. + scores_series = self.data["scores"] + finite_mask = np.isfinite(scores_series) + if finite_mask.any() and not finite_mask.all(): + finite_floor = float(scores_series[finite_mask].min()) + scores_for_quantile = scores_series.where(finite_mask, finite_floor) + else: + scores_for_quantile = scores_series + + outlier_thresh = np.quantile(scores_for_quantile, outlier_thresh_p) + outlier_thresh_1 = np.quantile(scores_for_quantile, outlier_thresh_p_1) def outlier_f(x, thresh_0, thresh_1): if np.isnan(x): return np.nan - if x > thresh_0: - return f"Low ({round(100 * (1 - outlier_thresh_p), 2)} Percentile)" - if x > thresh_1: - return f"Medium ({round(100 * (1 - outlier_thresh_p_1), 2)} Percentile)" + if x < thresh_0: + return f"Low ({round(100 * (outlier_thresh_p), 2)} Percentile)" + if x < thresh_1: + return f"Medium ({round(100 * (outlier_thresh_p_1), 2)} Percentile)" return "High" - self.data["outlier"] = self.data["p"].map( + self.data["outlier"] = self.data["scores"].map( partial(outlier_f, thresh_0=outlier_thresh, thresh_1=outlier_thresh_1), ) # Oversample the data with outlier = True oversample_low = self.data[ self.data["outlier"].map(lambda x: "Low" in x) - ].sample(frac=1 / (1 - outlier_thresh_p), replace=True) + ].sample(frac=1 / (outlier_thresh_p), replace=True) oversample_med = self.data[ self.data["outlier"].map(lambda x: "Medium" in x) - ].sample(frac=1 / (1 - outlier_thresh_p_1), replace=True) + ].sample(frac=1 / (outlier_thresh_p_1), replace=True) data_ = pd.concat( [ self.data[self.data["outlier"].map(lambda x: "High" in x)], @@ -230,37 +237,37 @@ def outlier_f(x, thresh_0, thresh_1): oversample_med, ], ) - g = sns.JointGrid( + fig, ax = plt.subplots(figsize=(DEFAULT_HEIGHT, DEFAULT_HEIGHT)) + sns.scatterplot( data=data_, - hue="outlier", x=self.feature_names[0], y=self.feature_names[1], - height=DEFAULT_HEIGHT, + hue="outlier", + s=50, + alpha=0.5, + ax=ax, ) - g.fig.suptitle("Data Density Estimation") - g.fig.tight_layout() - g.fig.subplots_adjust(top=0.95) # Reduce plot to make room - - g.plot_joint(sns.scatterplot, s=50, alpha=0.5) - g.plot_marginals(sns.histplot, kde=True, stat="density") + ax.set_title("outlier detection") - # Remove the original legend created by plot_joint - g.ax_joint.get_legend().remove() - - # Create a new legend on the joint plot axis with no frame and no title - handles, labels = g.ax_joint.get_legend_handles_labels() - leg = g.ax_joint.legend( + ax.get_legend().remove() + handles, labels = ax.get_legend_handles_labels() + leg = ax.legend( handles=handles, labels=labels, - loc="upper right", - title="Estimated density (percentile)", + loc="upper left", + title="Estimated data log(density)", + fontsize="small", + title_fontsize="small", + borderpad=0.6, + handletextpad=0.5, ) leg.get_frame().set_facecolor("white") leg.get_frame().set_edgecolor("none") - leg.get_frame().set_alpha(1) # Make the legend background completely opaque + leg.get_frame().set_alpha(1) + fig.tight_layout() - return g + return ax def run( self, @@ -291,16 +298,16 @@ def run( tabpfn.set_categorical_features(categorical_features) tabpfn.fit(self.X) - self.p = tabpfn.outliers(self.X, n_permutations=n_permutations) + self.scores = tabpfn.outliers(self.X, n_permutations=n_permutations) - p_rank = self.p.argsort().argsort() + score_rank = self.scores.argsort().argsort() self.data = pd.DataFrame( torch.cat( - [self.p[:, np.newaxis], p_rank[:, np.newaxis], self.X], + [self.scores[:, np.newaxis], score_rank[:, np.newaxis], self.X], dim=1, ).numpy(), - columns=["p", "p_rank", *self.feature_names], + columns=["scores", "score_rank", *self.feature_names], ) if kwargs.get("should_plot", True): @@ -312,4 +319,4 @@ def run( # Skip plotting if matplotlib is not available pass - return {"outlier_scores": self.p.numpy()} + return {"outlier_scores": self.scores.numpy()} diff --git a/src/tabpfn_extensions/unsupervised/unsupervised.py b/src/tabpfn_extensions/unsupervised/unsupervised.py index d6253603..2aab756c 100644 --- a/src/tabpfn_extensions/unsupervised/unsupervised.py +++ b/src/tabpfn_extensions/unsupervised/unsupervised.py @@ -593,6 +593,7 @@ def outliers_single_permutation_( X: torch.tensor, feature_permutation: list[int] | tuple[int], ) -> torch.tensor: + """Compute the chain-rule log-density / log-probability of each row under one permutation.""" log_p = torch.zeros_like( X[:, 0], ) # Start with a log probability of 0 (log(1) = 0) @@ -631,46 +632,33 @@ def outliers_single_permutation_( ): # Check bounds again per sample # Proper tensor construction to avoid warning pred[idx] = torch.as_tensor(prob_row[y_idx]) + log_pred = torch.log(pred) else: pred = model.predict(X_predict, output_type="full") - - # Get logits tensor properly logits = pred["logits"] logits_tensor = logits.clone().detach() - y_tensor = y_predict.clone().detach().to(logits.device) - - # TODO: We use 1/pdf here because pdf() returns probability densities that - # can be >> 1, causing exp(sum(log(p))) to overflow. Using 1/p keeps values - # small and numerically stable. Ideally, refactor to work in log space - # throughout and avoid exponentiating altogether. - pred = (1.0 / pred["criterion"].pdf(logits_tensor, y_tensor)).to( + # criterion.forward returns the NLL, so -forward is log p_θ directly. + log_pred = -pred["criterion"].forward(logits_tensor, y_tensor).to( log_p.device ) - # Handle zero or negative probabilities (avoid log(0)) - pred = torch.clamp(pred, min=1e-10) - - # Convert probabilities to log probabilities - log_pred = torch.log(pred) - - # Add log probabilities instead of multiplying probabilities log_p = log_p + log_pred - return log_p, torch.exp(log_p) + return log_p def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor: - """Calculate outlier scores based on probability density functions for continuous features. + """Calculate outlier scores from numerical features only. This method filters out categorical features and only considers numerical features - for outlier detection using probability density functions. + for outlier detection. Args: X: Input data tensor n_permutations: Number of permutations to use for the outlier calculation Returns: - Tensor of outlier scores (lower values indicate more likely outliers) + Outlier scores (lower values indicate more likely outliers). """ X_store = copy.deepcopy(self.X_) mask = torch.ones_like(X_store).bool() @@ -680,15 +668,15 @@ def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso mask[self.categorical_features] = False X = X[mask] - pdf = self.outliers(X, n_permutations=n_permutations) + scores_pdf = self.outliers(X, n_permutations=n_permutations) self.X_ = X_store - return pdf + return scores_pdf def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor: - """Calculate outlier scores based on probability mass functions for categorical features. + """Calculate outlier scores from categorical features only. This method filters out numerical features and only considers categorical features - for outlier detection using probability mass functions. + for outlier detection. Args: X: Input data tensor @@ -705,9 +693,9 @@ def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso mask[self.categorical_features] = True X = X[mask] - pmf = self.outliers(X, n_permutations=n_permutations) + scores_pmf = self.outliers(X, n_permutations=n_permutations) self.X_ = X_store - return pmf + return scores_pmf @set_extension("unsupervised:outliers") def outliers( @@ -715,12 +703,9 @@ def outliers( X: torch.Tensor | np.ndarray | pd.DataFrame, n_permutations: int = 10, ) -> torch.Tensor: - """Calculate outlier scores for each sample in the input data. + """Calculate outlier scores as the log of the arithmetic mean (AM) of the densities across the permutations used to approximate the chain rule. - This is the preferred implementation for outlier detection, which calculates - sample probability for each sample in X by multiplying the probabilities of - each feature according to chain rule of probability. Lower probabilities - indicate samples that are more likely to be outliers. + The logsumexp trick is used to compute the log of the AM to address the risk of over- or underflow. Parameters: X: Union[torch.Tensor, np.ndarray, pd.DataFrame] @@ -731,8 +716,7 @@ def outliers( Returns: torch.Tensor: - Tensor of outlier scores (lower values indicate more likely outliers), - shape (n_samples,) + Tensor of outlier scores as log(AM(densities)), (lower values indicate more likely outliers), shape (n_samples,). Raises: RuntimeError: If the model initialization fails @@ -756,31 +740,18 @@ def outliers( # Use fewer permutations in fast mode actual_n_permutations = 1 if fast_mode else n_permutations - densities: list[torch.Tensor | np.ndarray] = [] + log_densities: list[torch.Tensor] = [] for perm in efficient_random_permutation(all_features, actual_n_permutations): - perm_density_log, perm_density = self.outliers_single_permutation_( + log_p = self.outliers_single_permutation_( X, feature_permutation=perm, ) - densities.append(perm_density) - - # Average the densities across all permutations - # Handle potential infinite values by replacing them with large finite values - densities_clean: list[torch.Tensor] = [ - torch.nan_to_num(d, nan=0.0, posinf=1e30, neginf=1e-30) - if torch.is_tensor(d) - else torch.nan_to_num( - torch.tensor(d, dtype=torch.float32), - nan=0.0, - posinf=1e30, - neginf=1e-30, - ) - for d in densities - ] + log_densities.append(log_p) - # Stack the clean tensors and compute mean - densities_tensor = torch.stack(densities_clean) - return densities_tensor.mean(dim=0) + # AM combiner via the log-sum-exp identity. + return torch.logsumexp(torch.stack(log_densities), dim=0) - np.log( + actual_n_permutations + ) @set_extension("unsupervised:synthetic") def generate_synthetic_data( From 18c9ce966ada1c7e5e549a9c84cb0e8b3e65c4d3 Mon Sep 17 00:00:00 2001 From: ClementBourt Date: Tue, 12 May 2026 12:30:43 +0200 Subject: [PATCH 2/3] Rename outlier output: scores -> log_p The values returned by outliers()/outliers_pdf()/outliers_pmf() are log-densities, not unitless scores. Rename the experiment attribute, DataFrame columns, and local variables to match what the values are. - experiments.py: Experiment.scores -> Experiment.log_p; DataFrame columns "scores"/"score_rank" -> "log_p"/"log_p_rank"; run() return dict key "outlier_scores" -> "log_p". - unsupervised.py: locals scores_pdf/scores_pmf -> log_pdf/log_pmf inside outliers_pdf()/outliers_pmf(); docstrings updated. Public method names (outliers, outliers_pdf, outliers_pmf) are unchanged. --- .../unsupervised/experiments.py | 32 +++++++++---------- .../unsupervised/unsupervised.py | 14 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/tabpfn_extensions/unsupervised/experiments.py b/src/tabpfn_extensions/unsupervised/experiments.py index b3168809..72f676e0 100644 --- a/src/tabpfn_extensions/unsupervised/experiments.py +++ b/src/tabpfn_extensions/unsupervised/experiments.py @@ -187,8 +187,8 @@ class OutlierDetectionUnsupervisedExperiment(Experiment): def plot(self): # Create a grid of jointplots using PairGrid g = sns.PairGrid(self.data, vars=self.feature_names) - g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["scores"]) - g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["score_rank"]) + g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["log_p"]) + g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["log_p_rank"]) g.add_legend() def plot_two(self, **kwargs): @@ -198,18 +198,18 @@ def plot_two(self, **kwargs): # np.quantile returns NaN if any rank position falls on -inf (since # interpolation across -inf yields -inf - -inf = NaN). Clamp -inf to # the finite minimum just for the quantile computation; the original - # scores series is preserved for bucketing, where x < thresh keeps + # log_p series is preserved for bucketing, where x < thresh keeps # -inf rows correctly classified as Low. - scores_series = self.data["scores"] - finite_mask = np.isfinite(scores_series) + log_p_series = self.data["log_p"] + finite_mask = np.isfinite(log_p_series) if finite_mask.any() and not finite_mask.all(): - finite_floor = float(scores_series[finite_mask].min()) - scores_for_quantile = scores_series.where(finite_mask, finite_floor) + finite_floor = float(log_p_series[finite_mask].min()) + log_p_for_quantile = log_p_series.where(finite_mask, finite_floor) else: - scores_for_quantile = scores_series + log_p_for_quantile = log_p_series - outlier_thresh = np.quantile(scores_for_quantile, outlier_thresh_p) - outlier_thresh_1 = np.quantile(scores_for_quantile, outlier_thresh_p_1) + outlier_thresh = np.quantile(log_p_for_quantile, outlier_thresh_p) + outlier_thresh_1 = np.quantile(log_p_for_quantile, outlier_thresh_p_1) def outlier_f(x, thresh_0, thresh_1): if np.isnan(x): @@ -220,7 +220,7 @@ def outlier_f(x, thresh_0, thresh_1): return f"Medium ({round(100 * (outlier_thresh_p_1), 2)} Percentile)" return "High" - self.data["outlier"] = self.data["scores"].map( + self.data["outlier"] = self.data["log_p"].map( partial(outlier_f, thresh_0=outlier_thresh, thresh_1=outlier_thresh_1), ) # Oversample the data with outlier = True @@ -298,16 +298,16 @@ def run( tabpfn.set_categorical_features(categorical_features) tabpfn.fit(self.X) - self.scores = tabpfn.outliers(self.X, n_permutations=n_permutations) + self.log_p = tabpfn.outliers(self.X, n_permutations=n_permutations) - score_rank = self.scores.argsort().argsort() + log_p_rank = self.log_p.argsort().argsort() self.data = pd.DataFrame( torch.cat( - [self.scores[:, np.newaxis], score_rank[:, np.newaxis], self.X], + [self.log_p[:, np.newaxis], log_p_rank[:, np.newaxis], self.X], dim=1, ).numpy(), - columns=["scores", "score_rank", *self.feature_names], + columns=["log_p", "log_p_rank", *self.feature_names], ) if kwargs.get("should_plot", True): @@ -319,4 +319,4 @@ def run( # Skip plotting if matplotlib is not available pass - return {"outlier_scores": self.scores.numpy()} + return {"log_p": self.log_p.numpy()} diff --git a/src/tabpfn_extensions/unsupervised/unsupervised.py b/src/tabpfn_extensions/unsupervised/unsupervised.py index 2aab756c..9d409b34 100644 --- a/src/tabpfn_extensions/unsupervised/unsupervised.py +++ b/src/tabpfn_extensions/unsupervised/unsupervised.py @@ -648,7 +648,7 @@ def outliers_single_permutation_( return log_p def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor: - """Calculate outlier scores from numerical features only. + """Calculate the log_pdf from numerical features only. This method filters out categorical features and only considers numerical features for outlier detection. @@ -658,7 +658,7 @@ def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso n_permutations: Number of permutations to use for the outlier calculation Returns: - Outlier scores (lower values indicate more likely outliers). + log_pdf (lower values indicate more likely outliers). """ X_store = copy.deepcopy(self.X_) mask = torch.ones_like(X_store).bool() @@ -668,12 +668,12 @@ def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso mask[self.categorical_features] = False X = X[mask] - scores_pdf = self.outliers(X, n_permutations=n_permutations) + log_pdf = self.outliers(X, n_permutations=n_permutations) self.X_ = X_store - return scores_pdf + return log_pdf def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor: - """Calculate outlier scores from categorical features only. + """Calculate log_pmf from categorical features only. This method filters out numerical features and only considers categorical features for outlier detection. @@ -693,9 +693,9 @@ def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso mask[self.categorical_features] = True X = X[mask] - scores_pmf = self.outliers(X, n_permutations=n_permutations) + log_pmf = self.outliers(X, n_permutations=n_permutations) self.X_ = X_store - return scores_pmf + return log_pmf @set_extension("unsupervised:outliers") def outliers( From dbc5d327d7e32988fb926f7e14bdd598a3edeeef Mon Sep 17 00:00:00 2001 From: ClementBourt Date: Wed, 13 May 2026 16:45:56 +0200 Subject: [PATCH 3/3] style: apply ruff format to unsupervised.py --- src/tabpfn_extensions/unsupervised/unsupervised.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tabpfn_extensions/unsupervised/unsupervised.py b/src/tabpfn_extensions/unsupervised/unsupervised.py index 9d409b34..3ea8bec0 100644 --- a/src/tabpfn_extensions/unsupervised/unsupervised.py +++ b/src/tabpfn_extensions/unsupervised/unsupervised.py @@ -639,8 +639,8 @@ def outliers_single_permutation_( logits_tensor = logits.clone().detach() y_tensor = y_predict.clone().detach().to(logits.device) # criterion.forward returns the NLL, so -forward is log p_θ directly. - log_pred = -pred["criterion"].forward(logits_tensor, y_tensor).to( - log_p.device + log_pred = ( + -pred["criterion"].forward(logits_tensor, y_tensor).to(log_p.device) ) log_p = log_p + log_pred