Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions src/tabpfn_extensions/unsupervised/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,80 +187,87 @@ class OutlierDetectionUnsupervisedExperiment(Experiment):
def plot(self):
# Create a grid of jointplots using PairGrid
g = sns.PairGrid(self.data, vars=self.feature_names)
g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["p"])
g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["p_rank"])
g.map_upper(sns.scatterplot, s=5, alpha=0.5, hue=self.data["log_p"])
g.map_lower(sns.scatterplot, s=5, alpha=0.5, hue=self.data["log_p_rank"])
g.add_legend()

def plot_two(self, **kwargs):
outlier_thresh_p = kwargs.get("outlier_thresh_p", 0.98)
outlier_thresh = np.quantile(
self.data["p"][self.data["p"] > 0],
outlier_thresh_p,
)

outlier_thresh_p_1 = kwargs.get("outlier_thresh_p_1", 0.9)
outlier_thresh_1 = np.quantile(
self.data["p"][self.data["p"] > 0],
outlier_thresh_p_1,
)
outlier_thresh_p = kwargs.get("outlier_thresh_p", 0.02)
outlier_thresh_p_1 = kwargs.get("outlier_thresh_p_1", 0.1)

# np.quantile returns NaN if any rank position falls on -inf (since
# interpolation across -inf yields -inf - -inf = NaN). Clamp -inf to
# the finite minimum just for the quantile computation; the original
# log_p series is preserved for bucketing, where x < thresh keeps
# -inf rows correctly classified as Low.
log_p_series = self.data["log_p"]
finite_mask = np.isfinite(log_p_series)
if finite_mask.any() and not finite_mask.all():
finite_floor = float(log_p_series[finite_mask].min())
log_p_for_quantile = log_p_series.where(finite_mask, finite_floor)
else:
log_p_for_quantile = log_p_series

outlier_thresh = np.quantile(log_p_for_quantile, outlier_thresh_p)
outlier_thresh_1 = np.quantile(log_p_for_quantile, outlier_thresh_p_1)

def outlier_f(x, thresh_0, thresh_1):
if np.isnan(x):
return np.nan
if x > thresh_0:
return f"Low ({round(100 * (1 - outlier_thresh_p), 2)} Percentile)"
if x > thresh_1:
return f"Medium ({round(100 * (1 - outlier_thresh_p_1), 2)} Percentile)"
if x < thresh_0:
return f"Low ({round(100 * (outlier_thresh_p), 2)} Percentile)"
if x < thresh_1:
return f"Medium ({round(100 * (outlier_thresh_p_1), 2)} Percentile)"
return "High"

