diff --git a/README.md b/README.md index 5bc2be9..49576ac 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,45 @@ -[](https://www.codefactor.io/repository/github/diyago/tabular-data-generation) -[](https://github.com/psf/black) -[](https://opensource.org/licenses/Apache-2.0) -[](https://pypi.org/project/tabgan/) -[](https://pypi.org/project/tabgan/) -[](https://pepy.tech/project/tabgan) -[](https://github.com/diyago/Tabular-data-generation/actions/workflows/codeql.yml) +
+
+
High-quality synthetic tabular data generation
-A powerful library for generating high-quality synthetic tabular data using state-of-the-art generative models including GANs, Diffusion models, and Large Language Models. + -
+---
## Overview
-TabGAN is a comprehensive Python library that provides a unified interface for generating high-quality synthetic tabular data. It integrates multiple state-of-the-art generative approaches to address diverse data synthesis requirements:
+TabGAN provides a unified Python interface for generating synthetic tabular data using multiple state-of-the-art generative approaches:
-- **GANs**: Conditional Tabular GAN (CTGAN) for modeling complex multivariate distributions with mixed data types
-- **Diffusion Models**: Forest Diffusion for high-fidelity synthetic data generation with tree-based gradient boosting
-- **LLMs**: GReaT (Generative Realistic Tabular data) framework leveraging language models for realistic tabular data synthesis
-- **Time-Series**: TimeGAN support for temporal data generation preserving sequential dependencies
+| Approach | Backend | Strengths |
+|----------|---------|-----------|
+| **GANs** | Conditional Tabular GAN (CTGAN) | Mixed data types, complex multivariate distributions |
+| **Diffusion Models** | ForestDiffusion (tree-based gradient boosting) | High-fidelity generation for structured data |
+| **Large Language Models** | GReaT framework | Capturing semantic dependencies, conditional text generation |
+| **Baseline** | Random sampling with replacement | Quick benchmarking and comparison |
-*Related Research: [Tabular GANs for uneven distribution (arXiv:2010.00638)](https://arxiv.org/abs/2010.00638)*
+All generators share a common pipeline: **generate → post-process → adversarial filter**, ensuring synthetic data stays close to the real data distribution.
+
+*Based on the paper: [Tabular GANs for uneven distribution](https://arxiv.org/abs/2010.00638) (arXiv:2010.00638)*
## Key Features
-- **Multiple Generative Architectures**: Seamlessly switch between GANs, Diffusion Models, and LLMs via a unified API
-- **Adversarial Filtering**: Built-in adversarial validation to ensure synthetic data preserves predictive utility
-- **Mixed Data Type Support**: Native handling of continuous, categorical, and text columns
-- **Conditional Generation**: Generate data conditioned on specific column values or distributions
-- **Scalable Processing**: Efficient batch processing for large-scale datasets
-- **Quality Validation**: Integrated metrics for comparing synthetic against original data distributions
+- **Unified API** — switch between GANs, diffusion models, and LLMs with a single parameter change
+- **Adversarial filtering** — built-in LightGBM-based validation keeps synthetic samples distribution-consistent
+- **Mixed data types** — native handling of continuous, categorical, and free-text columns
+- **Conditional generation** — generate text conditioned on categorical attributes via LLM prompting
+- **LLM API support** — integrate with LM Studio, OpenAI, Ollama, or any OpenAI-compatible endpoint
+- **Quality validation** — compare original and synthetic distributions with a single function call
## Installation
@@ -45,12 +54,10 @@ import pandas as pd
import numpy as np
from tabgan.sampler import GANGenerator
-# Create sample data
train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD"))
target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y"))
test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
-# Generate synthetic data
new_train, new_target = GANGenerator().generate_data_pipe(train, target, test)
```
@@ -59,28 +66,28 @@ new_train, new_target = GANGenerator().generate_data_pipe(train, target, test)
| Generator | Description | Best For |
|-----------|-------------|----------|
| `GANGenerator` | CTGAN-based generation | General tabular data with mixed types |
-| `ForestDiffusionGenerator` | Diffusion models + tree-based methods | Complex tabular structures |
-| `LLMGenerator` | Large Language Model based | Capturing complex dependencies |
-| `OriginalGenerator` | Baseline sampler | Baseline comparisons |
+| `ForestDiffusionGenerator` | Diffusion models with tree-based methods | Complex tabular structures |
+| `LLMGenerator` | Large Language Model based | Semantic dependencies, text columns |
+| `OriginalGenerator` | Baseline random sampler | Benchmarking and comparison |
## API Reference
-### Sampler Parameters
+### Common Parameters
-All generators accept these common parameters:
+All generators accept the following parameters:
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
-| `gen_x_times` | float | 1.1 | Multiplier for data generation amount |
-| `cat_cols` | list | None | Column names to treat as categorical |
-| `bot_filter_quantile` | float | 0.001 | Bottom quantile for post-process filtering |
-| `top_filter_quantile` | float | 0.999 | Top quantile for post-process filtering |
-| `is_post_process` | bool | True | Enable post-filtering |
-| `pregeneration_frac` | float | 2 | Pre-generation data multiplier |
-| `only_generated_data` | bool | False | Return only generated data (no original) |
-| `gen_params` | dict | See below | Generator-specific parameters |
+| `gen_x_times` | `float` | `1.1` | Multiplier for synthetic sample count relative to training size |
+| `cat_cols` | `list` | `None` | Column names to treat as categorical |
+| `bot_filter_quantile` | `float` | `0.001` | Lower quantile for post-processing filters |
+| `top_filter_quantile` | `float` | `0.999` | Upper quantile for post-processing filters |
+| `is_post_process` | `bool` | `True` | Enable quantile-based post-filtering |
+| `pregeneration_frac` | `float` | `2` | Oversampling factor before filtering |
+| `only_generated_data` | `bool` | `False` | Return only synthetic rows (exclude originals) |
+| `gen_params` | `dict` | See below | Generator-specific hyperparameters |
-### Generator-Specific Parameters
+### Generator-Specific Parameters (`gen_params`)
**GANGenerator:**
```python
@@ -92,31 +99,33 @@ All generators accept these common parameters:
{"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}
```
-### generate_data_pipe Method
+### `generate_data_pipe` Method
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `train_df` | pd.DataFrame | Training features |
-| `target` | pd.DataFrame | Target variable |
-| `test_df` | pd.DataFrame | Test features |
-| `deep_copy` | bool | Make copy of input dataframes |
-| `only_adversarial` | bool | Only perform adversarial filtering |
-| `use_adversarial` | bool | Enable adversarial filtering |
+```python
+new_train, new_target = generator.generate_data_pipe(
+ train_df, # pd.DataFrame - training features
+ target, # pd.DataFrame - target variable (or None)
+ test_df, # pd.DataFrame - test features for distribution alignment
+ deep_copy=True, # bool - copy input DataFrames
+ only_adversarial=False, # bool - skip generation, only filter
+ use_adversarial=True, # bool - enable adversarial filtering
+)
+```
-**Returns:** `Tuple[pd.DataFrame, pd.DataFrame]` - (new_train, new_target)
+**Returns:** `Tuple[pd.DataFrame, pd.DataFrame]` — `(new_train, new_target)`
## Data Format
-TabGAN accepts both `numpy.ndarray` and `pandas.DataFrame` inputs, supporting:
+TabGAN accepts `pandas.DataFrame` inputs with:
-- **Continuous Variables**: Numerical columns with any real-valued data
-- **Categorical Variables**: Discrete columns with a finite set of possible values
+- **Continuous columns** — any real-valued numerical data
+- **Categorical columns** — discrete columns with a finite set of values
-> **Note:** TabGAN internally processes all values as floating-point numbers. For integer-valued outputs, apply rounding after generation.
+> **Note:** TabGAN processes values as floating-point internally. Apply rounding after generation for integer-valued outputs.
## Examples
-### Basic Usage
+### Basic Usage with All Generators
```python
from tabgan.sampler import OriginalGenerator, GANGenerator, ForestDiffusionGenerator, LLMGenerator
@@ -127,11 +136,14 @@ train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("A
target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y"))
test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
-# Different generators
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test)
-new_train2, new_target2 = GANGenerator(gen_params={"batch_size": 500, "epochs": 10, "patience": 5}).generate_data_pipe(train, target, test)
+new_train2, new_target2 = GANGenerator(
+ gen_params={"batch_size": 500, "epochs": 10, "patience": 5}
+).generate_data_pipe(train, target, test)
new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test)
-new_train4, new_target4 = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}).generate_data_pipe(train, target, test)
+new_train4, new_target4 = LLMGenerator(
+ gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}
+).generate_data_pipe(train, target, test)
```
### Full Parameter Example
@@ -149,168 +161,129 @@ new_train, new_target = GANGenerator(
},
pregeneration_frac=2,
only_generated_data=False,
- gen_params={"batch_size": 500, "patience": 25, "epochs": 500}
+ gen_params={"batch_size": 500, "patience": 25, "epochs": 500},
).generate_data_pipe(
train, target, test,
deep_copy=True,
only_adversarial=False,
- use_adversarial=True
+ use_adversarial=True,
)
```
### LLM Conditional Text Generation
-Generate synthetic data with LLMs while controlling text generation based on specific conditions. This uses the internal `_generate_via_prompt` method for novel text generation.
+Generate synthetic rows with novel text values conditioned on categorical attributes:
```python
import pandas as pd
from tabgan.sampler import LLMGenerator
-# Create sample data with text and categorical columns
train = pd.DataFrame({
"Name": ["Anna", "Maria", "Ivan", "Sergey", "Olga", "Boris"],
"Gender": ["F", "F", "M", "M", "F", "M"],
"Age": [25, 30, 35, 40, 28, 32],
- "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"]
+ "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"],
})
-# Generate new names conditioned on Gender, with other features imputed
new_train, _ = LLMGenerator(
- gen_x_times=1.5, # Generate 1.5x the original data
- text_generating_columns=["Name"], # Generate novel names
- conditional_columns=["Gender"], # Condition on Gender column
+ gen_x_times=1.5,
+ text_generating_columns=["Name"], # columns to generate novel text for
+ conditional_columns=["Gender"], # columns that condition text generation
gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500},
- is_post_process=False # Disable post-processing for this example
-).generate_data_pipe(
- train,
- target=None,
- test_df=None,
- only_generated_data=True # Return only generated data
-)
-
-print(new_train)
+ is_post_process=False,
+).generate_data_pipe(train, target=None, test_df=None, only_generated_data=True)
```
-**Parameters for conditional generation:**
-- `text_generating_columns`: List of column names to generate novel text for
-- `conditional_columns`: List of column names that condition the text generation
-
-The model will:
-1. Sample values for conditional columns from their distributions
-2. Impute remaining non-text columns using the LLM
-3. Generate novel text for text columns via prompt-based generation (using `_generate_via_prompt`)
-4. Ensure generated text values are unique (not present in original data)
+**How it works:**
+1. Sample conditional column values from their empirical distributions
+2. Impute remaining non-text columns using the fitted GReaT model
+3. Generate novel text via prompt-based generation
+4. Ensure generated text values differ from the original data
### LLM API-Based Text Generation
-Use external LLM APIs (LM Studio, OpenAI, Ollama) for text generation instead of local models. This allows you to leverage powerful models running on remote servers or local API endpoints.
+Use external LLM APIs (LM Studio, OpenAI, Ollama) instead of local models:
```python
import pandas as pd
from tabgan.sampler import LLMGenerator
from tabgan.llm_config import LLMAPIConfig
-# Create sample data
train = pd.DataFrame({
"Name": ["Anna", "Maria", "Ivan", "Sergey", "Olga", "Boris"],
"Gender": ["F", "F", "M", "M", "F", "M"],
"Age": [25, 30, 35, 40, 28, 32],
- "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"]
+ "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"],
})
-# Configure API connection (LM Studio example)
+# LM Studio
api_config = LLMAPIConfig.from_lm_studio(
base_url="http://localhost:1234",
model="google/gemma-3-12b",
- timeout=90
+ timeout=90,
)
-# Or use OpenAI
-# api_config = LLMAPIConfig.from_openai(
-# api_key="your-api-key",
-# model="gpt-4"
-# )
-
-# Or use Ollama
-# api_config = LLMAPIConfig.from_ollama(
-# base_url="http://localhost:11434",
-# model="llama3"
-# )
+# Or OpenAI: LLMAPIConfig.from_openai(api_key="...", model="gpt-4")
+# Or Ollama: LLMAPIConfig.from_ollama(model="llama3")
-# Generate with API-based text generation
new_train, _ = LLMGenerator(
gen_x_times=1.5,
text_generating_columns=["Name"],
conditional_columns=["Gender"],
gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500},
- llm_api_config=api_config, # Use external API for text generation
- is_post_process=False
-).generate_data_pipe(
- train,
- target=None,
- test_df=None,
- only_generated_data=True
-)
-
-print(new_train)
+ llm_api_config=api_config,
+ is_post_process=False,
+).generate_data_pipe(train, target=None, test_df=None, only_generated_data=True)
```
-**Configuration Options:**
+Original: {n_orig} rows × {n_cols} cols | + Synthetic: {n_synth} rows × {n_cols} cols
+ +Generated by TabGAN QualityReport
+ + +""" + + +def _fig_to_base64(fig) -> str: + """Render a matplotlib figure to a base64-encoded PNG data URI.""" + buf = io.BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", dpi=100) + buf.seek(0) + b64 = base64.b64encode(buf.read()).decode("utf-8") + buf.close() + return f"data:image/png;base64,{b64}" + + +class QualityReport: + """Compare original and synthetic DataFrames across multiple quality axes. + + Args: + original_df: Real data. + synthetic_df: Synthetic data (must have the same columns). + cat_cols: Names of categorical columns (affects stats display). + target_col: Optional target column name for ML utility evaluation. + + Example:: + + report = QualityReport(orig, synth, cat_cols=["gender"]).compute() + report.to_html("report.html") + print(report.summary()) + """ + + def __init__( + self, + original_df: pd.DataFrame, + synthetic_df: pd.DataFrame, + cat_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + ): + shared = [c for c in original_df.columns if c in synthetic_df.columns] + self.original = original_df[shared].copy() + self.synthetic = synthetic_df[shared].copy() + self.cat_cols = [c for c in (cat_cols or []) if c in shared] + self.target_col = target_col + self._results: Dict = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def compute(self) -> "QualityReport": + """Run all metrics and store results internally.""" + self._results["column_stats"] = self._column_stats() + self._results["psi"] = self._psi_per_column() + self._results["correlation"] = self._correlation_comparison() + self._results["ml_utility"] = self._ml_utility_score() + self._results["overall"] = self._overall_score() + return self + + def summary(self) -> Dict: + """Return a dict of all metric results (without images).""" + if not self._results: + self.compute() + return { + "column_stats": self._results["column_stats"], + "psi": {k: v for k, v in self._results["psi"].items()}, + "ml_utility": self._results["ml_utility"], + "overall_score": self._results["overall"], + } + + def to_html(self, path: str) -> None: + """Write a self-contained HTML report to *path*.""" + if not self._results: + self.compute() + + psi = self._results["psi"] + ml = self._results["ml_utility"] + + # Column stats table + stats = self._results["column_stats"] + col_rows = [] + for col, s in stats.items(): + if s["dtype"] == "numeric": + col_rows.append( + f"| Column | Type | " + "Orig Mean/Info | Synth Mean/Info | " + "Orig Std/Info | Synth Std/Info |
|---|
| Column | PSI |
|---|
Train-on-Synthetic, Test-on-Real AUC: {ml.get('tstr_auc', 'N/A')}
" + f"Train-on-Real, Test-on-Real AUC (baseline): {ml.get('trtr_auc', 'N/A')}
" + f"Utility ratio: {ml.get('utility_ratio', 'N/A')}
" + ) + + numeric_psi = [v for k, v in psi.items() if k != "mean" and isinstance(v, (int, float))] + mean_psi = float(np.mean(numeric_psi)) if numeric_psi else 0.0 + + html = _HTML_TEMPLATE.format( + n_orig=len(self.original), + n_synth=len(self.synthetic), + n_cols=len(self.original.columns), + overall_score=self._results["overall"], + ml_utility=ml.get("utility_ratio", 0.0), + mean_psi=mean_psi, + column_stats_html=col_stats_html, + psi_html=psi_html, + corr_images=corr_html, + dist_images=dist_html, + ml_detail_html=ml_html, + ) + + with open(path, "w", encoding="utf-8") as fh: + fh.write(html) + logging.info(f"Quality report written to {path}") + + # ------------------------------------------------------------------ + # Private metric methods + # ------------------------------------------------------------------ + def _column_stats(self) -> Dict: + stats = {} + for col in self.original.columns: + if col in self.cat_cols or not pd.api.types.is_numeric_dtype(self.original[col]): + stats[col] = { + "dtype": "categorical", + "orig_nunique": int(self.original[col].nunique()), + "synth_nunique": int(self.synthetic[col].nunique()), + } + else: + stats[col] = { + "dtype": "numeric", + "orig_mean": float(self.original[col].mean()), + "synth_mean": float(self.synthetic[col].mean()), + "orig_std": float(self.original[col].std()), + "synth_std": float(self.synthetic[col].std()), + "orig_min": float(self.original[col].min()), + "synth_min": float(self.synthetic[col].min()), + "orig_max": float(self.original[col].max()), + "synth_max": float(self.synthetic[col].max()), + } + return stats + + def _psi_per_column(self) -> Dict: + psi_dict = {} + numeric_cols = self.original.select_dtypes(include=[np.number]).columns + for col in numeric_cols: + try: + val = float(calculate_psi( + self.original[col].values, + self.synthetic[col].values, + buckets=10, + )) + except Exception: + val = float("nan") + psi_dict[col] = val + vals = [v for v in psi_dict.values() if np.isfinite(v)] + psi_dict["mean"] = float(np.mean(vals)) if vals else 0.0 + return psi_dict + + def _correlation_comparison(self) -> Dict: + numeric_orig = self.original.select_dtypes(include=[np.number]) + numeric_synth = self.synthetic.select_dtypes(include=[np.number]) + if numeric_orig.shape[1] < 2: + return {"orig_corr": None, "synth_corr": None, "diff_mean": 0.0} + orig_corr = numeric_orig.corr() + synth_corr = numeric_synth.corr() + diff = (orig_corr - synth_corr).abs() + return { + "orig_corr": orig_corr, + "synth_corr": synth_corr, + "diff_mean": float(diff.mean().mean()), + } + + def _ml_utility_score(self) -> Dict: + """Train-on-Synthetic, Test-on-Real (TSTR) evaluation.""" + from sklearn.ensemble import GradientBoostingClassifier + + target = self.target_col + if target is None or target not in self.original.columns or target not in self.synthetic.columns: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + orig_num = self.original.select_dtypes(include=[np.number]) + synth_num = self.synthetic.select_dtypes(include=[np.number]) + + if target not in orig_num.columns: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + X_real = orig_num.drop(columns=[target]).values + y_real = orig_num[target].values + X_synth = synth_num.drop(columns=[target]).values + y_synth = synth_num[target].values + + if len(np.unique(y_real)) < 2 or len(np.unique(y_synth)) < 2: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + try: + clf_real = GradientBoostingClassifier(n_estimators=50, max_depth=3, random_state=42) + proba_real = cross_val_predict(clf_real, X_real, y_real, cv=3, method="predict_proba")[:, 1] + trtr_auc = float(roc_auc_score(y_real, proba_real)) + + clf_synth = GradientBoostingClassifier(n_estimators=50, max_depth=3, random_state=42) + clf_synth.fit(X_synth, y_synth) + proba_synth = clf_synth.predict_proba(X_real)[:, 1] + tstr_auc = float(roc_auc_score(y_real, proba_synth)) + + ratio = tstr_auc / trtr_auc if trtr_auc > 0 else 0.0 + except Exception as e: + logging.warning(f"ML utility evaluation failed: {e}") + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + return { + "tstr_auc": round(tstr_auc, 4), + "trtr_auc": round(trtr_auc, 4), + "utility_ratio": round(min(ratio, 1.0), 4), + } + + def _overall_score(self) -> float: + psi_vals = [v for k, v in self._results["psi"].items() + if k != "mean" and isinstance(v, (int, float)) and np.isfinite(v)] + # PSI component: lower is better → score = 1/(1+mean_psi) + mean_psi = float(np.mean(psi_vals)) if psi_vals else 0.0 + psi_score = 1.0 / (1.0 + mean_psi) + + # Correlation component + corr_diff = self._results["correlation"]["diff_mean"] + corr_score = max(1.0 - corr_diff, 0.0) + + # ML utility component + ml_ratio = self._results["ml_utility"].get("utility_ratio", 0.0) or 0.0 + + overall = 0.35 * psi_score + 0.30 * corr_score + 0.35 * ml_ratio + return round(min(max(overall, 0.0), 1.0), 4) + + # ------------------------------------------------------------------ + # Chart renderers (require matplotlib — lazy import) + # ------------------------------------------------------------------ + def _render_correlation_plots(self) -> Dict[str, str]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + logging.warning("matplotlib not installed — skipping correlation plots") + return {} + + corr = self._results["correlation"] + if corr["orig_corr"] is None: + return {} + + images = {} + for label, matrix in [("Original", corr["orig_corr"]), ("Synthetic", corr["synth_corr"])]: + fig, ax = plt.subplots(figsize=(5, 4)) + im = ax.imshow(matrix.values, cmap="RdBu_r", vmin=-1, vmax=1) + ax.set_xticks(range(len(matrix.columns))) + ax.set_yticks(range(len(matrix.columns))) + ax.set_xticklabels(matrix.columns, rotation=45, ha="right", fontsize=7) + ax.set_yticklabels(matrix.columns, fontsize=7) + ax.set_title(f"{label} Correlation") + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + images[label] = _fig_to_base64(fig) + plt.close(fig) + + return images + + def _render_distribution_plots(self) -> Dict[str, str]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + logging.warning("matplotlib not installed — skipping distribution plots") + return {} + + images = {} + for col in self.original.columns: + if not pd.api.types.is_numeric_dtype(self.original[col]): + continue + fig, ax = plt.subplots(figsize=(4, 3)) + ax.hist(self.original[col].dropna(), bins=30, alpha=0.5, label="Original", density=True) + ax.hist(self.synthetic[col].dropna(), bins=30, alpha=0.5, label="Synthetic", density=True) + ax.set_title(col, fontsize=10) + ax.legend(fontsize=7) + images[col] = _fig_to_base64(fig) + plt.close(fig) + + return images diff --git a/src/tabgan/sampler.py b/src/tabgan/sampler.py index a634d21..ce123cf 100644 --- a/src/tabgan/sampler.py +++ b/src/tabgan/sampler.py @@ -27,40 +27,32 @@ __all__ = ["OriginalGenerator", "GANGenerator", "ForestDiffusionGenerator", "LLMGenerator"] -class OriginalGenerator(SampleData): +class _BaseGenerator(SampleData): + """Base factory that stores constructor arguments for the concrete sampler.""" + _sampler_class = None + def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs def get_object_generator(self) -> Sampler: - return SamplerOriginal(*self.args, **self.kwargs) + return self._sampler_class(*self.args, **self.kwargs) -class GANGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +class OriginalGenerator(_BaseGenerator): + _sampler_class = None # set after SamplerOriginal is defined - def get_object_generator(self) -> Sampler: - return SamplerGAN(*self.args, **self.kwargs) +class GANGenerator(_BaseGenerator): + _sampler_class = None -class ForestDiffusionGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - def get_object_generator(self) -> Sampler: - return SamplerDiffusion(*self.args, **self.kwargs) +class ForestDiffusionGenerator(_BaseGenerator): + _sampler_class = None -class LLMGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - def get_object_generator(self) -> Sampler: - return SamplerLLM(*self.args, **self.kwargs) +class LLMGenerator(_BaseGenerator): + _sampler_class = None class SamplerOriginal(Sampler): @@ -71,17 +63,10 @@ def __init__( bot_filter_quantile: float = 0.001, top_filter_quantile: float = 0.999, is_post_process: bool = True, - adversarial_model_params: dict = { - "metrics": "AUC", - "max_depth": 2, - "max_bin": 100, - "n_estimators": 150, - "learning_rate": 0.02, - "random_state": 42, - }, + adversarial_model_params: dict = None, pregeneration_frac: float = 2, only_generated_data: bool = False, - gen_params: dict = {"batch_size": 45, 'patience': 25, "epochs": 50, "llm": "distilgpt2"}, + gen_params: dict = None, text_generating_columns: list = None, conditional_columns: list = None, llm_api_config: LLMAPIConfig = None, @@ -121,6 +106,17 @@ def __init__( the API instead of the local model. Useful for LM Studio, Ollama, OpenAI, etc. """ + if adversarial_model_params is None: + adversarial_model_params = { + "metrics": "AUC", + "max_depth": 2, + "max_bin": 100, + "n_estimators": 150, + "learning_rate": 0.02, + "random_state": 42, + } + if gen_params is None: + gen_params = {"batch_size": 45, "patience": 25, "epochs": 50, "llm": "distilgpt2"} super().__init__( gen_x_times=gen_x_times, cat_cols=cat_cols, @@ -140,10 +136,10 @@ def __init__( @staticmethod def preprocess_data_df(df) -> pd.DataFrame: - logging.info("Input shape: {}".format(df.shape)) - if isinstance(df, pd.DataFrame) is False: + logging.info(f"Input shape: {df.shape}") + if not isinstance(df, pd.DataFrame): raise ValueError( - "Input dataframe aren't pandas dataframes: df is {}".format(type(df)) + f"Input dataframe is not a pandas DataFrame: got {type(df)}" ) return df @@ -156,9 +152,7 @@ def preprocess_data( self.TEMP_TARGET = target.columns[0] if self.TEMP_TARGET in train.columns: raise ValueError( - "Input train dataframe already have {} column, consider removing it".format( - self.TEMP_TARGET - ) + f"Input train dataframe already has '{self.TEMP_TARGET}' column, consider removing it" ) if "test_similarity" in train.columns: raise ValueError( @@ -171,7 +165,7 @@ def generate_data( self, train_df, target, test_df, only_generated_data ) -> Tuple[pd.DataFrame, pd.DataFrame]: if only_generated_data: - Warning( + warnings.warn( "For SamplerOriginal setting only_generated_data doesn't change anything, " "because generated data sampled from the train!" ) @@ -183,10 +177,8 @@ def generate_data( generated_df = generated_df.reset_index(drop=True) logging.info( - "Generated shape: {} and {}".format( - generated_df.drop(self.TEMP_TARGET, axis=1).shape, - generated_df[self.TEMP_TARGET].shape, - ) + f"Generated shape: {generated_df.drop(self.TEMP_TARGET, axis=1).shape} " + f"and {generated_df[self.TEMP_TARGET].shape}" ) return ( generated_df.drop(self.TEMP_TARGET, axis=1), @@ -258,17 +250,14 @@ def _validate_data(train_df, target, test_df): if test_df is not None: if train_df.shape[0] < 10 or test_df.shape[0] < 10: raise ValueError( - "Shape of train is {} and test is {}. Both should at least 10! " - "Consider disabling adversarial filtering".format( - train_df.shape[0], test_df.shape[0] - ) + f"Shape of train is {train_df.shape[0]} and test is {test_df.shape[0]}. " + f"Both should be at least 10! Consider disabling adversarial filtering" ) if target is not None: if train_df.shape[0] != target.shape[0]: raise ValueError( - "Something gone wrong: shape of train_df = {} is not equal to target = {} shape".format( - train_df.shape[0], target.shape[0] - ) + f"Shape mismatch: train_df has {train_df.shape[0]} rows " + f"but target has {target.shape[0]} rows" ) def handle_generated_data(self, train_df, generated_df, only_generated_data): @@ -302,9 +291,7 @@ def handle_generated_data(self, train_df, generated_df, only_generated_data): if not only_generated_data: train_df = pd.concat([train_df, generated_df]).reset_index(drop=True) logging.info( - "Generated shapes: {} plus target".format( - _drop_col_if_exist(train_df, self.TEMP_TARGET).shape - ) + f"Generated shapes: {_drop_col_if_exist(train_df, self.TEMP_TARGET).shape} plus target" ) return ( _drop_col_if_exist(train_df, self.TEMP_TARGET), @@ -312,9 +299,7 @@ def handle_generated_data(self, train_df, generated_df, only_generated_data): ) else: logging.info( - "Generated shapes: {} plus target".format( - _drop_col_if_exist(generated_df, self.TEMP_TARGET).shape - ) + f"Generated shapes: {_drop_col_if_exist(generated_df, self.TEMP_TARGET).shape} plus target" ) return ( _drop_col_if_exist(generated_df, self.TEMP_TARGET), @@ -326,16 +311,15 @@ class SamplerGAN(SamplerOriginal): def check_params(self): if self.gen_params["batch_size"] % 10 != 0: logging.warning( - "Batch size should be divisible to 10, but provided {}. Fixing it".format( - self.gen_params["batch_size"])) + f"Batch size should be divisible by 10, but got {self.gen_params['batch_size']}. Fixing it") self.gen_params["batch_size"] += 10 - (self.gen_params["batch_size"] % 10) if "patience" not in self.gen_params: - logging.warning("patience param is not set for GAN params, so setting it to default ""25""") + logging.warning("patience param is not set for GAN params, setting default to 25") self.gen_params["patience"] = 25 if "epochs" not in self.gen_params: - logging.warning("patience param is not set for GAN params, so setting it to default ""50""") + logging.warning("epochs param is not set for GAN params, setting default to 50") self.gen_params["epochs"] = 50 def generate_data( @@ -389,16 +373,15 @@ def get_column_indexes(df, column_names): class SamplerLLM(SamplerOriginal): def check_params(self): if "llm" not in self.gen_params: - logging.warning("llm param is not set for LLM params, so setting it to default ""distilgpt2""") + logging.warning("llm param is not set for LLM params, setting default to 'distilgpt2'") self.gen_params["llm"] = "distilgpt2" if "max_length" not in self.gen_params: - logging.warning("max_length param is not set for LLM params, so setting it to default ""500""") - self.gen_params["max_length"] = "500" + logging.warning("max_length param is not set for LLM params, setting default to 500") + self.gen_params["max_length"] = 500 if self.gen_params["epochs"] < 3: logging.warning( - "Current set epoch = {} for llm training is too low, setting to 3!""".format( - self.gen_params["epochs"])) + f"Current epoch={self.gen_params['epochs']} for LLM training is too low, setting to 3") self.gen_params["epochs"] = 3 def _build_training_frame(self, train_df: pd.DataFrame, target: pd.DataFrame | None) -> pd.DataFrame: @@ -628,6 +611,13 @@ def _generate_via_prompt(self, prompt: str, great_model_instance, device: str, m return "" # Fallback or re-raise +# Wire up factory classes to their concrete sampler implementations +OriginalGenerator._sampler_class = SamplerOriginal +GANGenerator._sampler_class = SamplerGAN +ForestDiffusionGenerator._sampler_class = SamplerDiffusion +LLMGenerator._sampler_class = SamplerLLM + + if __name__ == "__main__": setup_logging(logging.DEBUG) train_size = 75 diff --git a/src/tabgan/sklearn_transformer.py b/src/tabgan/sklearn_transformer.py new file mode 100644 index 0000000..4c6d5f6 --- /dev/null +++ b/src/tabgan/sklearn_transformer.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +""" +sklearn-compatible transformer for TabGAN data augmentation. + +Allows inserting synthetic data generation into a ``sklearn.pipeline.Pipeline``. +""" + +import logging +from typing import List, Optional, Type + +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, TransformerMixin + +from tabgan.sampler import GANGenerator + +__all__ = ["TabGANTransformer"] + + +class TabGANTransformer(BaseEstimator, TransformerMixin): + """Augment training data with TabGAN synthetic rows inside an sklearn Pipeline. + + During ``fit`` the generator is trained and synthetic data produced. + ``transform`` returns the augmented DataFrame (original + synthetic). + + Because sklearn's ``transform`` only returns X, the augmented target + is available via :meth:`get_augmented_target` after ``fit_transform``. + + Args: + generator_class: A TabGAN generator class (e.g. ``GANGenerator``). + gen_x_times: Multiplier for synthetic sample count. + cat_cols: Categorical column names. + gen_params: Generator-specific hyperparameters. + only_generated_data: If True, return only synthetic rows. + constraints: Optional list of ``Constraint`` instances. + use_adversarial: Whether to use adversarial filtering. + **generator_kwargs: Extra keyword arguments forwarded to the generator. + + Example:: + + from sklearn.pipeline import Pipeline + from sklearn.ensemble import RandomForestClassifier + from tabgan.sklearn_transformer import TabGANTransformer + + pipe = Pipeline([ + ("augment", TabGANTransformer(gen_x_times=1.5)), + ("model", RandomForestClassifier()), + ]) + pipe.fit(X_train, y_train) + """ + + def __init__( + self, + generator_class: Type = None, + gen_x_times: float = 1.1, + cat_cols: Optional[List[str]] = None, + gen_params: Optional[dict] = None, + only_generated_data: bool = False, + constraints: Optional[list] = None, + use_adversarial: bool = True, + **generator_kwargs, + ): + self.generator_class = generator_class + self.gen_x_times = gen_x_times + self.cat_cols = cat_cols + self.gen_params = gen_params + self.only_generated_data = only_generated_data + self.constraints = constraints + self.use_adversarial = use_adversarial + self.generator_kwargs = generator_kwargs + + # Internal state (set after fit) + self._augmented_X: Optional[pd.DataFrame] = None + self._augmented_y: Optional[pd.Series] = None + + def fit(self, X, y=None): + """Train the generator and produce synthetic data. + + Args: + X: Training features (DataFrame or ndarray). + y: Target variable (Series, DataFrame, ndarray, or None). + """ + gen_cls = self.generator_class or GANGenerator + + X_df = pd.DataFrame(X).copy() if not isinstance(X, pd.DataFrame) else X.copy() + + target_df = None + if y is not None: + if isinstance(y, pd.DataFrame): + target_df = y.copy() + elif isinstance(y, pd.Series): + target_df = y.to_frame().copy() + else: + target_df = pd.DataFrame(y, columns=["target"]) + + gen_kwargs = dict( + gen_x_times=self.gen_x_times, + cat_cols=self.cat_cols, + only_generated_data=self.only_generated_data, + ) + if self.gen_params is not None: + gen_kwargs["gen_params"] = self.gen_params + gen_kwargs.update(self.generator_kwargs) + + generator = gen_cls(**gen_kwargs) + + new_train, new_target = generator.generate_data_pipe( + X_df, + target_df, + X_df, # use train as test for distribution alignment + use_adversarial=self.use_adversarial, + constraints=self.constraints, + ) + + self._augmented_X = new_train + if new_target is not None and not new_target.isna().all(): + self._augmented_y = ( + new_target.iloc[:, 0] if isinstance(new_target, pd.DataFrame) else new_target + ) + else: + self._augmented_y = None + + return self + + def transform(self, X, y=None): + """Return the augmented training data. + + During training (when ``_augmented_X`` is available), returns the + augmented data. At inference time, returns X unchanged. + """ + if self._augmented_X is not None: + result = self._augmented_X + # Clear after first transform to avoid leaking into predict + self._augmented_X = None + return result + return X + + def fit_transform(self, X, y=None, **fit_params): + """Fit and return augmented data in one step.""" + self.fit(X, y) + return self.transform(X, y) + + def get_augmented_target(self) -> Optional[pd.Series]: + """Return the augmented target produced during ``fit``. + + Call this after ``fit`` or ``fit_transform`` to get the target + values corresponding to the augmented training data. + """ + return self._augmented_y diff --git a/src/tabgan/utils.py b/src/tabgan/utils.py index df8b904..a0f9b19 100644 --- a/src/tabgan/utils.py +++ b/src/tabgan/utils.py @@ -24,11 +24,10 @@ def setup_logging(loglevel): ) -def make_two_digit(num_as_str: str) -> pd.DataFrame: - if len(num_as_str) == 2: - return num_as_str - else: +def make_two_digit(num_as_str: str) -> str: + if len(num_as_str) < 2: return "0" + num_as_str + return num_as_str def get_year_mnth_dt_from_date(df: pd.DataFrame, date_col="Date") -> pd.DataFrame: @@ -66,7 +65,7 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def _sampler(creator, in_train, in_target, in_test) -> None: +def _sampler(creator, in_train, in_target, in_test) -> tuple: _logger = logging.getLogger(__name__) _logger.info("Starting generating data") train, test = creator.generate_data_pipe(in_train, in_target, in_test) @@ -230,7 +229,7 @@ def compare_dataframes(df_original, df_generated): # Combine uniqueness, data quality, and PSI scores (weighted) similarity_score = 0.1 * uniqueness_score + 0.45 * data_quality_score + 0.45 * (1/psi_similarity) - print(uniqueness_score, data_quality_score, psi_similarity) + logging.debug(f"Similarity components: uniqueness={uniqueness_score}, quality={data_quality_score}, psi={psi_similarity}") # Ensure score is between 0 and 1 similarity_score = min(max(similarity_score, 0), 1) diff --git a/tests/test_cli.py b/tests/test_cli.py index b4886d9..6dfe593 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,50 +7,162 @@ import pandas as pd +def _make_cli_env(): + """Return an env dict with PYTHONPATH pointing at the src directory.""" + env = os.environ.copy() + src_path = os.path.join(os.path.dirname(__file__), "..", "src") + env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "") + return env + + +def _run_cli(args, env): + """Run `python -m tabgan.cli` with the given extra arguments.""" + cmd = [sys.executable, "-m", "tabgan.cli"] + args + return subprocess.run(cmd, check=False, capture_output=True, text=True, env=env) + + def test_tabgan_generate_cli_creates_output_with_target(): - # Prepare temporary input and output CSV paths with tempfile.TemporaryDirectory() as tmpdir: input_path = os.path.join(tmpdir, "train.csv") output_path = os.path.join(tmpdir, "synthetic.csv") - # Small dummy dataset with a target column df = pd.DataFrame( np.random.randint(0, 10, size=(20, 3)), columns=["A", "B", "target"], ) df.to_csv(input_path, index=False) - # Set up environment with PYTHONPATH pointing to src directory - env = os.environ.copy() - src_path = os.path.join(os.path.dirname(__file__), "..", "src") - env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "") - - # Invoke the installed console script through Python -m to avoid PATH issues in CI - cmd = [ - sys.executable, - "-m", - "tabgan.cli", - "--input-csv", - input_path, - "--target-col", - "target", - "--generator", - "original", - "--gen-x-times", - "1.0", - "--output-csv", - output_path, - ] - - result = subprocess.run(cmd, check=False, capture_output=True, text=True, env=env) + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) assert result.returncode == 0, f"CLI failed: {result.stderr}" assert os.path.exists(output_path), "Output CSV was not created by CLI" out_df = pd.read_csv(output_path) - - # Target column must be present assert "target" in out_df.columns - # At least as many rows as the original (OriginalGenerator samples with replacement) assert len(out_df) >= len(df) + +def test_cli_gan_generator(): + """CLI with --generator gan produces valid output.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 50, size=(30, 3)), + columns=["A", "B", "target"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "gan", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (gan) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "target" in out_df.columns + assert len(out_df) > 0 + + +def test_cli_only_generated_flag(): + """CLI with --only-generated returns output without original rows.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 10, size=(20, 3)), + columns=["A", "B", "target"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--only-generated", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (only-generated) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "target" in out_df.columns + assert len(out_df) > 0 + + +def test_cli_without_target(): + """CLI without --target-col should work (target=None path).""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 10, size=(20, 3)), + columns=["A", "B", "C"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--generator", "original", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (no target) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert len(out_df) > 0 + + +def test_cli_with_cat_cols(): + """CLI with --cat-cols passes categorical column names correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + rng = np.random.RandomState(42) + df = pd.DataFrame({ + "num": rng.randint(0, 100, 30), + "cat": rng.choice(["X", "Y", "Z"], 30), + "target": rng.randint(0, 2, 30), + }) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--cat-cols", "cat", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (cat-cols) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "cat" in out_df.columns + assert "target" in out_df.columns + diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 0000000..ee7b435 --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""Tests for the constraint system.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.constraints import ( + RangeConstraint, + UniqueConstraint, + FormulaConstraint, + RegexConstraint, + ConstraintEngine, +) +from src.tabgan.sampler import OriginalGenerator + + +class TestRangeConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"age": [5, 25, 150, -1]}) + c = RangeConstraint("age", min_val=0, max_val=120) + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, True, False, False]) + + def test_fix_clips_values(self): + df = pd.DataFrame({"age": [5, 25, 150, -1]}) + c = RangeConstraint("age", min_val=0, max_val=120) + fixed = c.fix(df) + self.assertEqual(list(fixed["age"]), [5, 25, 120, 0]) + + def test_min_only(self): + df = pd.DataFrame({"x": [-5, 0, 10]}) + c = RangeConstraint("x", min_val=0) + self.assertEqual(list(c.is_satisfied(df)), [False, True, True]) + + def test_max_only(self): + df = pd.DataFrame({"x": [-5, 0, 10]}) + c = RangeConstraint("x", max_val=5) + self.assertEqual(list(c.is_satisfied(df)), [True, True, False]) + + def test_requires_at_least_one_bound(self): + with self.assertRaises(ValueError): + RangeConstraint("x") + + +class TestUniqueConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"id": [1, 2, 3, 2, 1]}) + c = UniqueConstraint("id") + mask = c.is_satisfied(df) + # First occurrences are True, duplicates are False + self.assertEqual(list(mask), [True, True, True, False, False]) + + def test_fix_drops_duplicates(self): + df = pd.DataFrame({"id": [1, 2, 3, 2, 1], "val": [10, 20, 30, 40, 50]}) + c = UniqueConstraint("id") + fixed = c.fix(df) + self.assertEqual(len(fixed), 3) + self.assertEqual(list(fixed["id"]), [1, 2, 3]) + + +class TestFormulaConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"start": [1, 5, 10], "end": [10, 3, 15]}) + c = FormulaConstraint("end > start") + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, False, True]) + + def test_fix_filters_violations(self): + df = pd.DataFrame({"start": [1, 5, 10], "end": [10, 3, 15]}) + c = FormulaConstraint("end > start") + fixed = c.fix(df) + self.assertEqual(len(fixed), 2) + self.assertEqual(list(fixed["start"]), [1, 10]) + + +class TestRegexConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"email": ["a@b.com", "invalid", "x@y.org"]}) + c = RegexConstraint("email", r".+@.+\..+") + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, False, True]) + + def test_fix_filters(self): + df = pd.DataFrame({"code": ["AB12", "XY34", "bad!"]}) + c = RegexConstraint("code", r"[A-Z]{2}\d{2}") + fixed = c.fix(df) + self.assertEqual(len(fixed), 2) + self.assertEqual(list(fixed["code"]), ["AB12", "XY34"]) + + +class TestConstraintEngine(unittest.TestCase): + def test_filter_strategy(self): + df = pd.DataFrame({"age": [5, 150, 30], "id": [1, 2, 3]}) + engine = ConstraintEngine( + [RangeConstraint("age", min_val=0, max_val=120)], + strategy="filter", + ) + result = engine.apply(df) + self.assertEqual(len(result), 2) + + def test_fix_strategy(self): + df = pd.DataFrame({"age": [5, 150, 30], "id": [1, 2, 3]}) + engine = ConstraintEngine( + [RangeConstraint("age", min_val=0, max_val=120)], + strategy="fix", + ) + result = engine.apply(df) + self.assertEqual(len(result), 3) # All rows kept after clipping + self.assertEqual(result["age"].max(), 120) + + def test_multiple_constraints(self): + df = pd.DataFrame({ + "age": [5, 150, 30, 25, 25], + "id": [1, 2, 3, 4, 4], + }) + engine = ConstraintEngine([ + RangeConstraint("age", min_val=0, max_val=120), + UniqueConstraint("id"), + ], strategy="fix") + result = engine.apply(df) + # After fix: age clipped, then unique id kept + self.assertTrue(result["age"].max() <= 120) + self.assertEqual(result["id"].nunique(), len(result)) + + def test_invalid_strategy_raises(self): + with self.assertRaises(ValueError): + ConstraintEngine([], strategy="invalid") + + +class TestConstraintsInPipeline(unittest.TestCase): + def test_generate_data_pipe_with_constraints(self): + rng = np.random.RandomState(42) + train = pd.DataFrame(rng.randint(-10, 200, size=(60, 3)), columns=list("ABC")) + target = pd.DataFrame(rng.randint(0, 2, size=(60, 1)), columns=["Y"]) + test = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + + constraints = [RangeConstraint("A", min_val=0, max_val=100)] + + new_train, new_target = OriginalGenerator(gen_x_times=1.5).generate_data_pipe( + train, target, test, constraints=constraints, + ) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreaterEqual(new_train["A"].min(), 0) + self.assertLessEqual(new_train["A"].max(), 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_data_pipe.py b/tests/test_generate_data_pipe.py new file mode 100644 index 0000000..f08ae76 --- /dev/null +++ b/tests/test_generate_data_pipe.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +""" +Tests for generate_data_pipe parameter combinations and cat_cols handling +in postprocess / adversarial filtering. +""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.sampler import ( + OriginalGenerator, + GANGenerator, + ForestDiffusionGenerator, + SamplerOriginal, +) + + +def _make_data(n_train=80, n_test=80, seed=42): + """Create reproducible train / target / test DataFrames.""" + rng = np.random.RandomState(seed) + train = pd.DataFrame(rng.randint(0, 100, size=(n_train, 4)), columns=list("ABCD")) + target = pd.DataFrame(rng.randint(0, 2, size=(n_train, 1)), columns=["Y"]) + test = pd.DataFrame(rng.randint(0, 100, size=(n_test, 4)), columns=list("ABCD")) + return train, target, test + + +def _make_data_with_cat(n_train=80, n_test=80, seed=42): + """Create data with an explicit categorical column.""" + rng = np.random.RandomState(seed) + train = pd.DataFrame({ + "num1": rng.randint(0, 100, n_train), + "num2": rng.randint(0, 100, n_train), + "cat": rng.choice(["X", "Y", "Z"], n_train), + }) + target = pd.DataFrame({"Y": rng.randint(0, 2, n_train)}) + test = pd.DataFrame({ + "num1": rng.randint(0, 100, n_test), + "num2": rng.randint(0, 100, n_test), + "cat": rng.choice(["X", "Y", "Z"], n_test), + }) + return train, target, test + + +# --------------------------------------------------------------------------- +# generate_data_pipe parameter combinations +# --------------------------------------------------------------------------- +class TestGenerateDataPipeParams(unittest.TestCase): + """Test various parameter combinations of generate_data_pipe.""" + + def test_only_adversarial_true(self): + """only_adversarial=True should skip generation and only filter.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train, target, test, + only_adversarial=True, + use_adversarial=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + # With only adversarial filtering on original data, output rows <= input rows + self.assertLessEqual(new_train.shape[0], train.shape[0]) + + def test_use_adversarial_false(self): + """use_adversarial=False should skip adversarial filtering entirely.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.5).generate_data_pipe( + train, target, test, + use_adversarial=False, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + # Without adversarial filtering, we should have more rows than original + self.assertGreater(new_train.shape[0], train.shape[0]) + + def test_deep_copy_false(self): + """deep_copy=False should still produce valid output.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train.copy(), target.copy(), test.copy(), + deep_copy=False, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + + def test_only_generated_data_true_original(self): + """only_generated_data=True with OriginalGenerator — data is sampled from train.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator( + gen_x_times=1.1, + only_generated_data=True, + ).generate_data_pipe( + train, target, test, + only_generated_data=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + + def test_only_generated_data_true_gan(self): + """only_generated_data=True with GANGenerator returns purely synthetic rows.""" + train, target, test = _make_data() + new_train, new_target = GANGenerator( + gen_x_times=1.0, + only_generated_data=True, + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ).generate_data_pipe( + train, target, test, + only_generated_data=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + self.assertEqual(new_train.shape[1], train.shape[1]) + + def test_target_none(self): + """Passing target=None should work for all generators.""" + train, _, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train, None, test, + ) + self.assertIsNotNone(new_train) + self.assertGreater(new_train.shape[0], 0) + + def test_test_df_none(self): + """Passing test_df=None should skip postprocess and adversarial.""" + train, _, _ = _make_data() + new_train, new_target = OriginalGenerator( + gen_x_times=1.1, + is_post_process=False, + ).generate_data_pipe( + train, None, None, + ) + self.assertIsNotNone(new_train) + self.assertGreater(new_train.shape[0], 0) + + +# --------------------------------------------------------------------------- +# cat_cols in postprocess and adversarial filtering +# --------------------------------------------------------------------------- +class TestCatColsPostprocessAndAdversarial(unittest.TestCase): + """Test that cat_cols are handled correctly in postprocess and adversarial.""" + + def test_postprocess_with_cat_cols(self): + """Postprocessing with cat_cols should filter by category membership.""" + train, target, test = _make_data_with_cat() + + sampler = OriginalGenerator( + gen_x_times=2.0, + cat_cols=["cat"], + ).get_object_generator() + + new_train, new_target, test_df = sampler.preprocess_data( + train.copy(), target.copy(), test.copy() + ) + gen_train, gen_target = sampler.generate_data( + new_train, new_target, test_df, only_generated_data=False + ) + post_train, post_target = sampler.postprocess_data(gen_train, gen_target, test_df) + + self.assertEqual(post_train.shape[0], post_target.shape[0]) + # All categorical values in result should be present in test + result_cats = set(post_train["cat"].unique()) + test_cats = set(test_df["cat"].unique()) + self.assertTrue(result_cats.issubset(test_cats)) + + def test_adversarial_with_cat_cols(self): + """Adversarial filtering with cat_cols should produce valid output.""" + train, target, test = _make_data_with_cat() + + sampler = OriginalGenerator( + gen_x_times=2.0, + cat_cols=["cat"], + ).get_object_generator() + + new_train, new_target, test_df = sampler.preprocess_data( + train.copy(), target.copy(), test.copy() + ) + gen_train, gen_target = sampler.generate_data( + new_train, new_target, test_df, only_generated_data=False + ) + post_train, post_target = sampler.postprocess_data(gen_train, gen_target, test_df) + adv_train, adv_target = sampler.adversarial_filtering(post_train, post_target, test_df) + + self.assertEqual(adv_train.shape[0], adv_target.shape[0]) + self.assertGreater(adv_train.shape[0], 0) + + def test_full_pipeline_with_cat_cols(self): + """End-to-end generate_data_pipe with cat_cols.""" + train, target, test = _make_data_with_cat() + new_train, new_target = OriginalGenerator( + gen_x_times=1.5, + cat_cols=["cat"], + ).generate_data_pipe(train, target, test) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + self.assertIn("cat", new_train.columns) + + def test_gan_with_cat_cols(self): + """GANGenerator with cat_cols should train and generate correctly.""" + train, target, test = _make_data_with_cat() + new_train, new_target = GANGenerator( + gen_x_times=1.1, + cat_cols=["cat"], + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ).generate_data_pipe(train, target, test) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_privacy_metrics.py b/tests/test_privacy_metrics.py new file mode 100644 index 0000000..8e9e9ce --- /dev/null +++ b/tests/test_privacy_metrics.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +"""Tests for privacy metrics.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.privacy_metrics import PrivacyMetrics + + +def _make_numeric_data(seed=42): + rng = np.random.RandomState(seed) + original = pd.DataFrame(rng.normal(0, 1, size=(100, 4)), columns=list("ABCD")) + synthetic = pd.DataFrame(rng.normal(0, 1, size=(80, 4)), columns=list("ABCD")) + return original, synthetic + + +class TestDCR(unittest.TestCase): + def test_dcr_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.dcr() + self.assertIn("mean", result) + self.assertIn("median", result) + self.assertIn("5th_percentile", result) + self.assertIn("distances", result) + + def test_identical_data_dcr_near_zero(self): + orig, _ = _make_numeric_data() + pm = PrivacyMetrics(orig, orig.copy()) + result = pm.dcr() + self.assertAlmostEqual(result["mean"], 0.0, places=5) + + def test_distant_data_dcr_positive(self): + rng = np.random.RandomState(42) + orig = pd.DataFrame(rng.normal(0, 1, size=(100, 3)), columns=list("ABC")) + synth = pd.DataFrame(rng.normal(10, 1, size=(80, 3)), columns=list("ABC")) + pm = PrivacyMetrics(orig, synth) + result = pm.dcr() + self.assertGreater(result["mean"], 1.0) + + def test_dcr_with_sample_size(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.dcr(sample_size=20) + self.assertEqual(len(result["distances"]), 20) + + +class TestNNDR(unittest.TestCase): + def test_nndr_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.nndr() + self.assertIn("mean", result) + self.assertIn("median", result) + self.assertIn("ratios", result) + + def test_nndr_values_between_0_and_1(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.nndr() + self.assertGreater(result["mean"], 0.0) + self.assertLessEqual(result["mean"], 1.0) + + +class TestMembershipInference(unittest.TestCase): + def test_mi_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.membership_inference_risk() + self.assertIn("auc", result) + self.assertIn("accuracy", result) + + def test_mi_auc_in_range(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.membership_inference_risk() + self.assertGreaterEqual(result["auc"], 0.0) + self.assertLessEqual(result["auc"], 1.0) + + +class TestSummary(unittest.TestCase): + def test_summary_returns_overall_score(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + s = pm.summary() + self.assertIn("overall_privacy_score", s) + self.assertGreaterEqual(s["overall_privacy_score"], 0.0) + self.assertLessEqual(s["overall_privacy_score"], 1.0) + + def test_summary_contains_all_sections(self): + orig, synth = _make_numeric_data() + s = PrivacyMetrics(orig, synth).summary() + self.assertIn("dcr", s) + self.assertIn("nndr", s) + self.assertIn("membership_inference", s) + + def test_with_cat_cols(self): + rng = np.random.RandomState(42) + orig = pd.DataFrame({ + "num": rng.normal(0, 1, 80), + "cat": rng.choice(["A", "B", "C"], 80), + }) + synth = pd.DataFrame({ + "num": rng.normal(0, 1, 60), + "cat": rng.choice(["A", "B", "C"], 60), + }) + s = PrivacyMetrics(orig, synth, cat_cols=["cat"]).summary() + self.assertIn("overall_privacy_score", s) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_quality_report.py b/tests/test_quality_report.py new file mode 100644 index 0000000..53eb00a --- /dev/null +++ b/tests/test_quality_report.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Tests for the quality report.""" + +import os +import tempfile +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.quality_report import QualityReport + + +def _make_data(seed=42): + rng = np.random.RandomState(seed) + original = pd.DataFrame({ + "A": rng.normal(0, 1, 100), + "B": rng.normal(5, 2, 100), + "target": rng.randint(0, 2, 100), + }) + synthetic = pd.DataFrame({ + "A": rng.normal(0, 1, 80), + "B": rng.normal(5, 2, 80), + "target": rng.randint(0, 2, 80), + }) + return original, synthetic + + +class TestQualityReportCompute(unittest.TestCase): + def test_compute_returns_self(self): + orig, synth = _make_data() + report = QualityReport(orig, synth).compute() + self.assertIsInstance(report, QualityReport) + + def test_summary_has_required_keys(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("column_stats", s) + self.assertIn("psi", s) + self.assertIn("ml_utility", s) + self.assertIn("overall_score", s) + + def test_overall_score_range(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertGreaterEqual(s["overall_score"], 0.0) + self.assertLessEqual(s["overall_score"], 1.0) + + def test_column_stats_numeric(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("A", s["column_stats"]) + self.assertEqual(s["column_stats"]["A"]["dtype"], "numeric") + self.assertIn("orig_mean", s["column_stats"]["A"]) + + def test_column_stats_categorical(self): + orig = pd.DataFrame({"cat": ["a", "b", "a", "c"], "num": [1, 2, 3, 4]}) + synth = pd.DataFrame({"cat": ["a", "b", "b", "c"], "num": [1, 2, 3, 4]}) + s = QualityReport(orig, synth, cat_cols=["cat"]).compute().summary() + self.assertEqual(s["column_stats"]["cat"]["dtype"], "categorical") + + def test_psi_per_column(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("A", s["psi"]) + self.assertIn("mean", s["psi"]) + + def test_ml_utility_with_target(self): + orig, synth = _make_data() + s = QualityReport(orig, synth, target_col="target").compute().summary() + ml = s["ml_utility"] + self.assertIn("tstr_auc", ml) + self.assertIn("trtr_auc", ml) + self.assertIn("utility_ratio", ml) + + def test_ml_utility_without_target(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertEqual(s["ml_utility"]["utility_ratio"], 0.0) + + +class TestQualityReportHTML(unittest.TestCase): + def test_to_html_creates_file(self): + orig, synth = _make_data() + report = QualityReport(orig, synth, target_col="target").compute() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + self.assertTrue(os.path.exists(path)) + with open(path) as f: + content = f.read() + self.assertIn("TabGAN Quality Report", content) + self.assertIn("Overall Score", content) + + def test_html_contains_charts(self): + orig, synth = _make_data() + report = QualityReport(orig, synth).compute() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + with open(path) as f: + content = f.read() + # Should contain base64 images + self.assertIn("data:image/png;base64", content) + + def test_auto_compute_on_to_html(self): + """to_html should auto-compute if compute() wasn't called.""" + orig, synth = _make_data() + report = QualityReport(orig, synth) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + self.assertTrue(os.path.exists(path)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 2ee45c7..d6ecec0 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -78,22 +78,23 @@ def test__validate_data(self): args = [self.train.head(), self.target.copy(), self.test] self.assertRaises(ValueError, self.sampler._validate_data, *args) - class TestSamplerGAN(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = GANGenerator(gen_x_times=15) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + +class TestSamplerGAN(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = GANGenerator(gen_x_times=15) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) class TestSamplerLLMConditional(TestCase): @@ -426,37 +427,39 @@ def test_empty_text_or_conditional_columns_use_fallback_sampling(self): mock_conditional.assert_not_called() pd.testing.assert_frame_equal(new_train_df.reset_index(drop=True), dummy_generated.reset_index(drop=True)) - class TestSamplerSamplerDiffusion(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = ForestDiffusionGenerator(gen_x_times=15) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - - class TestSamplerSamplerDiffusion(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", - "max_length": 500}) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + +class TestSamplerDiffusion(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = ForestDiffusionGenerator(gen_x_times=15) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + + +class TestSamplerLLMDirect(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", + "max_length": 500}) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) diff --git a/tests/test_sklearn_transformer.py b/tests/test_sklearn_transformer.py new file mode 100644 index 0000000..2f0449d --- /dev/null +++ b/tests/test_sklearn_transformer.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +"""Tests for the sklearn-compatible TabGANTransformer.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.sampler import OriginalGenerator, GANGenerator +from src.tabgan.sklearn_transformer import TabGANTransformer +from src.tabgan.constraints import RangeConstraint + + +class TestTabGANTransformerBasic(unittest.TestCase): + def setUp(self): + rng = np.random.RandomState(42) + self.X = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + self.y = pd.Series(rng.randint(0, 2, 60), name="target") + + def test_fit_transform_returns_dataframe(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + result = t.fit_transform(self.X, self.y) + self.assertIsInstance(result, pd.DataFrame) + self.assertGreater(len(result), 0) + + def test_augmented_target_available(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + t.fit(self.X, self.y) + aug_y = t.get_augmented_target() + self.assertIsNotNone(aug_y) + + def test_augmented_shapes_aligned(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + X_aug = t.fit_transform(self.X, self.y) + y_aug = t.get_augmented_target() + self.assertEqual(len(X_aug), len(y_aug)) + + def test_without_target(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + result = t.fit_transform(self.X) + self.assertIsInstance(result, pd.DataFrame) + self.assertIsNone(t.get_augmented_target()) + + def test_transform_at_inference_passthrough(self): + """After first transform, subsequent transforms pass data through.""" + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + t.fit(self.X, self.y) + _ = t.transform(self.X) # First transform consumes augmented data + result2 = t.transform(self.X) # Second should pass through + self.assertEqual(len(result2), len(self.X)) + + +class TestTabGANTransformerWithGAN(unittest.TestCase): + def setUp(self): + rng = np.random.RandomState(42) + self.X = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + self.y = pd.Series(rng.randint(0, 2, 60), name="target") + + def test_gan_generator(self): + t = TabGANTransformer( + generator_class=GANGenerator, + gen_x_times=1.0, + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ) + result = t.fit_transform(self.X, self.y) + self.assertGreater(len(result), 0) + self.assertEqual(result.shape[1], self.X.shape[1]) + + +class TestTabGANTransformerWithConstraints(unittest.TestCase): + def test_constraints_applied(self): + rng = np.random.RandomState(42) + X = pd.DataFrame(rng.randint(-50, 200, size=(60, 3)), columns=list("ABC")) + y = pd.Series(rng.randint(0, 2, 60)) + + t = TabGANTransformer( + generator_class=OriginalGenerator, + gen_x_times=1.5, + constraints=[RangeConstraint("A", min_val=0, max_val=100)], + ) + result = t.fit_transform(X, y) + self.assertGreaterEqual(result["A"].min(), 0) + self.assertLessEqual(result["A"].max(), 100) + + +class TestTabGANTransformerSklearnCompat(unittest.TestCase): + def test_get_params(self): + t = TabGANTransformer(gen_x_times=2.0, cat_cols=["A"]) + params = t.get_params() + self.assertEqual(params["gen_x_times"], 2.0) + self.assertEqual(params["cat_cols"], ["A"]) + + def test_set_params(self): + t = TabGANTransformer(gen_x_times=1.0) + t.set_params(gen_x_times=3.0) + self.assertEqual(t.gen_x_times, 3.0) + + def test_in_sklearn_pipeline(self): + """Verify it can be placed in a Pipeline (doesn't run the full pipeline).""" + from sklearn.pipeline import Pipeline + from sklearn.ensemble import RandomForestClassifier + + pipe = Pipeline([ + ("augment", TabGANTransformer( + generator_class=OriginalGenerator, + gen_x_times=1.1, + )), + ("model", RandomForestClassifier(n_estimators=5, random_state=42)), + ]) + # Pipeline should be constructable and have proper steps + self.assertEqual(len(pipe.steps), 2) + self.assertEqual(pipe.steps[0][0], "augment") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index c0f2319..443a918 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,14 +4,14 @@ from tabgan.utils import make_two_digit, get_year_mnth_dt_from_date -class TestUtils(unittest.TestCase): +class TestMakeTwoDigit(unittest.TestCase): def test_make_two_digit(self): self.assertEqual(make_two_digit('1'), '01') self.assertEqual(make_two_digit('12'), '12') self.assertEqual(make_two_digit('123'), '123') -class TestUtils(unittest.TestCase): +class TestGetYearMonthDtFromDate(unittest.TestCase): def test_get_year_month_dt_from_date(self): # create a sample dataframe df = pd.DataFrame({ diff --git a/tests/test_utils_extended.py b/tests/test_utils_extended.py new file mode 100644 index 0000000..82de28a --- /dev/null +++ b/tests/test_utils_extended.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +""" +Extended tests for tabgan.utils — covers compare_dataframes, calculate_psi, +collect_dates, and the time-series round-trip workflow. +""" + +import unittest + +import numpy as np +import pandas as pd + +from tabgan.utils import ( + calculate_psi, + collect_dates, + compare_dataframes, + get_year_mnth_dt_from_date, +) + + +# --------------------------------------------------------------------------- +# calculate_psi +# --------------------------------------------------------------------------- +class TestCalculatePSI(unittest.TestCase): + def test_identical_distributions_returns_near_zero(self): + arr = np.random.RandomState(42).normal(0, 1, size=500) + psi = calculate_psi(arr, arr.copy(), buckets=10) + # PSI of identical arrays should be ~0 + self.assertAlmostEqual(float(psi), 0.0, places=3) + + def test_shifted_distribution_returns_positive(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=1000) + actual = rng.normal(2, 1, size=1000) # shifted mean + psi = calculate_psi(expected, actual, buckets=10) + self.assertGreater(float(psi), 0.1) + + def test_2d_array_axis0(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=(200, 3)) + actual = rng.normal(0, 1, size=(200, 3)) + psi_values = calculate_psi(expected, actual, buckets=10, axis=0) + self.assertEqual(len(psi_values), 3) + for v in psi_values: + self.assertGreaterEqual(v, 0.0) + + def test_quantile_bucket_type(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=500) + actual = rng.normal(0, 1, size=500) + psi = calculate_psi(expected, actual, buckettype="quantiles", buckets=10) + self.assertGreaterEqual(float(psi), 0.0) + + +# --------------------------------------------------------------------------- +# compare_dataframes +# --------------------------------------------------------------------------- +class TestCompareDataframes(unittest.TestCase): + def test_identical_dataframes_score_high(self): + df = pd.DataFrame({"a": range(100), "b": range(100, 200)}) + score = compare_dataframes(df.copy(), df.copy()) + self.assertGreater(score, 0.5) + self.assertLessEqual(score, 1.0) + + def test_completely_different_columns_returns_zero(self): + df1 = pd.DataFrame({"a": [1, 2, 3]}) + df2 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + score = compare_dataframes(df1, df2) + self.assertEqual(score, 0.0) + + def test_similar_dataframes_score_moderate(self): + rng = np.random.RandomState(42) + df1 = pd.DataFrame({"x": rng.normal(0, 1, 200), "y": rng.normal(5, 2, 200)}) + df2 = pd.DataFrame({"x": rng.normal(0, 1, 200), "y": rng.normal(5, 2, 200)}) + score = compare_dataframes(df1, df2) + self.assertGreater(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_score_is_between_0_and_1(self): + rng = np.random.RandomState(0) + df1 = pd.DataFrame({"a": rng.randint(0, 100, 50)}) + df2 = pd.DataFrame({"a": rng.randint(0, 100, 50)}) + score = compare_dataframes(df1, df2) + self.assertGreaterEqual(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_with_only_numeric_columns(self): + rng = np.random.RandomState(10) + df1 = pd.DataFrame({"a": rng.randint(0, 50, 80), "b": rng.randint(0, 50, 80)}) + df2 = pd.DataFrame({"a": rng.randint(0, 50, 80), "b": rng.randint(0, 50, 80)}) + score = compare_dataframes(df1, df2) + self.assertGreaterEqual(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_with_mixed_types_raises_on_string_psi(self): + """compare_dataframes passes string columns to calculate_psi which + cannot handle them. This documents the current behavior.""" + df1 = pd.DataFrame({"num": [1, 2, 3, 4], "cat": ["a", "b", "a", "c"]}) + df2 = pd.DataFrame({"num": [1, 2, 5, 6], "cat": ["a", "b", "x", "y"]}) + with self.assertRaises(TypeError): + compare_dataframes(df1.copy(), df2.copy()) + + +# --------------------------------------------------------------------------- +# collect_dates (round-trip with get_year_mnth_dt_from_date) +# --------------------------------------------------------------------------- +class TestCollectDates(unittest.TestCase): + def test_round_trip(self): + """Decompose dates then reassemble — result should match originals.""" + dates = pd.to_datetime(["2022-01-15", "2023-06-01", "2021-12-31"]) + df = pd.DataFrame({"Date": dates, "value": [10, 20, 30]}) + decomposed = get_year_mnth_dt_from_date(df.copy(), "Date") + + # Drop the original Date column to simulate the generation pipeline + decomposed = decomposed.drop("Date", axis=1) + reassembled = collect_dates(decomposed) + + self.assertIn("Date", reassembled.columns) + self.assertNotIn("year", reassembled.columns) + self.assertNotIn("month", reassembled.columns) + self.assertNotIn("day", reassembled.columns) + + expected_dates = ["2022-01-15", "2023-06-01", "2021-12-31"] + self.assertEqual(list(reassembled["Date"]), expected_dates) + + def test_single_digit_month_day_padded(self): + """Months and days < 10 should be zero-padded.""" + df = pd.DataFrame({"year": [2020], "month": [3], "day": [5], "x": [1]}) + result = collect_dates(df) + self.assertEqual(result["Date"].iloc[0], "2020-03-05") + + +# --------------------------------------------------------------------------- +# Time-series generation workflow (integration-like) +# --------------------------------------------------------------------------- +class TestTimeSeriesWorkflow(unittest.TestCase): + def test_date_decompose_generate_collect(self): + """Full round-trip: decompose dates → OriginalGenerator → collect dates.""" + from tabgan.sampler import OriginalGenerator + + rng = np.random.RandomState(42) + train = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + min_date = pd.to_datetime("2020-01-01") + max_date = pd.to_datetime("2021-12-31") + d = (max_date - min_date).days + 1 + train["Date"] = min_date + pd.to_timedelta(rng.randint(d, size=60), unit="D") + train = get_year_mnth_dt_from_date(train, "Date") + + train_no_date = train.drop("Date", axis=1) + + new_train, _ = OriginalGenerator( + gen_x_times=1.1, + cat_cols=["year"], + is_post_process=True, + pregeneration_frac=2, + ).generate_data_pipe(train_no_date, None, train_no_date) + + self.assertIn("year", new_train.columns) + self.assertIn("month", new_train.columns) + self.assertIn("day", new_train.columns) + + new_train = collect_dates(new_train) + self.assertIn("Date", new_train.columns) + self.assertNotIn("year", new_train.columns)