self.data["outlier"] = self.data["p"].map(
self.data["outlier"] = self.data["log_p"].map(
partial(outlier_f, thresh_0=outlier_thresh, thresh_1=outlier_thresh_1),
)
# Oversample the data with outlier = True
oversample_low = self.data[
self.data["outlier"].map(lambda x: "Low" in x)
].sample(frac=1 / (1 - outlier_thresh_p), replace=True)
].sample(frac=1 / (outlier_thresh_p), replace=True)
oversample_med = self.data[
self.data["outlier"].map(lambda x: "Medium" in x)
].sample(frac=1 / (1 - outlier_thresh_p_1), replace=True)
].sample(frac=1 / (outlier_thresh_p_1), replace=True)
data_ = pd.concat(
[
self.data[self.data["outlier"].map(lambda x: "High" in x)],
oversample_low,
oversample_med,
],
)
g = sns.JointGrid(
fig, ax = plt.subplots(figsize=(DEFAULT_HEIGHT, DEFAULT_HEIGHT))
sns.scatterplot(
data=data_,
hue="outlier",
x=self.feature_names[0],
y=self.feature_names[1],
height=DEFAULT_HEIGHT,
hue="outlier",
s=50,
alpha=0.5,
ax=ax,
)

g.fig.suptitle("Data Density Estimation")
g.fig.tight_layout()
g.fig.subplots_adjust(top=0.95) # Reduce plot to make room

g.plot_joint(sns.scatterplot, s=50, alpha=0.5)
g.plot_marginals(sns.histplot, kde=True, stat="density")
ax.set_title("outlier detection")

# Remove the original legend created by plot_joint
g.ax_joint.get_legend().remove()

# Create a new legend on the joint plot axis with no frame and no title
handles, labels = g.ax_joint.get_legend_handles_labels()
leg = g.ax_joint.legend(
ax.get_legend().remove()
handles, labels = ax.get_legend_handles_labels()
leg = ax.legend(
handles=handles,
labels=labels,
loc="upper right",
title="Estimated density (percentile)",
loc="upper left",
title="Estimated data log(density)",
fontsize="small",
title_fontsize="small",
borderpad=0.6,
handletextpad=0.5,
)
leg.get_frame().set_facecolor("white")
leg.get_frame().set_edgecolor("none")
leg.get_frame().set_alpha(1) # Make the legend background completely opaque
leg.get_frame().set_alpha(1)
fig.tight_layout()

return g
return ax

def run(
self,
Expand Down Expand Up @@ -291,16 +298,16 @@ def run(
tabpfn.set_categorical_features(categorical_features)

tabpfn.fit(self.X)
self.p = tabpfn.outliers(self.X, n_permutations=n_permutations)
self.log_p = tabpfn.outliers(self.X, n_permutations=n_permutations)

p_rank = self.p.argsort().argsort()
log_p_rank = self.log_p.argsort().argsort()

self.data = pd.DataFrame(
torch.cat(
[self.p[:, np.newaxis], p_rank[:, np.newaxis], self.X],
[self.log_p[:, np.newaxis], log_p_rank[:, np.newaxis], self.X],
dim=1,
).numpy(),
columns=["p", "p_rank", *self.feature_names],
columns=["log_p", "log_p_rank", *self.feature_names],
)

if kwargs.get("should_plot", True):
Expand All @@ -312,4 +319,4 @@ def run(
# Skip plotting if matplotlib is not available
pass

return {"outlier_scores": self.p.numpy()}
return {"log_p": self.log_p.numpy()}
79 changes: 25 additions & 54 deletions src/tabpfn_extensions/unsupervised/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def outliers_single_permutation_(
X: torch.tensor,
feature_permutation: list[int] | tuple[int],
) -> torch.tensor:
"""Compute the chain-rule log-density / log-probability of each row under one permutation."""
log_p = torch.zeros_like(
X[:, 0],
) # Start with a log probability of 0 (log(1) = 0)
Expand Down Expand Up @@ -631,46 +632,33 @@ def outliers_single_permutation_(
): # Check bounds again per sample
# Proper tensor construction to avoid warning
pred[idx] = torch.as_tensor(prob_row[y_idx])
log_pred = torch.log(pred)
else:
pred = model.predict(X_predict, output_type="full")

# Get logits tensor properly
logits = pred["logits"]
logits_tensor = logits.clone().detach()

y_tensor = y_predict.clone().detach().to(logits.device)

# TODO: We use 1/pdf here because pdf() returns probability densities that
# can be >> 1, causing exp(sum(log(p))) to overflow. Using 1/p keeps values
# small and numerically stable. Ideally, refactor to work in log space
# throughout and avoid exponentiating altogether.
pred = (1.0 / pred["criterion"].pdf(logits_tensor, y_tensor)).to(
log_p.device
# criterion.forward returns the NLL, so -forward is log p_θ directly.
log_pred = (
-pred["criterion"].forward(logits_tensor, y_tensor).to(log_p.device)
)

# Handle zero or negative probabilities (avoid log(0))
pred = torch.clamp(pred, min=1e-10)

# Convert probabilities to log probabilities
log_pred = torch.log(pred)

# Add log probabilities instead of multiplying probabilities
log_p = log_p + log_pred

return log_p, torch.exp(log_p)
return log_p

def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor:
"""Calculate outlier scores based on probability density functions for continuous features.
"""Calculate the log_pdf from numerical features only.

This method filters out categorical features and only considers numerical features
for outlier detection using probability density functions.
for outlier detection.

Args:
X: Input data tensor
n_permutations: Number of permutations to use for the outlier calculation

Returns:
Tensor of outlier scores (lower values indicate more likely outliers)
log_pdf (lower values indicate more likely outliers).
"""
X_store = copy.deepcopy(self.X_)
mask = torch.ones_like(X_store).bool()
Expand All @@ -680,15 +668,15 @@ def outliers_pdf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso
mask[self.categorical_features] = False
X = X[mask]

pdf = self.outliers(X, n_permutations=n_permutations)
log_pdf = self.outliers(X, n_permutations=n_permutations)
Comment thread
ClementBourt marked this conversation as resolved.
self.X_ = X_store
return pdf
return log_pdf

def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tensor:
"""Calculate outlier scores based on probability mass functions for categorical features.
"""Calculate log_pmf from categorical features only.

This method filters out numerical features and only considers categorical features
for outlier detection using probability mass functions.
for outlier detection.

Args:
X: Input data tensor
Expand All @@ -705,22 +693,19 @@ def outliers_pmf(self, X: torch.Tensor, n_permutations: int = 10) -> torch.Tenso
mask[self.categorical_features] = True
X = X[mask]

pmf = self.outliers(X, n_permutations=n_permutations)
log_pmf = self.outliers(X, n_permutations=n_permutations)
self.X_ = X_store
return pmf
return log_pmf

@set_extension("unsupervised:outliers")
def outliers(
self,
X: torch.Tensor | np.ndarray | pd.DataFrame,
n_permutations: int = 10,
) -> torch.Tensor:
"""Calculate outlier scores for each sample in the input data.
"""Calculate outlier scores as the log of the arithmetic mean (AM) of the densities across the permutations used to approximate the chain rule.

This is the preferred implementation for outlier detection, which calculates
sample probability for each sample in X by multiplying the probabilities of
each feature according to chain rule of probability. Lower probabilities
indicate samples that are more likely to be outliers.
The logsumexp trick is used to compute the log of the AM to address the risk of over- or underflow.

Parameters:
X: Union[torch.Tensor, np.ndarray, pd.DataFrame]
Expand All @@ -731,8 +716,7 @@ def outliers(

Returns:
torch.Tensor:
Tensor of outlier scores (lower values indicate more likely outliers),
shape (n_samples,)
Tensor of outlier scores as log(AM(densities)), (lower values indicate more likely outliers), shape (n_samples,).

Raises:
RuntimeError: If the model initialization fails
Expand All @@ -756,31 +740,18 @@ def outliers(
# Use fewer permutations in fast mode
actual_n_permutations = 1 if fast_mode else n_permutations

densities: list[torch.Tensor | np.ndarray] = []
log_densities: list[torch.Tensor] = []
for perm in efficient_random_permutation(all_features, actual_n_permutations):
perm_density_log, perm_density = self.outliers_single_permutation_(
log_p = self.outliers_single_permutation_(
X,
feature_permutation=perm,
)
densities.append(perm_density)

# Average the densities across all permutations
# Handle potential infinite values by replacing them with large finite values
densities_clean: list[torch.Tensor] = [
torch.nan_to_num(d, nan=0.0, posinf=1e30, neginf=1e-30)
if torch.is_tensor(d)
else torch.nan_to_num(
torch.tensor(d, dtype=torch.float32),
nan=0.0,
posinf=1e30,
neginf=1e-30,
)
for d in densities
]
log_densities.append(log_p)

# Stack the clean tensors and compute mean
densities_tensor = torch.stack(densities_clean)
return densities_tensor.mean(dim=0)
# AM combiner via the log-sum-exp identity.
return torch.logsumexp(torch.stack(log_densities), dim=0) - np.log(
actual_n_permutations
)

@set_extension("unsupervised:synthetic")
def generate_synthetic_data(
Expand Down
Loading