From 4653fe27e19c5ac678a782085d636b6c8d832642 Mon Sep 17 00:00:00 2001 From: Insaf Ashrapov Date: Sat, 28 Mar 2026 05:05:41 +0000 Subject: [PATCH 1/2] Refactor codebase, add 4 new features, expand test coverage to 115 tests Refactoring: - Fix mutable default arguments in SamplerOriginal - Fix Warning() -> warnings.warn(), standardize f-strings - Fix wrong log message in abc_sampler.py (deep copy branch) - Fix nested/duplicate test classes (TestSamplerGAN, TestSamplerDiffusion) - Fix duplicate TestUtils class, FrequencyEncoder.fit signature - Fix make_two_digit() bug for 3+ char strings - Replace deprecated pkg_resources with importlib.metadata - DRY generator factories via _BaseGenerator base class - Replace print() with logging.debug() in compare_dataframes New features: - Constraints system (RangeConstraint, UniqueConstraint, FormulaConstraint, RegexConstraint) - Privacy metrics (DCR, NNDR, membership inference risk) - Quality report with HTML export (stats, PSI, correlations, ML utility) - sklearn Pipeline integration (TabGANTransformer) Other: - Professional README with centered badges, pipeline diagram, CLI docs - Add Python version classifiers, update python_requires to >= 3.9 - Add matplotlib, requests to dependencies - 115 tests passing (was 39) Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 337 ++++++++++++------------- setup.cfg | 10 +- src/tabgan/__init__.py | 30 ++- src/tabgan/abc_sampler.py | 19 +- src/tabgan/constraints.py | 157 ++++++++++++ src/tabgan/encoders.py | 2 +- src/tabgan/privacy_metrics.py | 225 +++++++++++++++++ src/tabgan/quality_report.py | 406 ++++++++++++++++++++++++++++++ src/tabgan/sampler.py | 116 ++++----- src/tabgan/sklearn_transformer.py | 149 +++++++++++ src/tabgan/utils.py | 11 +- tests/test_cli.py | 168 ++++++++++--- tests/test_constraints.py | 151 +++++++++++ tests/test_generate_data_pipe.py | 210 ++++++++++++++++ tests/test_privacy_metrics.py | 114 +++++++++ tests/test_quality_report.py | 117 +++++++++ tests/test_sampler.py | 103 ++++---- tests/test_sklearn_transformer.py | 116 +++++++++ tests/test_utils.py | 4 +- tests/test_utils_extended.py | 163 ++++++++++++ 20 files changed, 2277 insertions(+), 331 deletions(-) create mode 100644 src/tabgan/constraints.py create mode 100644 src/tabgan/privacy_metrics.py create mode 100644 src/tabgan/quality_report.py create mode 100644 src/tabgan/sklearn_transformer.py create mode 100644 tests/test_constraints.py create mode 100644 tests/test_generate_data_pipe.py create mode 100644 tests/test_privacy_metrics.py create mode 100644 tests/test_quality_report.py create mode 100644 tests/test_sklearn_transformer.py create mode 100644 tests/test_utils_extended.py diff --git a/README.md b/README.md index 5bc2be9..3ce678d 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,45 @@ -[![CodeFactor](https://www.codefactor.io/repository/github/diyago/tabular-data-generation/badge)](https://www.codefactor.io/repository/github/diyago/tabular-data-generation) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![Python Version](https://img.shields.io/pypi/pyversions/tabgan)](https://pypi.org/project/tabgan/) -[![PyPI Version](https://img.shields.io/pypi/v/tabgan.svg)](https://pypi.org/project/tabgan/) -[![Downloads](https://pepy.tech/badge/tabgan)](https://pepy.tech/project/tabgan) -[![CodeQL](https://github.com/diyago/Tabular-data-generation/workflows/CodeQL/badge.svg)](https://github.com/diyago/Tabular-data-generation/actions/workflows/codeql.yml) +

+ TabGAN logo +

-# TabGAN - Synthetic Tabular Data Generation +

TabGAN

+

High-quality synthetic tabular data generation

-A powerful library for generating high-quality synthetic tabular data using state-of-the-art generative models including GANs, Diffusion models, and Large Language Models. +

+ PyPI Version + Python Version + Downloads + License + Code style: black + CodeFactor + CodeQL +

- +--- ## Overview -TabGAN is a comprehensive Python library that provides a unified interface for generating high-quality synthetic tabular data. It integrates multiple state-of-the-art generative approaches to address diverse data synthesis requirements: +TabGAN provides a unified Python interface for generating synthetic tabular data using multiple state-of-the-art generative approaches: -- **GANs**: Conditional Tabular GAN (CTGAN) for modeling complex multivariate distributions with mixed data types -- **Diffusion Models**: Forest Diffusion for high-fidelity synthetic data generation with tree-based gradient boosting -- **LLMs**: GReaT (Generative Realistic Tabular data) framework leveraging language models for realistic tabular data synthesis -- **Time-Series**: TimeGAN support for temporal data generation preserving sequential dependencies +| Approach | Backend | Strengths | +|----------|---------|-----------| +| **GANs** | Conditional Tabular GAN (CTGAN) | Mixed data types, complex multivariate distributions | +| **Diffusion Models** | ForestDiffusion (tree-based gradient boosting) | High-fidelity generation for structured data | +| **Large Language Models** | GReaT framework | Capturing semantic dependencies, conditional text generation | +| **Baseline** | Random sampling with replacement | Quick benchmarking and comparison | -*Related Research: [Tabular GANs for uneven distribution (arXiv:2010.00638)](https://arxiv.org/abs/2010.00638)* +All generators share a common pipeline: **generate → post-process → adversarial filter**, ensuring synthetic data stays close to the real data distribution. + +*Based on the paper: [Tabular GANs for uneven distribution](https://arxiv.org/abs/2010.00638) (arXiv:2010.00638)* ## Key Features -- **Multiple Generative Architectures**: Seamlessly switch between GANs, Diffusion Models, and LLMs via a unified API -- **Adversarial Filtering**: Built-in adversarial validation to ensure synthetic data preserves predictive utility -- **Mixed Data Type Support**: Native handling of continuous, categorical, and text columns -- **Conditional Generation**: Generate data conditioned on specific column values or distributions -- **Scalable Processing**: Efficient batch processing for large-scale datasets -- **Quality Validation**: Integrated metrics for comparing synthetic against original data distributions +- **Unified API** — switch between GANs, diffusion models, and LLMs with a single parameter change +- **Adversarial filtering** — built-in LightGBM-based validation keeps synthetic samples distribution-consistent +- **Mixed data types** — native handling of continuous, categorical, and free-text columns +- **Conditional generation** — generate text conditioned on categorical attributes via LLM prompting +- **LLM API support** — integrate with LM Studio, OpenAI, Ollama, or any OpenAI-compatible endpoint +- **Quality validation** — compare original and synthetic distributions with a single function call ## Installation @@ -45,12 +54,10 @@ import pandas as pd import numpy as np from tabgan.sampler import GANGenerator -# Create sample data train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD")) target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y")) test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD")) -# Generate synthetic data new_train, new_target = GANGenerator().generate_data_pipe(train, target, test) ``` @@ -59,28 +66,28 @@ new_train, new_target = GANGenerator().generate_data_pipe(train, target, test) | Generator | Description | Best For | |-----------|-------------|----------| | `GANGenerator` | CTGAN-based generation | General tabular data with mixed types | -| `ForestDiffusionGenerator` | Diffusion models + tree-based methods | Complex tabular structures | -| `LLMGenerator` | Large Language Model based | Capturing complex dependencies | -| `OriginalGenerator` | Baseline sampler | Baseline comparisons | +| `ForestDiffusionGenerator` | Diffusion models with tree-based methods | Complex tabular structures | +| `LLMGenerator` | Large Language Model based | Semantic dependencies, text columns | +| `OriginalGenerator` | Baseline random sampler | Benchmarking and comparison | ## API Reference -### Sampler Parameters +### Common Parameters -All generators accept these common parameters: +All generators accept the following parameters: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `gen_x_times` | float | 1.1 | Multiplier for data generation amount | -| `cat_cols` | list | None | Column names to treat as categorical | -| `bot_filter_quantile` | float | 0.001 | Bottom quantile for post-process filtering | -| `top_filter_quantile` | float | 0.999 | Top quantile for post-process filtering | -| `is_post_process` | bool | True | Enable post-filtering | -| `pregeneration_frac` | float | 2 | Pre-generation data multiplier | -| `only_generated_data` | bool | False | Return only generated data (no original) | -| `gen_params` | dict | See below | Generator-specific parameters | +| `gen_x_times` | `float` | `1.1` | Multiplier for synthetic sample count relative to training size | +| `cat_cols` | `list` | `None` | Column names to treat as categorical | +| `bot_filter_quantile` | `float` | `0.001` | Lower quantile for post-processing filters | +| `top_filter_quantile` | `float` | `0.999` | Upper quantile for post-processing filters | +| `is_post_process` | `bool` | `True` | Enable quantile-based post-filtering | +| `pregeneration_frac` | `float` | `2` | Oversampling factor before filtering | +| `only_generated_data` | `bool` | `False` | Return only synthetic rows (exclude originals) | +| `gen_params` | `dict` | See below | Generator-specific hyperparameters | -### Generator-Specific Parameters +### Generator-Specific Parameters (`gen_params`) **GANGenerator:** ```python @@ -92,31 +99,33 @@ All generators accept these common parameters: {"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500} ``` -### generate_data_pipe Method +### `generate_data_pipe` Method -| Parameter | Type | Description | -|-----------|------|-------------| -| `train_df` | pd.DataFrame | Training features | -| `target` | pd.DataFrame | Target variable | -| `test_df` | pd.DataFrame | Test features | -| `deep_copy` | bool | Make copy of input dataframes | -| `only_adversarial` | bool | Only perform adversarial filtering | -| `use_adversarial` | bool | Enable adversarial filtering | +```python +new_train, new_target = generator.generate_data_pipe( + train_df, # pd.DataFrame - training features + target, # pd.DataFrame - target variable (or None) + test_df, # pd.DataFrame - test features for distribution alignment + deep_copy=True, # bool - copy input DataFrames + only_adversarial=False, # bool - skip generation, only filter + use_adversarial=True, # bool - enable adversarial filtering +) +``` -**Returns:** `Tuple[pd.DataFrame, pd.DataFrame]` - (new_train, new_target) +**Returns:** `Tuple[pd.DataFrame, pd.DataFrame]` — `(new_train, new_target)` ## Data Format -TabGAN accepts both `numpy.ndarray` and `pandas.DataFrame` inputs, supporting: +TabGAN accepts `pandas.DataFrame` inputs with: -- **Continuous Variables**: Numerical columns with any real-valued data -- **Categorical Variables**: Discrete columns with a finite set of possible values +- **Continuous columns** — any real-valued numerical data +- **Categorical columns** — discrete columns with a finite set of values -> **Note:** TabGAN internally processes all values as floating-point numbers. For integer-valued outputs, apply rounding after generation. +> **Note:** TabGAN processes values as floating-point internally. Apply rounding after generation for integer-valued outputs. ## Examples -### Basic Usage +### Basic Usage with All Generators ```python from tabgan.sampler import OriginalGenerator, GANGenerator, ForestDiffusionGenerator, LLMGenerator @@ -127,11 +136,14 @@ train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("A target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y")) test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD")) -# Different generators new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test) -new_train2, new_target2 = GANGenerator(gen_params={"batch_size": 500, "epochs": 10, "patience": 5}).generate_data_pipe(train, target, test) +new_train2, new_target2 = GANGenerator( + gen_params={"batch_size": 500, "epochs": 10, "patience": 5} +).generate_data_pipe(train, target, test) new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test) -new_train4, new_target4 = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}).generate_data_pipe(train, target, test) +new_train4, new_target4 = LLMGenerator( + gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500} +).generate_data_pipe(train, target, test) ``` ### Full Parameter Example @@ -149,168 +161,129 @@ new_train, new_target = GANGenerator( }, pregeneration_frac=2, only_generated_data=False, - gen_params={"batch_size": 500, "patience": 25, "epochs": 500} + gen_params={"batch_size": 500, "patience": 25, "epochs": 500}, ).generate_data_pipe( train, target, test, deep_copy=True, only_adversarial=False, - use_adversarial=True + use_adversarial=True, ) ``` ### LLM Conditional Text Generation -Generate synthetic data with LLMs while controlling text generation based on specific conditions. This uses the internal `_generate_via_prompt` method for novel text generation. +Generate synthetic rows with novel text values conditioned on categorical attributes: ```python import pandas as pd from tabgan.sampler import LLMGenerator -# Create sample data with text and categorical columns train = pd.DataFrame({ "Name": ["Anna", "Maria", "Ivan", "Sergey", "Olga", "Boris"], "Gender": ["F", "F", "M", "M", "F", "M"], "Age": [25, 30, 35, 40, 28, 32], - "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"] + "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"], }) -# Generate new names conditioned on Gender, with other features imputed new_train, _ = LLMGenerator( - gen_x_times=1.5, # Generate 1.5x the original data - text_generating_columns=["Name"], # Generate novel names - conditional_columns=["Gender"], # Condition on Gender column + gen_x_times=1.5, + text_generating_columns=["Name"], # columns to generate novel text for + conditional_columns=["Gender"], # columns that condition text generation gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}, - is_post_process=False # Disable post-processing for this example -).generate_data_pipe( - train, - target=None, - test_df=None, - only_generated_data=True # Return only generated data -) - -print(new_train) + is_post_process=False, +).generate_data_pipe(train, target=None, test_df=None, only_generated_data=True) ``` -**Parameters for conditional generation:** -- `text_generating_columns`: List of column names to generate novel text for -- `conditional_columns`: List of column names that condition the text generation - -The model will: -1. Sample values for conditional columns from their distributions -2. Impute remaining non-text columns using the LLM -3. Generate novel text for text columns via prompt-based generation (using `_generate_via_prompt`) -4. Ensure generated text values are unique (not present in original data) +**How it works:** +1. Sample conditional column values from their empirical distributions +2. Impute remaining non-text columns using the fitted GReaT model +3. Generate novel text via prompt-based generation +4. Ensure generated text values differ from the original data ### LLM API-Based Text Generation -Use external LLM APIs (LM Studio, OpenAI, Ollama) for text generation instead of local models. This allows you to leverage powerful models running on remote servers or local API endpoints. +Use external LLM APIs (LM Studio, OpenAI, Ollama) instead of local models: ```python import pandas as pd from tabgan.sampler import LLMGenerator from tabgan.llm_config import LLMAPIConfig -# Create sample data train = pd.DataFrame({ "Name": ["Anna", "Maria", "Ivan", "Sergey", "Olga", "Boris"], "Gender": ["F", "F", "M", "M", "F", "M"], "Age": [25, 30, 35, 40, 28, 32], - "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"] + "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"], }) -# Configure API connection (LM Studio example) +# LM Studio api_config = LLMAPIConfig.from_lm_studio( base_url="http://localhost:1234", model="google/gemma-3-12b", - timeout=90 + timeout=90, ) -# Or use OpenAI -# api_config = LLMAPIConfig.from_openai( -# api_key="your-api-key", -# model="gpt-4" -# ) - -# Or use Ollama -# api_config = LLMAPIConfig.from_ollama( -# base_url="http://localhost:11434", -# model="llama3" -# ) +# Or OpenAI: LLMAPIConfig.from_openai(api_key="...", model="gpt-4") +# Or Ollama: LLMAPIConfig.from_ollama(model="llama3") -# Generate with API-based text generation new_train, _ = LLMGenerator( gen_x_times=1.5, text_generating_columns=["Name"], conditional_columns=["Gender"], gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}, - llm_api_config=api_config, # Use external API for text generation - is_post_process=False -).generate_data_pipe( - train, - target=None, - test_df=None, - only_generated_data=True -) - -print(new_train) + llm_api_config=api_config, + is_post_process=False, +).generate_data_pipe(train, target=None, test_df=None, only_generated_data=True) ``` -**Configuration Options:** +
+LLM API Configuration Options | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `base_url` | str | `"http://localhost:1234"` | Base URL for the API server | -| `model` | str | `"google/gemma-3-12b"` | Model identifier to use | -| `api_key` | str | None | API key for authentication (OpenAI, etc.) | -| `timeout` | int | 90 | Request timeout in seconds | -| `max_tokens` | int | 256 | Maximum tokens to generate | -| `temperature` | float | 0.7 | Sampling temperature (0.0-2.0) | -| `system_prompt` | str | None | System prompt to guide generation | - -**Supported API Providers:** -- **LM Studio**: Local LLM server with OpenAI-compatible API -- **OpenAI**: GPT-4, GPT-3.5, and other OpenAI models -- **Ollama**: Local LLM server for running open-source models -- **Any OpenAI-compatible API**: Custom endpoints with compatible schema - -**Testing API Connection:** +| `base_url` | `str` | `"http://localhost:1234"` | API server base URL | +| `model` | `str` | `"google/gemma-3-12b"` | Model identifier | +| `api_key` | `str` | `None` | API key for authentication | +| `timeout` | `int` | `90` | Request timeout in seconds | +| `max_tokens` | `int` | `256` | Maximum tokens to generate | +| `temperature` | `float` | `0.7` | Sampling temperature | +| `system_prompt` | `str` | `None` | System prompt for generation | + +**Testing the connection:** ```python from tabgan.llm_config import LLMAPIConfig from tabgan.llm_api_client import LLMAPIClient -# Test if API is accessible config = LLMAPIConfig.from_lm_studio() with LLMAPIClient(config) as client: - is_connected = client.check_connection() - print(f"API available: {is_connected}") - - # Generate text directly - text = client.generate("Generate a female name: ") - print(f"Generated: {text}") + print(f"API available: {client.check_connection()}") + print(f"Generated: {client.generate('Generate a female name: ')}") ``` +
+ ### Improving Model Performance ```python import sklearn +import pandas as pd from tabgan.sampler import GANGenerator def evaluate(clf, X_train, y_train, X_test, y_test): clf.fit(X_train, y_train) return sklearn.metrics.roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) -# Load dataset dataset = sklearn.datasets.load_breast_cancer() clf = sklearn.ensemble.RandomForestClassifier(n_estimators=25, max_depth=6) X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - pd.DataFrame(dataset.data), pd.DataFrame(dataset.target, columns=["target"]), - test_size=0.33, random_state=42) + pd.DataFrame(dataset.data), + pd.DataFrame(dataset.target, columns=["target"]), + test_size=0.33, random_state=42, +) -# Compare performance print("Baseline:", evaluate(clf, X_train, y_train, X_test, y_test)) -# Generate and evaluate new_train, new_target = GANGenerator().generate_data_pipe(X_train, y_train, X_test) print("With GAN:", evaluate(clf, new_train, new_target, X_test, y_test)) ``` @@ -324,42 +297,69 @@ from tabgan.utils import get_year_mnth_dt_from_date, collect_dates from tabgan.sampler import GANGenerator train = pd.DataFrame(np.random.randint(-10, 150, size=(100, 4)), columns=list("ABCD")) -min_date, max_date = pd.to_datetime('2019-01-01'), pd.to_datetime('2021-12-31') +min_date, max_date = pd.to_datetime("2019-01-01"), pd.to_datetime("2021-12-31") d = (max_date - min_date).days + 1 -train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=100), unit='d') -train = get_year_mnth_dt_from_date(train, 'Date') +train["Date"] = min_date + pd.to_timedelta(np.random.randint(d, size=100), unit="d") +train = get_year_mnth_dt_from_date(train, "Date") new_train, _ = GANGenerator( - gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001, - top_filter_quantile=0.999, is_post_process=True, pregeneration_frac=2, - only_generated_data=False -).generate_data_pipe( - train.drop('Date', axis=1), None, train.drop('Date', axis=1) -) + gen_x_times=1.1, cat_cols=["year"], + bot_filter_quantile=0.001, top_filter_quantile=0.999, + is_post_process=True, pregeneration_frac=2, +).generate_data_pipe(train.drop("Date", axis=1), None, train.drop("Date", axis=1)) + new_train = collect_dates(new_train) ``` ## Data Quality Validation -Validate the statistical fidelity of generated data using the built-in evaluation utilities: +Evaluate synthetic data fidelity with the built-in comparison utility: ```python from tabgan.utils import compare_dataframes -# Returns a quality score between 0 (low fidelity) and 1 (high fidelity) -quality_score = compare_dataframes(original_df, generated_df) +score = compare_dataframes(original_df, generated_df) # 0.0 (poor) to 1.0 (excellent) ``` -### Experimental Workflow +## Command-Line Interface + +```bash +tabgan-generate \ + --input-csv train.csv \ + --target-col target \ + --generator gan \ + --gen-x-times 1.5 \ + --cat-cols year,gender \ + --output-csv synthetic_train.csv +``` -![Experiment design and workflow](images/workflow.png) +## Pipeline Architecture + +``` +Input (train_df, target, test_df) + | + v +[Preprocess] --> Validate DataFrames, prepare columns + | + v +[Generate] --> CTGAN / ForestDiffusion / GReaT LLM / Random sampling + | + v +[Post-process] --> Quantile-based filtering against test distribution + | + v +[Adversarial Filter] --> LightGBM classifier removes dissimilar samples + | + v +Output (synthetic_df, synthetic_target) +``` ## Benchmark Results -The following table shows normalized ROC AUC scores (higher is better): +Normalized ROC AUC scores (higher is better): -| Dataset | None | GAN | Sample Original | -|---------|------|-----|-----------------| +| Dataset | No augmentation | GAN | Sample Original | +|---------|:-:|:-:|:-:| | credit | 0.997 | **0.998** | 0.997 | | employee | **0.986** | 0.966 | 0.972 | | mortgages | 0.984 | 0.964 | **0.988** | @@ -369,29 +369,24 @@ The following table shows normalized ROC AUC scores (higher is better): ## Citation -If you use TabGAN in your research, please cite: - ```bibtex @misc{ashrapov2020tabular, - title={Tabular GANs for uneven distribution}, - author={Insaf Ashrapov}, - year={2020}, - eprint={2010.00638}, - archivePrefix={arXiv}, - primaryClass={cs.LG} + title={Tabular GANs for uneven distribution}, + author={Insaf Ashrapov}, + year={2020}, + eprint={2010.00638}, + archivePrefix={arXiv}, + primaryClass={cs.LG} } ``` ## References -[1] Xu, L., & Veeramachaneni, K. (2018). *Synthesizing Tabular Data using Generative Adversarial Networks*. arXiv:1811.11264 [cs.LG]. - -[2] Jolicoeur-Martineau, A., Fatras, K., & Kachman, T. (2023). *Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees*. SamsungSAILMontreal/ForestDiffusion. - -[3] Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). *Modeling Tabular data using Conditional GAN*. NeurIPS. - -[4] Borisov, V., Sessler, K., Leemann, T., Pawelczyk, M., & Kasneci, G. (2023). *Language Models are Realistic Tabular Data Generators*. ICLR. +1. Xu, L., & Veeramachaneni, K. (2018). *Synthesizing Tabular Data using Generative Adversarial Networks*. arXiv:1811.11264. +2. Jolicoeur-Martineau, A., Fatras, K., & Kachman, T. (2023). *Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees*. SamsungSAILMontreal/ForestDiffusion. +3. Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). *Modeling Tabular data using Conditional GAN*. NeurIPS. +4. Borisov, V., Sessler, K., Leemann, T., Pawelczyk, M., & Kasneci, G. (2023). *Language Models are Realistic Tabular Data Generators*. ICLR. ## License -Apache License 2.0 - See [LICENSE](LICENSE) file for details. +Apache License 2.0 — see [LICENSE](LICENSE) for details. diff --git a/setup.cfg b/setup.cfg index eb3e55c..351b1e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,12 @@ platforms = any classifiers = Development Status :: 5 - Production/Stable Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 [options] zip_safe = False @@ -35,7 +41,7 @@ setup_requires = # tests_require = pytest; pytest-cov # Require a specific Python version, e.g. Python 2.7 or >= 3.4 # python_requires = >= 3.4 -python_requires = >= 3.5 +python_requires = >= 3.9 install_requires = pandas numpy>=2.0 @@ -48,6 +54,8 @@ install_requires = tqdm xgboost be-great>=0.0.13 + matplotlib>=3.5 + requests [options.packages.find] where = src diff --git a/src/tabgan/__init__.py b/src/tabgan/__init__.py index 7970967..7a39968 100644 --- a/src/tabgan/__init__.py +++ b/src/tabgan/__init__.py @@ -1,8 +1,19 @@ # -*- coding: utf-8 -*- -from pkg_resources import DistributionNotFound, get_distribution +from importlib.metadata import version, PackageNotFoundError from .sampler import OriginalGenerator, Sampler, GANGenerator, ForestDiffusionGenerator, LLMGenerator from .llm_config import LLMAPIConfig from .llm_api_client import LLMAPIClient +from .constraints import ( + Constraint, + RangeConstraint, + UniqueConstraint, + FormulaConstraint, + RegexConstraint, + ConstraintEngine, +) +from .privacy_metrics import PrivacyMetrics +from .quality_report import QualityReport +from .sklearn_transformer import TabGANTransformer __all__ = [ "OriginalGenerator", @@ -12,13 +23,18 @@ "LLMGenerator", "LLMAPIConfig", "LLMAPIClient", + "Constraint", + "RangeConstraint", + "UniqueConstraint", + "FormulaConstraint", + "RegexConstraint", + "ConstraintEngine", + "PrivacyMetrics", + "QualityReport", + "TabGANTransformer", ] try: - # Change here if project is renamed and does not equal the package name - dist_name = __name__ - __version__ = get_distribution(dist_name).version -except DistributionNotFound: + __version__ = version(__name__) +except PackageNotFoundError: __version__ = "unknown" -finally: - del get_distribution, DistributionNotFound diff --git a/src/tabgan/abc_sampler.py b/src/tabgan/abc_sampler.py index f8ae341..fd27bd5 100644 --- a/src/tabgan/abc_sampler.py +++ b/src/tabgan/abc_sampler.py @@ -1,7 +1,7 @@ import gc import logging from abc import ABC, abstractmethod -from typing import Tuple +from typing import List, Optional, Tuple from .utils import seed_everything import pandas as pd @@ -30,6 +30,7 @@ def generate_data_pipe( only_adversarial: bool = False, use_adversarial: bool = True, only_generated_data: bool = False, + constraints: Optional[List] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Defines logic for sampling @@ -41,6 +42,7 @@ def generate_data_pipe( @param use_adversarial: perform or not adversarial filtering @param only_generated_data: After generation get only newly generated, without concating input train dataframe. Only works for SamplerGAN or ForestDiffusionGenerator. + @param constraints: Optional list of Constraint instances to enforce on generated data. @return: Newly generated train dataframe and test data """ seed_everything() @@ -55,7 +57,7 @@ def generate_data_pipe( train_df.copy(), target.copy(), test_df ) else: - logging.info("Preprocessing input data with deep copying input data.") + logging.info("Preprocessing input data without deep copying.") new_train, new_target, test_df = generator.preprocess_data( train_df, target, test_df ) @@ -76,6 +78,19 @@ def generate_data_pipe( new_train, new_target = generator.adversarial_filtering( new_train, new_target, test_df ) + if constraints: + from .constraints import ConstraintEngine + logging.info("Applying constraints") + engine = ConstraintEngine(constraints, strategy="fix") + # Temporarily attach target to keep rows aligned + target_col = "__constraint_target__" + if new_target is not None: + new_train[target_col] = new_target.values if hasattr(new_target, 'values') else new_target + new_train = engine.apply(new_train) + if new_target is not None: + new_target = new_train[target_col].reset_index(drop=True) + new_train = new_train.drop(columns=[target_col]).reset_index(drop=True) + gc.collect() logging.info("Total finishing, returning data") diff --git a/src/tabgan/constraints.py b/src/tabgan/constraints.py new file mode 100644 index 0000000..8cb7a2c --- /dev/null +++ b/src/tabgan/constraints.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +Constraint system for enforcing business rules on generated data. + +Constraints are applied as a post-generation step — after the main +generation pipeline produces synthetic rows, the ConstraintEngine filters +or repairs rows that violate the declared rules. +""" + +import logging +import re +from abc import ABC, abstractmethod +from typing import List, Optional + +import pandas as pd + +__all__ = [ + "Constraint", + "RangeConstraint", + "UniqueConstraint", + "FormulaConstraint", + "RegexConstraint", + "ConstraintEngine", +] + + +class Constraint(ABC): + """Base class for data constraints.""" + + @abstractmethod + def is_satisfied(self, df: pd.DataFrame) -> pd.Series: + """Return a boolean Series — True for rows that satisfy the constraint.""" + raise NotImplementedError + + @abstractmethod + def fix(self, df: pd.DataFrame) -> pd.DataFrame: + """Attempt to repair violating rows in-place and return the DataFrame.""" + raise NotImplementedError + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class RangeConstraint(Constraint): + """Enforce numeric column values within [min_val, max_val].""" + + def __init__(self, column: str, min_val: float = None, max_val: float = None): + if min_val is None and max_val is None: + raise ValueError("At least one of min_val or max_val must be specified") + self.column = column + self.min_val = min_val + self.max_val = max_val + + def is_satisfied(self, df: pd.DataFrame) -> pd.Series: + col = df[self.column] + mask = pd.Series(True, index=df.index) + if self.min_val is not None: + mask &= col >= self.min_val + if self.max_val is not None: + mask &= col <= self.max_val + return mask + + def fix(self, df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + df[self.column] = df[self.column].clip(lower=self.min_val, upper=self.max_val) + return df + + def __repr__(self) -> str: + return f"RangeConstraint(column={self.column!r}, min={self.min_val}, max={self.max_val})" + + +class UniqueConstraint(Constraint): + """Enforce uniqueness of values in a column (drop duplicate rows).""" + + def __init__(self, column: str): + self.column = column + + def is_satisfied(self, df: pd.DataFrame) -> pd.Series: + return ~df[self.column].duplicated(keep="first") + + def fix(self, df: pd.DataFrame) -> pd.DataFrame: + return df.drop_duplicates(subset=[self.column], keep="first").reset_index(drop=True) + + def __repr__(self) -> str: + return f"UniqueConstraint(column={self.column!r})" + + +class FormulaConstraint(Constraint): + """Enforce a boolean expression evaluated via ``pd.DataFrame.eval``. + + Example expressions: + - ``"end_date > start_date"`` + - ``"price * quantity == total"`` + - ``"age >= 0"`` + """ + + def __init__(self, expression: str): + self.expression = expression + + def is_satisfied(self, df: pd.DataFrame) -> pd.Series: + return df.eval(self.expression) + + def fix(self, df: pd.DataFrame) -> pd.DataFrame: + mask = self.is_satisfied(df) + return df[mask].reset_index(drop=True) + + def __repr__(self) -> str: + return f"FormulaConstraint({self.expression!r})" + + +class RegexConstraint(Constraint): + """Enforce that string values in a column match a regular expression.""" + + def __init__(self, column: str, pattern: str): + self.column = column + self.pattern = pattern + self._compiled = re.compile(pattern) + + def is_satisfied(self, df: pd.DataFrame) -> pd.Series: + return df[self.column].astype(str).str.fullmatch(self.pattern).fillna(False) + + def fix(self, df: pd.DataFrame) -> pd.DataFrame: + mask = self.is_satisfied(df) + return df[mask].reset_index(drop=True) + + def __repr__(self) -> str: + return f"RegexConstraint(column={self.column!r}, pattern={self.pattern!r})" + + +class ConstraintEngine: + """Apply a list of constraints to a DataFrame. + + Args: + constraints: List of ``Constraint`` instances to enforce. + strategy: ``"filter"`` drops violating rows; ``"fix"`` attempts + repair first, then filters remaining violations. + """ + + def __init__(self, constraints: List[Constraint], strategy: str = "filter"): + if strategy not in ("filter", "fix"): + raise ValueError(f"strategy must be 'filter' or 'fix', got {strategy!r}") + self.constraints = constraints + self.strategy = strategy + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + initial_len = len(df) + for constraint in self.constraints: + if self.strategy == "fix": + df = constraint.fix(df) + # After fix (or directly if filter), drop remaining violations + mask = constraint.is_satisfied(df) + df = df[mask].reset_index(drop=True) + + dropped = initial_len - len(df) + if dropped > 0: + logging.info(f"ConstraintEngine: dropped {dropped} rows ({dropped / initial_len:.1%})") + return df diff --git a/src/tabgan/encoders.py b/src/tabgan/encoders.py index 2c209f2..626515a 100644 --- a/src/tabgan/encoders.py +++ b/src/tabgan/encoders.py @@ -228,7 +228,7 @@ def __init__(self, cols): self.cols = cols self.counts_dict = None - def fit(self, X: pd.DataFrame): + def fit(self, X: pd.DataFrame, y=None): counts_dict = {} for col in self.cols: values, counts = np.unique(X[col], return_counts=True) diff --git a/src/tabgan/privacy_metrics.py b/src/tabgan/privacy_metrics.py new file mode 100644 index 0000000..a4d6b93 --- /dev/null +++ b/src/tabgan/privacy_metrics.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +""" +Privacy metrics for assessing re-identification risk in synthetic data. + +Provides Distance to Closest Record (DCR), Nearest Neighbor Distance Ratio +(NNDR), and a membership inference risk score. +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import OrdinalEncoder, StandardScaler + +__all__ = ["PrivacyMetrics"] + + +def _encode_for_distance( + original: pd.DataFrame, + synthetic: pd.DataFrame, + cat_cols: Optional[List[str]] = None, +) -> tuple: + """Encode and scale DataFrames for distance computation.""" + original = original.copy() + synthetic = synthetic.copy() + + if cat_cols: + encoder = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1) + original[cat_cols] = encoder.fit_transform(original[cat_cols].astype(str)) + synthetic[cat_cols] = encoder.transform(synthetic[cat_cols].astype(str)) + + # Fill NaN with column medians + for col in original.columns: + if original[col].isna().any(): + med = original[col].median() + original[col] = original[col].fillna(med) + synthetic[col] = synthetic[col].fillna(med) + + scaler = StandardScaler() + orig_scaled = scaler.fit_transform(original.select_dtypes(include=[np.number])) + synth_scaled = scaler.transform(synthetic.select_dtypes(include=[np.number])) + + return orig_scaled, synth_scaled + + +class PrivacyMetrics: + """Evaluate privacy risk of synthetic data relative to original data. + + Args: + original_df: The real / training DataFrame. + synthetic_df: The generated / synthetic DataFrame. + cat_cols: Names of categorical columns (encoded before distance computation). + + Example:: + + from tabgan.privacy_metrics import PrivacyMetrics + pm = PrivacyMetrics(original_df, synthetic_df, cat_cols=["gender"]) + print(pm.summary()) + """ + + def __init__( + self, + original_df: pd.DataFrame, + synthetic_df: pd.DataFrame, + cat_cols: Optional[List[str]] = None, + ): + shared_cols = [c for c in original_df.columns if c in synthetic_df.columns] + self.original_df = original_df[shared_cols].copy() + self.synthetic_df = synthetic_df[shared_cols].copy() + self.cat_cols = [c for c in (cat_cols or []) if c in shared_cols] + self._orig_scaled, self._synth_scaled = _encode_for_distance( + self.original_df, self.synthetic_df, self.cat_cols + ) + + # ------------------------------------------------------------------ + # DCR — Distance to Closest Record + # ------------------------------------------------------------------ + def dcr(self, sample_size: Optional[int] = None) -> Dict: + """Compute the distance from each synthetic row to the nearest original row. + + Higher distances indicate better privacy (synthetic rows are not + trivially close to any real record). + + Returns: + dict with ``mean``, ``median``, ``5th_percentile``, and ``distances``. + """ + synth = self._synth_scaled + if sample_size and sample_size < len(synth): + idx = np.random.choice(len(synth), sample_size, replace=False) + synth = synth[idx] + + nn = NearestNeighbors(n_neighbors=1, algorithm="auto") + nn.fit(self._orig_scaled) + distances, _ = nn.kneighbors(synth) + distances = distances.ravel() + + return { + "mean": float(np.mean(distances)), + "median": float(np.median(distances)), + "5th_percentile": float(np.percentile(distances, 5)), + "distances": distances, + } + + # ------------------------------------------------------------------ + # NNDR — Nearest Neighbor Distance Ratio + # ------------------------------------------------------------------ + def nndr(self, sample_size: Optional[int] = None) -> Dict: + """Nearest-neighbor distance ratio for each synthetic row. + + Ratio = dist(nearest_original) / dist(2nd_nearest_original). + A ratio close to 1 means the synthetic row is equidistant to + multiple originals (lower risk); a ratio near 0 means it is + suspiciously close to exactly one real record. + + Returns: + dict with ``mean``, ``median``, and ``ratios``. + """ + synth = self._synth_scaled + if sample_size and sample_size < len(synth): + idx = np.random.choice(len(synth), sample_size, replace=False) + synth = synth[idx] + + k = min(2, len(self._orig_scaled)) + nn = NearestNeighbors(n_neighbors=k, algorithm="auto") + nn.fit(self._orig_scaled) + distances, _ = nn.kneighbors(synth) + + if k < 2: + ratios = np.ones(len(synth)) + else: + d1 = distances[:, 0] + d2 = np.where(distances[:, 1] == 0, 1e-10, distances[:, 1]) + ratios = d1 / d2 + + return { + "mean": float(np.mean(ratios)), + "median": float(np.median(ratios)), + "ratios": ratios, + } + + # ------------------------------------------------------------------ + # Membership Inference Risk + # ------------------------------------------------------------------ + def membership_inference_risk(self, holdout_frac: float = 0.3) -> Dict: + """Estimate membership inference risk. + + Splits the original data into a *member* set (simulating the training + data the generator saw) and a *holdout* set. If the generator + memorised the members, synthetic rows will be closer to members than + to holdout rows. The risk is quantified as the AUC of a simple + classifier trained on this distance signal. + + Returns: + dict with ``auc`` (0.5 = good privacy, 1.0 = full memorisation) + and ``accuracy``. + """ + from sklearn.metrics import roc_auc_score + from sklearn.model_selection import cross_val_predict + from sklearn.linear_model import LogisticRegression + + n = len(self._orig_scaled) + n_holdout = max(int(n * holdout_frac), 1) + perm = np.random.permutation(n) + member_idx = perm[n_holdout:] + holdout_idx = perm[:n_holdout] + + members = self._orig_scaled[member_idx] + holdout = self._orig_scaled[holdout_idx] + + # For each original row, compute distance to nearest synthetic + nn = NearestNeighbors(n_neighbors=1, algorithm="auto") + nn.fit(self._synth_scaled) + + d_members, _ = nn.kneighbors(members) + d_holdout, _ = nn.kneighbors(holdout) + + X = np.concatenate([d_members.ravel(), d_holdout.ravel()]).reshape(-1, 1) + y = np.concatenate([np.ones(len(d_members)), np.zeros(len(d_holdout))]) + + if len(np.unique(y)) < 2: + return {"auc": 0.5, "accuracy": 0.5} + + clf = LogisticRegression(solver="lbfgs", max_iter=200) + try: + proba = cross_val_predict(clf, X, y, cv=min(3, len(y)), method="predict_proba")[:, 1] + auc = float(roc_auc_score(y, proba)) + except Exception: + auc = 0.5 + + accuracy = float(np.mean((proba > 0.5) == y)) if 'proba' in dir() else 0.5 + + return {"auc": auc, "accuracy": accuracy} + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + def summary(self) -> Dict: + """Aggregate all privacy metrics into a single report. + + The ``overall_privacy_score`` ranges from 0 (high risk) to 1 (private). + """ + dcr_res = self.dcr() + nndr_res = self.nndr() + mi_res = self.membership_inference_risk() + + # Score components (each normalised to 0-1, higher = more private) + # DCR: 5th percentile > 0 is good; cap contribution at 1 + dcr_score = min(dcr_res["5th_percentile"], 1.0) + + # NNDR: mean closer to 1 is better + nndr_score = min(nndr_res["mean"], 1.0) + + # MI: AUC closer to 0.5 is better → score = 1 - 2*|AUC - 0.5| + mi_score = max(1.0 - 2.0 * abs(mi_res["auc"] - 0.5), 0.0) + + overall = 0.4 * dcr_score + 0.3 * nndr_score + 0.3 * mi_score + + return { + "dcr": {k: v for k, v in dcr_res.items() if k != "distances"}, + "nndr": {k: v for k, v in nndr_res.items() if k != "ratios"}, + "membership_inference": mi_res, + "overall_privacy_score": round(overall, 4), + } diff --git a/src/tabgan/quality_report.py b/src/tabgan/quality_report.py new file mode 100644 index 0000000..b12a7f5 --- /dev/null +++ b/src/tabgan/quality_report.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- +""" +Quality report for comparing original and synthetic DataFrames. + +Generates a self-contained HTML file with distribution comparisons, +correlation analysis, PSI scores, and an ML utility benchmark. +""" + +import base64 +import io +import logging +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, accuracy_score +from sklearn.model_selection import cross_val_predict + +from tabgan.utils import calculate_psi + +__all__ = ["QualityReport"] + +_HTML_TEMPLATE = """\ + + + + +TabGAN Quality Report + + + +

TabGAN Quality Report

+

Original: {n_orig} rows × {n_cols} cols  |  + Synthetic: {n_synth} rows × {n_cols} cols

+ +
+
+
Overall Score
+
{overall_score:.2f}
+
+
+
ML Utility
+
{ml_utility:.2f}
+
+
+
Mean PSI
+
{mean_psi:.3f}
+
+
+ +

Column Statistics

+{column_stats_html} + +

PSI per Column

+{psi_html} + +

Correlation Comparison

+
+{corr_images} +
+ +

Distribution Comparison

+
+{dist_images} +
+ +

ML Utility Detail

+{ml_detail_html} + +
+

Generated by TabGAN QualityReport

+ + +""" + + +def _fig_to_base64(fig) -> str: + """Render a matplotlib figure to a base64-encoded PNG data URI.""" + buf = io.BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", dpi=100) + buf.seek(0) + b64 = base64.b64encode(buf.read()).decode("utf-8") + buf.close() + return f"data:image/png;base64,{b64}" + + +class QualityReport: + """Compare original and synthetic DataFrames across multiple quality axes. + + Args: + original_df: Real data. + synthetic_df: Synthetic data (must have the same columns). + cat_cols: Names of categorical columns (affects stats display). + target_col: Optional target column name for ML utility evaluation. + + Example:: + + report = QualityReport(orig, synth, cat_cols=["gender"]).compute() + report.to_html("report.html") + print(report.summary()) + """ + + def __init__( + self, + original_df: pd.DataFrame, + synthetic_df: pd.DataFrame, + cat_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + ): + shared = [c for c in original_df.columns if c in synthetic_df.columns] + self.original = original_df[shared].copy() + self.synthetic = synthetic_df[shared].copy() + self.cat_cols = [c for c in (cat_cols or []) if c in shared] + self.target_col = target_col + self._results: Dict = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def compute(self) -> "QualityReport": + """Run all metrics and store results internally.""" + self._results["column_stats"] = self._column_stats() + self._results["psi"] = self._psi_per_column() + self._results["correlation"] = self._correlation_comparison() + self._results["ml_utility"] = self._ml_utility_score() + self._results["overall"] = self._overall_score() + return self + + def summary(self) -> Dict: + """Return a dict of all metric results (without images).""" + if not self._results: + self.compute() + return { + "column_stats": self._results["column_stats"], + "psi": {k: v for k, v in self._results["psi"].items()}, + "ml_utility": self._results["ml_utility"], + "overall_score": self._results["overall"], + } + + def to_html(self, path: str) -> None: + """Write a self-contained HTML report to *path*.""" + if not self._results: + self.compute() + + psi = self._results["psi"] + ml = self._results["ml_utility"] + + # Column stats table + stats = self._results["column_stats"] + col_rows = [] + for col, s in stats.items(): + if s["dtype"] == "numeric": + col_rows.append( + f"{col}numeric" + f"{s['orig_mean']:.3f}{s['synth_mean']:.3f}" + f"{s['orig_std']:.3f}{s['synth_std']:.3f}" + ) + else: + col_rows.append( + f"{col}categorical" + f"unique: {s['orig_nunique']}" + f"unique: {s['synth_nunique']}" + ) + col_stats_html = ( + "" + "" + "" + + "\n".join(col_rows) + "
ColumnTypeOrig Mean/InfoSynth Mean/InfoOrig Std/InfoSynth Std/Info
" + ) + + # PSI table + psi_rows = "".join( + f"{col}{val:.4f}" + for col, val in psi.items() if col != "mean" + ) + psi_html = ( + f"{psi_rows}
ColumnPSI
" + ) + + # Correlation images + corr_imgs = self._render_correlation_plots() + corr_html = "\n".join( + f'{label}' + for label, src in corr_imgs.items() + ) + + # Distribution images + dist_imgs = self._render_distribution_plots() + dist_html = "\n".join( + f'{col}' + for col, src in dist_imgs.items() + ) + + # ML utility detail + ml_html = ( + f"

Train-on-Synthetic, Test-on-Real AUC: {ml.get('tstr_auc', 'N/A')}

" + f"

Train-on-Real, Test-on-Real AUC (baseline): {ml.get('trtr_auc', 'N/A')}

" + f"

Utility ratio: {ml.get('utility_ratio', 'N/A')}

" + ) + + numeric_psi = [v for k, v in psi.items() if k != "mean" and isinstance(v, (int, float))] + mean_psi = float(np.mean(numeric_psi)) if numeric_psi else 0.0 + + html = _HTML_TEMPLATE.format( + n_orig=len(self.original), + n_synth=len(self.synthetic), + n_cols=len(self.original.columns), + overall_score=self._results["overall"], + ml_utility=ml.get("utility_ratio", 0.0), + mean_psi=mean_psi, + column_stats_html=col_stats_html, + psi_html=psi_html, + corr_images=corr_html, + dist_images=dist_html, + ml_detail_html=ml_html, + ) + + with open(path, "w", encoding="utf-8") as fh: + fh.write(html) + logging.info(f"Quality report written to {path}") + + # ------------------------------------------------------------------ + # Private metric methods + # ------------------------------------------------------------------ + def _column_stats(self) -> Dict: + stats = {} + for col in self.original.columns: + if col in self.cat_cols or not pd.api.types.is_numeric_dtype(self.original[col]): + stats[col] = { + "dtype": "categorical", + "orig_nunique": int(self.original[col].nunique()), + "synth_nunique": int(self.synthetic[col].nunique()), + } + else: + stats[col] = { + "dtype": "numeric", + "orig_mean": float(self.original[col].mean()), + "synth_mean": float(self.synthetic[col].mean()), + "orig_std": float(self.original[col].std()), + "synth_std": float(self.synthetic[col].std()), + "orig_min": float(self.original[col].min()), + "synth_min": float(self.synthetic[col].min()), + "orig_max": float(self.original[col].max()), + "synth_max": float(self.synthetic[col].max()), + } + return stats + + def _psi_per_column(self) -> Dict: + psi_dict = {} + numeric_cols = self.original.select_dtypes(include=[np.number]).columns + for col in numeric_cols: + try: + val = float(calculate_psi( + self.original[col].values, + self.synthetic[col].values, + buckets=10, + )) + except Exception: + val = float("nan") + psi_dict[col] = val + vals = [v for v in psi_dict.values() if np.isfinite(v)] + psi_dict["mean"] = float(np.mean(vals)) if vals else 0.0 + return psi_dict + + def _correlation_comparison(self) -> Dict: + numeric_orig = self.original.select_dtypes(include=[np.number]) + numeric_synth = self.synthetic.select_dtypes(include=[np.number]) + if numeric_orig.shape[1] < 2: + return {"orig_corr": None, "synth_corr": None, "diff_mean": 0.0} + orig_corr = numeric_orig.corr() + synth_corr = numeric_synth.corr() + diff = (orig_corr - synth_corr).abs() + return { + "orig_corr": orig_corr, + "synth_corr": synth_corr, + "diff_mean": float(diff.mean().mean()), + } + + def _ml_utility_score(self) -> Dict: + """Train-on-Synthetic, Test-on-Real (TSTR) evaluation.""" + from sklearn.ensemble import GradientBoostingClassifier + + target = self.target_col + if target is None or target not in self.original.columns or target not in self.synthetic.columns: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + orig_num = self.original.select_dtypes(include=[np.number]) + synth_num = self.synthetic.select_dtypes(include=[np.number]) + + if target not in orig_num.columns: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + X_real = orig_num.drop(columns=[target]).values + y_real = orig_num[target].values + X_synth = synth_num.drop(columns=[target]).values + y_synth = synth_num[target].values + + if len(np.unique(y_real)) < 2 or len(np.unique(y_synth)) < 2: + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + try: + clf_real = GradientBoostingClassifier(n_estimators=50, max_depth=3, random_state=42) + proba_real = cross_val_predict(clf_real, X_real, y_real, cv=3, method="predict_proba")[:, 1] + trtr_auc = float(roc_auc_score(y_real, proba_real)) + + clf_synth = GradientBoostingClassifier(n_estimators=50, max_depth=3, random_state=42) + clf_synth.fit(X_synth, y_synth) + proba_synth = clf_synth.predict_proba(X_real)[:, 1] + tstr_auc = float(roc_auc_score(y_real, proba_synth)) + + ratio = tstr_auc / trtr_auc if trtr_auc > 0 else 0.0 + except Exception as e: + logging.warning(f"ML utility evaluation failed: {e}") + return {"tstr_auc": None, "trtr_auc": None, "utility_ratio": 0.0} + + return { + "tstr_auc": round(tstr_auc, 4), + "trtr_auc": round(trtr_auc, 4), + "utility_ratio": round(min(ratio, 1.0), 4), + } + + def _overall_score(self) -> float: + psi_vals = [v for k, v in self._results["psi"].items() + if k != "mean" and isinstance(v, (int, float)) and np.isfinite(v)] + # PSI component: lower is better → score = 1/(1+mean_psi) + mean_psi = float(np.mean(psi_vals)) if psi_vals else 0.0 + psi_score = 1.0 / (1.0 + mean_psi) + + # Correlation component + corr_diff = self._results["correlation"]["diff_mean"] + corr_score = max(1.0 - corr_diff, 0.0) + + # ML utility component + ml_ratio = self._results["ml_utility"].get("utility_ratio", 0.0) or 0.0 + + overall = 0.35 * psi_score + 0.30 * corr_score + 0.35 * ml_ratio + return round(min(max(overall, 0.0), 1.0), 4) + + # ------------------------------------------------------------------ + # Chart renderers (require matplotlib — lazy import) + # ------------------------------------------------------------------ + def _render_correlation_plots(self) -> Dict[str, str]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + logging.warning("matplotlib not installed — skipping correlation plots") + return {} + + corr = self._results["correlation"] + if corr["orig_corr"] is None: + return {} + + images = {} + for label, matrix in [("Original", corr["orig_corr"]), ("Synthetic", corr["synth_corr"])]: + fig, ax = plt.subplots(figsize=(5, 4)) + im = ax.imshow(matrix.values, cmap="RdBu_r", vmin=-1, vmax=1) + ax.set_xticks(range(len(matrix.columns))) + ax.set_yticks(range(len(matrix.columns))) + ax.set_xticklabels(matrix.columns, rotation=45, ha="right", fontsize=7) + ax.set_yticklabels(matrix.columns, fontsize=7) + ax.set_title(f"{label} Correlation") + fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + images[label] = _fig_to_base64(fig) + plt.close(fig) + + return images + + def _render_distribution_plots(self) -> Dict[str, str]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + logging.warning("matplotlib not installed — skipping distribution plots") + return {} + + images = {} + for col in self.original.columns: + if not pd.api.types.is_numeric_dtype(self.original[col]): + continue + fig, ax = plt.subplots(figsize=(4, 3)) + ax.hist(self.original[col].dropna(), bins=30, alpha=0.5, label="Original", density=True) + ax.hist(self.synthetic[col].dropna(), bins=30, alpha=0.5, label="Synthetic", density=True) + ax.set_title(col, fontsize=10) + ax.legend(fontsize=7) + images[col] = _fig_to_base64(fig) + plt.close(fig) + + return images diff --git a/src/tabgan/sampler.py b/src/tabgan/sampler.py index a634d21..ce123cf 100644 --- a/src/tabgan/sampler.py +++ b/src/tabgan/sampler.py @@ -27,40 +27,32 @@ __all__ = ["OriginalGenerator", "GANGenerator", "ForestDiffusionGenerator", "LLMGenerator"] -class OriginalGenerator(SampleData): +class _BaseGenerator(SampleData): + """Base factory that stores constructor arguments for the concrete sampler.""" + _sampler_class = None + def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs def get_object_generator(self) -> Sampler: - return SamplerOriginal(*self.args, **self.kwargs) + return self._sampler_class(*self.args, **self.kwargs) -class GANGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs +class OriginalGenerator(_BaseGenerator): + _sampler_class = None # set after SamplerOriginal is defined - def get_object_generator(self) -> Sampler: - return SamplerGAN(*self.args, **self.kwargs) +class GANGenerator(_BaseGenerator): + _sampler_class = None -class ForestDiffusionGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - def get_object_generator(self) -> Sampler: - return SamplerDiffusion(*self.args, **self.kwargs) +class ForestDiffusionGenerator(_BaseGenerator): + _sampler_class = None -class LLMGenerator(SampleData): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - def get_object_generator(self) -> Sampler: - return SamplerLLM(*self.args, **self.kwargs) +class LLMGenerator(_BaseGenerator): + _sampler_class = None class SamplerOriginal(Sampler): @@ -71,17 +63,10 @@ def __init__( bot_filter_quantile: float = 0.001, top_filter_quantile: float = 0.999, is_post_process: bool = True, - adversarial_model_params: dict = { - "metrics": "AUC", - "max_depth": 2, - "max_bin": 100, - "n_estimators": 150, - "learning_rate": 0.02, - "random_state": 42, - }, + adversarial_model_params: dict = None, pregeneration_frac: float = 2, only_generated_data: bool = False, - gen_params: dict = {"batch_size": 45, 'patience': 25, "epochs": 50, "llm": "distilgpt2"}, + gen_params: dict = None, text_generating_columns: list = None, conditional_columns: list = None, llm_api_config: LLMAPIConfig = None, @@ -121,6 +106,17 @@ def __init__( the API instead of the local model. Useful for LM Studio, Ollama, OpenAI, etc. """ + if adversarial_model_params is None: + adversarial_model_params = { + "metrics": "AUC", + "max_depth": 2, + "max_bin": 100, + "n_estimators": 150, + "learning_rate": 0.02, + "random_state": 42, + } + if gen_params is None: + gen_params = {"batch_size": 45, "patience": 25, "epochs": 50, "llm": "distilgpt2"} super().__init__( gen_x_times=gen_x_times, cat_cols=cat_cols, @@ -140,10 +136,10 @@ def __init__( @staticmethod def preprocess_data_df(df) -> pd.DataFrame: - logging.info("Input shape: {}".format(df.shape)) - if isinstance(df, pd.DataFrame) is False: + logging.info(f"Input shape: {df.shape}") + if not isinstance(df, pd.DataFrame): raise ValueError( - "Input dataframe aren't pandas dataframes: df is {}".format(type(df)) + f"Input dataframe is not a pandas DataFrame: got {type(df)}" ) return df @@ -156,9 +152,7 @@ def preprocess_data( self.TEMP_TARGET = target.columns[0] if self.TEMP_TARGET in train.columns: raise ValueError( - "Input train dataframe already have {} column, consider removing it".format( - self.TEMP_TARGET - ) + f"Input train dataframe already has '{self.TEMP_TARGET}' column, consider removing it" ) if "test_similarity" in train.columns: raise ValueError( @@ -171,7 +165,7 @@ def generate_data( self, train_df, target, test_df, only_generated_data ) -> Tuple[pd.DataFrame, pd.DataFrame]: if only_generated_data: - Warning( + warnings.warn( "For SamplerOriginal setting only_generated_data doesn't change anything, " "because generated data sampled from the train!" ) @@ -183,10 +177,8 @@ def generate_data( generated_df = generated_df.reset_index(drop=True) logging.info( - "Generated shape: {} and {}".format( - generated_df.drop(self.TEMP_TARGET, axis=1).shape, - generated_df[self.TEMP_TARGET].shape, - ) + f"Generated shape: {generated_df.drop(self.TEMP_TARGET, axis=1).shape} " + f"and {generated_df[self.TEMP_TARGET].shape}" ) return ( generated_df.drop(self.TEMP_TARGET, axis=1), @@ -258,17 +250,14 @@ def _validate_data(train_df, target, test_df): if test_df is not None: if train_df.shape[0] < 10 or test_df.shape[0] < 10: raise ValueError( - "Shape of train is {} and test is {}. Both should at least 10! " - "Consider disabling adversarial filtering".format( - train_df.shape[0], test_df.shape[0] - ) + f"Shape of train is {train_df.shape[0]} and test is {test_df.shape[0]}. " + f"Both should be at least 10! Consider disabling adversarial filtering" ) if target is not None: if train_df.shape[0] != target.shape[0]: raise ValueError( - "Something gone wrong: shape of train_df = {} is not equal to target = {} shape".format( - train_df.shape[0], target.shape[0] - ) + f"Shape mismatch: train_df has {train_df.shape[0]} rows " + f"but target has {target.shape[0]} rows" ) def handle_generated_data(self, train_df, generated_df, only_generated_data): @@ -302,9 +291,7 @@ def handle_generated_data(self, train_df, generated_df, only_generated_data): if not only_generated_data: train_df = pd.concat([train_df, generated_df]).reset_index(drop=True) logging.info( - "Generated shapes: {} plus target".format( - _drop_col_if_exist(train_df, self.TEMP_TARGET).shape - ) + f"Generated shapes: {_drop_col_if_exist(train_df, self.TEMP_TARGET).shape} plus target" ) return ( _drop_col_if_exist(train_df, self.TEMP_TARGET), @@ -312,9 +299,7 @@ def handle_generated_data(self, train_df, generated_df, only_generated_data): ) else: logging.info( - "Generated shapes: {} plus target".format( - _drop_col_if_exist(generated_df, self.TEMP_TARGET).shape - ) + f"Generated shapes: {_drop_col_if_exist(generated_df, self.TEMP_TARGET).shape} plus target" ) return ( _drop_col_if_exist(generated_df, self.TEMP_TARGET), @@ -326,16 +311,15 @@ class SamplerGAN(SamplerOriginal): def check_params(self): if self.gen_params["batch_size"] % 10 != 0: logging.warning( - "Batch size should be divisible to 10, but provided {}. Fixing it".format( - self.gen_params["batch_size"])) + f"Batch size should be divisible by 10, but got {self.gen_params['batch_size']}. Fixing it") self.gen_params["batch_size"] += 10 - (self.gen_params["batch_size"] % 10) if "patience" not in self.gen_params: - logging.warning("patience param is not set for GAN params, so setting it to default ""25""") + logging.warning("patience param is not set for GAN params, setting default to 25") self.gen_params["patience"] = 25 if "epochs" not in self.gen_params: - logging.warning("patience param is not set for GAN params, so setting it to default ""50""") + logging.warning("epochs param is not set for GAN params, setting default to 50") self.gen_params["epochs"] = 50 def generate_data( @@ -389,16 +373,15 @@ def get_column_indexes(df, column_names): class SamplerLLM(SamplerOriginal): def check_params(self): if "llm" not in self.gen_params: - logging.warning("llm param is not set for LLM params, so setting it to default ""distilgpt2""") + logging.warning("llm param is not set for LLM params, setting default to 'distilgpt2'") self.gen_params["llm"] = "distilgpt2" if "max_length" not in self.gen_params: - logging.warning("max_length param is not set for LLM params, so setting it to default ""500""") - self.gen_params["max_length"] = "500" + logging.warning("max_length param is not set for LLM params, setting default to 500") + self.gen_params["max_length"] = 500 if self.gen_params["epochs"] < 3: logging.warning( - "Current set epoch = {} for llm training is too low, setting to 3!""".format( - self.gen_params["epochs"])) + f"Current epoch={self.gen_params['epochs']} for LLM training is too low, setting to 3") self.gen_params["epochs"] = 3 def _build_training_frame(self, train_df: pd.DataFrame, target: pd.DataFrame | None) -> pd.DataFrame: @@ -628,6 +611,13 @@ def _generate_via_prompt(self, prompt: str, great_model_instance, device: str, m return "" # Fallback or re-raise +# Wire up factory classes to their concrete sampler implementations +OriginalGenerator._sampler_class = SamplerOriginal +GANGenerator._sampler_class = SamplerGAN +ForestDiffusionGenerator._sampler_class = SamplerDiffusion +LLMGenerator._sampler_class = SamplerLLM + + if __name__ == "__main__": setup_logging(logging.DEBUG) train_size = 75 diff --git a/src/tabgan/sklearn_transformer.py b/src/tabgan/sklearn_transformer.py new file mode 100644 index 0000000..4c6d5f6 --- /dev/null +++ b/src/tabgan/sklearn_transformer.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +""" +sklearn-compatible transformer for TabGAN data augmentation. + +Allows inserting synthetic data generation into a ``sklearn.pipeline.Pipeline``. +""" + +import logging +from typing import List, Optional, Type + +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, TransformerMixin + +from tabgan.sampler import GANGenerator + +__all__ = ["TabGANTransformer"] + + +class TabGANTransformer(BaseEstimator, TransformerMixin): + """Augment training data with TabGAN synthetic rows inside an sklearn Pipeline. + + During ``fit`` the generator is trained and synthetic data produced. + ``transform`` returns the augmented DataFrame (original + synthetic). + + Because sklearn's ``transform`` only returns X, the augmented target + is available via :meth:`get_augmented_target` after ``fit_transform``. + + Args: + generator_class: A TabGAN generator class (e.g. ``GANGenerator``). + gen_x_times: Multiplier for synthetic sample count. + cat_cols: Categorical column names. + gen_params: Generator-specific hyperparameters. + only_generated_data: If True, return only synthetic rows. + constraints: Optional list of ``Constraint`` instances. + use_adversarial: Whether to use adversarial filtering. + **generator_kwargs: Extra keyword arguments forwarded to the generator. + + Example:: + + from sklearn.pipeline import Pipeline + from sklearn.ensemble import RandomForestClassifier + from tabgan.sklearn_transformer import TabGANTransformer + + pipe = Pipeline([ + ("augment", TabGANTransformer(gen_x_times=1.5)), + ("model", RandomForestClassifier()), + ]) + pipe.fit(X_train, y_train) + """ + + def __init__( + self, + generator_class: Type = None, + gen_x_times: float = 1.1, + cat_cols: Optional[List[str]] = None, + gen_params: Optional[dict] = None, + only_generated_data: bool = False, + constraints: Optional[list] = None, + use_adversarial: bool = True, + **generator_kwargs, + ): + self.generator_class = generator_class + self.gen_x_times = gen_x_times + self.cat_cols = cat_cols + self.gen_params = gen_params + self.only_generated_data = only_generated_data + self.constraints = constraints + self.use_adversarial = use_adversarial + self.generator_kwargs = generator_kwargs + + # Internal state (set after fit) + self._augmented_X: Optional[pd.DataFrame] = None + self._augmented_y: Optional[pd.Series] = None + + def fit(self, X, y=None): + """Train the generator and produce synthetic data. + + Args: + X: Training features (DataFrame or ndarray). + y: Target variable (Series, DataFrame, ndarray, or None). + """ + gen_cls = self.generator_class or GANGenerator + + X_df = pd.DataFrame(X).copy() if not isinstance(X, pd.DataFrame) else X.copy() + + target_df = None + if y is not None: + if isinstance(y, pd.DataFrame): + target_df = y.copy() + elif isinstance(y, pd.Series): + target_df = y.to_frame().copy() + else: + target_df = pd.DataFrame(y, columns=["target"]) + + gen_kwargs = dict( + gen_x_times=self.gen_x_times, + cat_cols=self.cat_cols, + only_generated_data=self.only_generated_data, + ) + if self.gen_params is not None: + gen_kwargs["gen_params"] = self.gen_params + gen_kwargs.update(self.generator_kwargs) + + generator = gen_cls(**gen_kwargs) + + new_train, new_target = generator.generate_data_pipe( + X_df, + target_df, + X_df, # use train as test for distribution alignment + use_adversarial=self.use_adversarial, + constraints=self.constraints, + ) + + self._augmented_X = new_train + if new_target is not None and not new_target.isna().all(): + self._augmented_y = ( + new_target.iloc[:, 0] if isinstance(new_target, pd.DataFrame) else new_target + ) + else: + self._augmented_y = None + + return self + + def transform(self, X, y=None): + """Return the augmented training data. + + During training (when ``_augmented_X`` is available), returns the + augmented data. At inference time, returns X unchanged. + """ + if self._augmented_X is not None: + result = self._augmented_X + # Clear after first transform to avoid leaking into predict + self._augmented_X = None + return result + return X + + def fit_transform(self, X, y=None, **fit_params): + """Fit and return augmented data in one step.""" + self.fit(X, y) + return self.transform(X, y) + + def get_augmented_target(self) -> Optional[pd.Series]: + """Return the augmented target produced during ``fit``. + + Call this after ``fit`` or ``fit_transform`` to get the target + values corresponding to the augmented training data. + """ + return self._augmented_y diff --git a/src/tabgan/utils.py b/src/tabgan/utils.py index df8b904..a0f9b19 100644 --- a/src/tabgan/utils.py +++ b/src/tabgan/utils.py @@ -24,11 +24,10 @@ def setup_logging(loglevel): ) -def make_two_digit(num_as_str: str) -> pd.DataFrame: - if len(num_as_str) == 2: - return num_as_str - else: +def make_two_digit(num_as_str: str) -> str: + if len(num_as_str) < 2: return "0" + num_as_str + return num_as_str def get_year_mnth_dt_from_date(df: pd.DataFrame, date_col="Date") -> pd.DataFrame: @@ -66,7 +65,7 @@ def seed_everything(seed=1234): torch.backends.cudnn.deterministic = True -def _sampler(creator, in_train, in_target, in_test) -> None: +def _sampler(creator, in_train, in_target, in_test) -> tuple: _logger = logging.getLogger(__name__) _logger.info("Starting generating data") train, test = creator.generate_data_pipe(in_train, in_target, in_test) @@ -230,7 +229,7 @@ def compare_dataframes(df_original, df_generated): # Combine uniqueness, data quality, and PSI scores (weighted) similarity_score = 0.1 * uniqueness_score + 0.45 * data_quality_score + 0.45 * (1/psi_similarity) - print(uniqueness_score, data_quality_score, psi_similarity) + logging.debug(f"Similarity components: uniqueness={uniqueness_score}, quality={data_quality_score}, psi={psi_similarity}") # Ensure score is between 0 and 1 similarity_score = min(max(similarity_score, 0), 1) diff --git a/tests/test_cli.py b/tests/test_cli.py index b4886d9..6dfe593 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,50 +7,162 @@ import pandas as pd +def _make_cli_env(): + """Return an env dict with PYTHONPATH pointing at the src directory.""" + env = os.environ.copy() + src_path = os.path.join(os.path.dirname(__file__), "..", "src") + env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "") + return env + + +def _run_cli(args, env): + """Run `python -m tabgan.cli` with the given extra arguments.""" + cmd = [sys.executable, "-m", "tabgan.cli"] + args + return subprocess.run(cmd, check=False, capture_output=True, text=True, env=env) + + def test_tabgan_generate_cli_creates_output_with_target(): - # Prepare temporary input and output CSV paths with tempfile.TemporaryDirectory() as tmpdir: input_path = os.path.join(tmpdir, "train.csv") output_path = os.path.join(tmpdir, "synthetic.csv") - # Small dummy dataset with a target column df = pd.DataFrame( np.random.randint(0, 10, size=(20, 3)), columns=["A", "B", "target"], ) df.to_csv(input_path, index=False) - # Set up environment with PYTHONPATH pointing to src directory - env = os.environ.copy() - src_path = os.path.join(os.path.dirname(__file__), "..", "src") - env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "") - - # Invoke the installed console script through Python -m to avoid PATH issues in CI - cmd = [ - sys.executable, - "-m", - "tabgan.cli", - "--input-csv", - input_path, - "--target-col", - "target", - "--generator", - "original", - "--gen-x-times", - "1.0", - "--output-csv", - output_path, - ] - - result = subprocess.run(cmd, check=False, capture_output=True, text=True, env=env) + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) assert result.returncode == 0, f"CLI failed: {result.stderr}" assert os.path.exists(output_path), "Output CSV was not created by CLI" out_df = pd.read_csv(output_path) - - # Target column must be present assert "target" in out_df.columns - # At least as many rows as the original (OriginalGenerator samples with replacement) assert len(out_df) >= len(df) + +def test_cli_gan_generator(): + """CLI with --generator gan produces valid output.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 50, size=(30, 3)), + columns=["A", "B", "target"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "gan", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (gan) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "target" in out_df.columns + assert len(out_df) > 0 + + +def test_cli_only_generated_flag(): + """CLI with --only-generated returns output without original rows.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 10, size=(20, 3)), + columns=["A", "B", "target"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--only-generated", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (only-generated) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "target" in out_df.columns + assert len(out_df) > 0 + + +def test_cli_without_target(): + """CLI without --target-col should work (target=None path).""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + df = pd.DataFrame( + np.random.randint(0, 10, size=(20, 3)), + columns=["A", "B", "C"], + ) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--generator", "original", + "--gen-x-times", "1.0", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (no target) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert len(out_df) > 0 + + +def test_cli_with_cat_cols(): + """CLI with --cat-cols passes categorical column names correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + rng = np.random.RandomState(42) + df = pd.DataFrame({ + "num": rng.randint(0, 100, 30), + "cat": rng.choice(["X", "Y", "Z"], 30), + "target": rng.randint(0, 2, 30), + }) + df.to_csv(input_path, index=False) + + env = _make_cli_env() + result = _run_cli([ + "--input-csv", input_path, + "--target-col", "target", + "--generator", "original", + "--gen-x-times", "1.0", + "--cat-cols", "cat", + "--output-csv", output_path, + ], env) + + assert result.returncode == 0, f"CLI (cat-cols) failed: {result.stderr}" + assert os.path.exists(output_path) + + out_df = pd.read_csv(output_path) + assert "cat" in out_df.columns + assert "target" in out_df.columns + diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 0000000..ee7b435 --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""Tests for the constraint system.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.constraints import ( + RangeConstraint, + UniqueConstraint, + FormulaConstraint, + RegexConstraint, + ConstraintEngine, +) +from src.tabgan.sampler import OriginalGenerator + + +class TestRangeConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"age": [5, 25, 150, -1]}) + c = RangeConstraint("age", min_val=0, max_val=120) + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, True, False, False]) + + def test_fix_clips_values(self): + df = pd.DataFrame({"age": [5, 25, 150, -1]}) + c = RangeConstraint("age", min_val=0, max_val=120) + fixed = c.fix(df) + self.assertEqual(list(fixed["age"]), [5, 25, 120, 0]) + + def test_min_only(self): + df = pd.DataFrame({"x": [-5, 0, 10]}) + c = RangeConstraint("x", min_val=0) + self.assertEqual(list(c.is_satisfied(df)), [False, True, True]) + + def test_max_only(self): + df = pd.DataFrame({"x": [-5, 0, 10]}) + c = RangeConstraint("x", max_val=5) + self.assertEqual(list(c.is_satisfied(df)), [True, True, False]) + + def test_requires_at_least_one_bound(self): + with self.assertRaises(ValueError): + RangeConstraint("x") + + +class TestUniqueConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"id": [1, 2, 3, 2, 1]}) + c = UniqueConstraint("id") + mask = c.is_satisfied(df) + # First occurrences are True, duplicates are False + self.assertEqual(list(mask), [True, True, True, False, False]) + + def test_fix_drops_duplicates(self): + df = pd.DataFrame({"id": [1, 2, 3, 2, 1], "val": [10, 20, 30, 40, 50]}) + c = UniqueConstraint("id") + fixed = c.fix(df) + self.assertEqual(len(fixed), 3) + self.assertEqual(list(fixed["id"]), [1, 2, 3]) + + +class TestFormulaConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"start": [1, 5, 10], "end": [10, 3, 15]}) + c = FormulaConstraint("end > start") + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, False, True]) + + def test_fix_filters_violations(self): + df = pd.DataFrame({"start": [1, 5, 10], "end": [10, 3, 15]}) + c = FormulaConstraint("end > start") + fixed = c.fix(df) + self.assertEqual(len(fixed), 2) + self.assertEqual(list(fixed["start"]), [1, 10]) + + +class TestRegexConstraint(unittest.TestCase): + def test_is_satisfied(self): + df = pd.DataFrame({"email": ["a@b.com", "invalid", "x@y.org"]}) + c = RegexConstraint("email", r".+@.+\..+") + mask = c.is_satisfied(df) + self.assertEqual(list(mask), [True, False, True]) + + def test_fix_filters(self): + df = pd.DataFrame({"code": ["AB12", "XY34", "bad!"]}) + c = RegexConstraint("code", r"[A-Z]{2}\d{2}") + fixed = c.fix(df) + self.assertEqual(len(fixed), 2) + self.assertEqual(list(fixed["code"]), ["AB12", "XY34"]) + + +class TestConstraintEngine(unittest.TestCase): + def test_filter_strategy(self): + df = pd.DataFrame({"age": [5, 150, 30], "id": [1, 2, 3]}) + engine = ConstraintEngine( + [RangeConstraint("age", min_val=0, max_val=120)], + strategy="filter", + ) + result = engine.apply(df) + self.assertEqual(len(result), 2) + + def test_fix_strategy(self): + df = pd.DataFrame({"age": [5, 150, 30], "id": [1, 2, 3]}) + engine = ConstraintEngine( + [RangeConstraint("age", min_val=0, max_val=120)], + strategy="fix", + ) + result = engine.apply(df) + self.assertEqual(len(result), 3) # All rows kept after clipping + self.assertEqual(result["age"].max(), 120) + + def test_multiple_constraints(self): + df = pd.DataFrame({ + "age": [5, 150, 30, 25, 25], + "id": [1, 2, 3, 4, 4], + }) + engine = ConstraintEngine([ + RangeConstraint("age", min_val=0, max_val=120), + UniqueConstraint("id"), + ], strategy="fix") + result = engine.apply(df) + # After fix: age clipped, then unique id kept + self.assertTrue(result["age"].max() <= 120) + self.assertEqual(result["id"].nunique(), len(result)) + + def test_invalid_strategy_raises(self): + with self.assertRaises(ValueError): + ConstraintEngine([], strategy="invalid") + + +class TestConstraintsInPipeline(unittest.TestCase): + def test_generate_data_pipe_with_constraints(self): + rng = np.random.RandomState(42) + train = pd.DataFrame(rng.randint(-10, 200, size=(60, 3)), columns=list("ABC")) + target = pd.DataFrame(rng.randint(0, 2, size=(60, 1)), columns=["Y"]) + test = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + + constraints = [RangeConstraint("A", min_val=0, max_val=100)] + + new_train, new_target = OriginalGenerator(gen_x_times=1.5).generate_data_pipe( + train, target, test, constraints=constraints, + ) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreaterEqual(new_train["A"].min(), 0) + self.assertLessEqual(new_train["A"].max(), 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_data_pipe.py b/tests/test_generate_data_pipe.py new file mode 100644 index 0000000..f08ae76 --- /dev/null +++ b/tests/test_generate_data_pipe.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +""" +Tests for generate_data_pipe parameter combinations and cat_cols handling +in postprocess / adversarial filtering. +""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.sampler import ( + OriginalGenerator, + GANGenerator, + ForestDiffusionGenerator, + SamplerOriginal, +) + + +def _make_data(n_train=80, n_test=80, seed=42): + """Create reproducible train / target / test DataFrames.""" + rng = np.random.RandomState(seed) + train = pd.DataFrame(rng.randint(0, 100, size=(n_train, 4)), columns=list("ABCD")) + target = pd.DataFrame(rng.randint(0, 2, size=(n_train, 1)), columns=["Y"]) + test = pd.DataFrame(rng.randint(0, 100, size=(n_test, 4)), columns=list("ABCD")) + return train, target, test + + +def _make_data_with_cat(n_train=80, n_test=80, seed=42): + """Create data with an explicit categorical column.""" + rng = np.random.RandomState(seed) + train = pd.DataFrame({ + "num1": rng.randint(0, 100, n_train), + "num2": rng.randint(0, 100, n_train), + "cat": rng.choice(["X", "Y", "Z"], n_train), + }) + target = pd.DataFrame({"Y": rng.randint(0, 2, n_train)}) + test = pd.DataFrame({ + "num1": rng.randint(0, 100, n_test), + "num2": rng.randint(0, 100, n_test), + "cat": rng.choice(["X", "Y", "Z"], n_test), + }) + return train, target, test + + +# --------------------------------------------------------------------------- +# generate_data_pipe parameter combinations +# --------------------------------------------------------------------------- +class TestGenerateDataPipeParams(unittest.TestCase): + """Test various parameter combinations of generate_data_pipe.""" + + def test_only_adversarial_true(self): + """only_adversarial=True should skip generation and only filter.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train, target, test, + only_adversarial=True, + use_adversarial=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + # With only adversarial filtering on original data, output rows <= input rows + self.assertLessEqual(new_train.shape[0], train.shape[0]) + + def test_use_adversarial_false(self): + """use_adversarial=False should skip adversarial filtering entirely.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.5).generate_data_pipe( + train, target, test, + use_adversarial=False, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + # Without adversarial filtering, we should have more rows than original + self.assertGreater(new_train.shape[0], train.shape[0]) + + def test_deep_copy_false(self): + """deep_copy=False should still produce valid output.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train.copy(), target.copy(), test.copy(), + deep_copy=False, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + + def test_only_generated_data_true_original(self): + """only_generated_data=True with OriginalGenerator — data is sampled from train.""" + train, target, test = _make_data() + new_train, new_target = OriginalGenerator( + gen_x_times=1.1, + only_generated_data=True, + ).generate_data_pipe( + train, target, test, + only_generated_data=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + + def test_only_generated_data_true_gan(self): + """only_generated_data=True with GANGenerator returns purely synthetic rows.""" + train, target, test = _make_data() + new_train, new_target = GANGenerator( + gen_x_times=1.0, + only_generated_data=True, + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ).generate_data_pipe( + train, target, test, + only_generated_data=True, + ) + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + self.assertEqual(new_train.shape[1], train.shape[1]) + + def test_target_none(self): + """Passing target=None should work for all generators.""" + train, _, test = _make_data() + new_train, new_target = OriginalGenerator(gen_x_times=1.1).generate_data_pipe( + train, None, test, + ) + self.assertIsNotNone(new_train) + self.assertGreater(new_train.shape[0], 0) + + def test_test_df_none(self): + """Passing test_df=None should skip postprocess and adversarial.""" + train, _, _ = _make_data() + new_train, new_target = OriginalGenerator( + gen_x_times=1.1, + is_post_process=False, + ).generate_data_pipe( + train, None, None, + ) + self.assertIsNotNone(new_train) + self.assertGreater(new_train.shape[0], 0) + + +# --------------------------------------------------------------------------- +# cat_cols in postprocess and adversarial filtering +# --------------------------------------------------------------------------- +class TestCatColsPostprocessAndAdversarial(unittest.TestCase): + """Test that cat_cols are handled correctly in postprocess and adversarial.""" + + def test_postprocess_with_cat_cols(self): + """Postprocessing with cat_cols should filter by category membership.""" + train, target, test = _make_data_with_cat() + + sampler = OriginalGenerator( + gen_x_times=2.0, + cat_cols=["cat"], + ).get_object_generator() + + new_train, new_target, test_df = sampler.preprocess_data( + train.copy(), target.copy(), test.copy() + ) + gen_train, gen_target = sampler.generate_data( + new_train, new_target, test_df, only_generated_data=False + ) + post_train, post_target = sampler.postprocess_data(gen_train, gen_target, test_df) + + self.assertEqual(post_train.shape[0], post_target.shape[0]) + # All categorical values in result should be present in test + result_cats = set(post_train["cat"].unique()) + test_cats = set(test_df["cat"].unique()) + self.assertTrue(result_cats.issubset(test_cats)) + + def test_adversarial_with_cat_cols(self): + """Adversarial filtering with cat_cols should produce valid output.""" + train, target, test = _make_data_with_cat() + + sampler = OriginalGenerator( + gen_x_times=2.0, + cat_cols=["cat"], + ).get_object_generator() + + new_train, new_target, test_df = sampler.preprocess_data( + train.copy(), target.copy(), test.copy() + ) + gen_train, gen_target = sampler.generate_data( + new_train, new_target, test_df, only_generated_data=False + ) + post_train, post_target = sampler.postprocess_data(gen_train, gen_target, test_df) + adv_train, adv_target = sampler.adversarial_filtering(post_train, post_target, test_df) + + self.assertEqual(adv_train.shape[0], adv_target.shape[0]) + self.assertGreater(adv_train.shape[0], 0) + + def test_full_pipeline_with_cat_cols(self): + """End-to-end generate_data_pipe with cat_cols.""" + train, target, test = _make_data_with_cat() + new_train, new_target = OriginalGenerator( + gen_x_times=1.5, + cat_cols=["cat"], + ).generate_data_pipe(train, target, test) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + self.assertIn("cat", new_train.columns) + + def test_gan_with_cat_cols(self): + """GANGenerator with cat_cols should train and generate correctly.""" + train, target, test = _make_data_with_cat() + new_train, new_target = GANGenerator( + gen_x_times=1.1, + cat_cols=["cat"], + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ).generate_data_pipe(train, target, test) + + self.assertEqual(new_train.shape[0], new_target.shape[0]) + self.assertGreater(new_train.shape[0], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_privacy_metrics.py b/tests/test_privacy_metrics.py new file mode 100644 index 0000000..8e9e9ce --- /dev/null +++ b/tests/test_privacy_metrics.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +"""Tests for privacy metrics.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.privacy_metrics import PrivacyMetrics + + +def _make_numeric_data(seed=42): + rng = np.random.RandomState(seed) + original = pd.DataFrame(rng.normal(0, 1, size=(100, 4)), columns=list("ABCD")) + synthetic = pd.DataFrame(rng.normal(0, 1, size=(80, 4)), columns=list("ABCD")) + return original, synthetic + + +class TestDCR(unittest.TestCase): + def test_dcr_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.dcr() + self.assertIn("mean", result) + self.assertIn("median", result) + self.assertIn("5th_percentile", result) + self.assertIn("distances", result) + + def test_identical_data_dcr_near_zero(self): + orig, _ = _make_numeric_data() + pm = PrivacyMetrics(orig, orig.copy()) + result = pm.dcr() + self.assertAlmostEqual(result["mean"], 0.0, places=5) + + def test_distant_data_dcr_positive(self): + rng = np.random.RandomState(42) + orig = pd.DataFrame(rng.normal(0, 1, size=(100, 3)), columns=list("ABC")) + synth = pd.DataFrame(rng.normal(10, 1, size=(80, 3)), columns=list("ABC")) + pm = PrivacyMetrics(orig, synth) + result = pm.dcr() + self.assertGreater(result["mean"], 1.0) + + def test_dcr_with_sample_size(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.dcr(sample_size=20) + self.assertEqual(len(result["distances"]), 20) + + +class TestNNDR(unittest.TestCase): + def test_nndr_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.nndr() + self.assertIn("mean", result) + self.assertIn("median", result) + self.assertIn("ratios", result) + + def test_nndr_values_between_0_and_1(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.nndr() + self.assertGreater(result["mean"], 0.0) + self.assertLessEqual(result["mean"], 1.0) + + +class TestMembershipInference(unittest.TestCase): + def test_mi_returns_expected_keys(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.membership_inference_risk() + self.assertIn("auc", result) + self.assertIn("accuracy", result) + + def test_mi_auc_in_range(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + result = pm.membership_inference_risk() + self.assertGreaterEqual(result["auc"], 0.0) + self.assertLessEqual(result["auc"], 1.0) + + +class TestSummary(unittest.TestCase): + def test_summary_returns_overall_score(self): + orig, synth = _make_numeric_data() + pm = PrivacyMetrics(orig, synth) + s = pm.summary() + self.assertIn("overall_privacy_score", s) + self.assertGreaterEqual(s["overall_privacy_score"], 0.0) + self.assertLessEqual(s["overall_privacy_score"], 1.0) + + def test_summary_contains_all_sections(self): + orig, synth = _make_numeric_data() + s = PrivacyMetrics(orig, synth).summary() + self.assertIn("dcr", s) + self.assertIn("nndr", s) + self.assertIn("membership_inference", s) + + def test_with_cat_cols(self): + rng = np.random.RandomState(42) + orig = pd.DataFrame({ + "num": rng.normal(0, 1, 80), + "cat": rng.choice(["A", "B", "C"], 80), + }) + synth = pd.DataFrame({ + "num": rng.normal(0, 1, 60), + "cat": rng.choice(["A", "B", "C"], 60), + }) + s = PrivacyMetrics(orig, synth, cat_cols=["cat"]).summary() + self.assertIn("overall_privacy_score", s) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_quality_report.py b/tests/test_quality_report.py new file mode 100644 index 0000000..53eb00a --- /dev/null +++ b/tests/test_quality_report.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Tests for the quality report.""" + +import os +import tempfile +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.quality_report import QualityReport + + +def _make_data(seed=42): + rng = np.random.RandomState(seed) + original = pd.DataFrame({ + "A": rng.normal(0, 1, 100), + "B": rng.normal(5, 2, 100), + "target": rng.randint(0, 2, 100), + }) + synthetic = pd.DataFrame({ + "A": rng.normal(0, 1, 80), + "B": rng.normal(5, 2, 80), + "target": rng.randint(0, 2, 80), + }) + return original, synthetic + + +class TestQualityReportCompute(unittest.TestCase): + def test_compute_returns_self(self): + orig, synth = _make_data() + report = QualityReport(orig, synth).compute() + self.assertIsInstance(report, QualityReport) + + def test_summary_has_required_keys(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("column_stats", s) + self.assertIn("psi", s) + self.assertIn("ml_utility", s) + self.assertIn("overall_score", s) + + def test_overall_score_range(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertGreaterEqual(s["overall_score"], 0.0) + self.assertLessEqual(s["overall_score"], 1.0) + + def test_column_stats_numeric(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("A", s["column_stats"]) + self.assertEqual(s["column_stats"]["A"]["dtype"], "numeric") + self.assertIn("orig_mean", s["column_stats"]["A"]) + + def test_column_stats_categorical(self): + orig = pd.DataFrame({"cat": ["a", "b", "a", "c"], "num": [1, 2, 3, 4]}) + synth = pd.DataFrame({"cat": ["a", "b", "b", "c"], "num": [1, 2, 3, 4]}) + s = QualityReport(orig, synth, cat_cols=["cat"]).compute().summary() + self.assertEqual(s["column_stats"]["cat"]["dtype"], "categorical") + + def test_psi_per_column(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertIn("A", s["psi"]) + self.assertIn("mean", s["psi"]) + + def test_ml_utility_with_target(self): + orig, synth = _make_data() + s = QualityReport(orig, synth, target_col="target").compute().summary() + ml = s["ml_utility"] + self.assertIn("tstr_auc", ml) + self.assertIn("trtr_auc", ml) + self.assertIn("utility_ratio", ml) + + def test_ml_utility_without_target(self): + orig, synth = _make_data() + s = QualityReport(orig, synth).compute().summary() + self.assertEqual(s["ml_utility"]["utility_ratio"], 0.0) + + +class TestQualityReportHTML(unittest.TestCase): + def test_to_html_creates_file(self): + orig, synth = _make_data() + report = QualityReport(orig, synth, target_col="target").compute() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + self.assertTrue(os.path.exists(path)) + with open(path) as f: + content = f.read() + self.assertIn("TabGAN Quality Report", content) + self.assertIn("Overall Score", content) + + def test_html_contains_charts(self): + orig, synth = _make_data() + report = QualityReport(orig, synth).compute() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + with open(path) as f: + content = f.read() + # Should contain base64 images + self.assertIn("data:image/png;base64", content) + + def test_auto_compute_on_to_html(self): + """to_html should auto-compute if compute() wasn't called.""" + orig, synth = _make_data() + report = QualityReport(orig, synth) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "report.html") + report.to_html(path) + self.assertTrue(os.path.exists(path)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 2ee45c7..d6ecec0 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -78,22 +78,23 @@ def test__validate_data(self): args = [self.train.head(), self.target.copy(), self.test] self.assertRaises(ValueError, self.sampler._validate_data, *args) - class TestSamplerGAN(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = GANGenerator(gen_x_times=15) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + +class TestSamplerGAN(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = GANGenerator(gen_x_times=15) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) class TestSamplerLLMConditional(TestCase): @@ -426,37 +427,39 @@ def test_empty_text_or_conditional_columns_use_fallback_sampling(self): mock_conditional.assert_not_called() pd.testing.assert_frame_equal(new_train_df.reset_index(drop=True), dummy_generated.reset_index(drop=True)) - class TestSamplerSamplerDiffusion(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = ForestDiffusionGenerator(gen_x_times=15) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - - class TestSamplerSamplerDiffusion(TestCase): - def setUp(self): - self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) - self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) - self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) - self.gen = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", - "max_length": 500}) - self.sampler = self.gen.get_object_generator() - - def test_generate_data(self): - new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), - self.target.copy(), self.test) - gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df) - self.assertEqual(gen_train.shape[0], gen_target.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) - self.assertTrue(gen_train.shape[0] > new_train.shape[0]) - self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + +class TestSamplerDiffusion(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = ForestDiffusionGenerator(gen_x_times=15) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + + +class TestSamplerLLMDirect(TestCase): + def setUp(self): + self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD')) + self.target = pd.DataFrame(np.random.randint(0, 2, size=(50, 1)), columns=list('Y')) + self.test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')) + self.gen = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", + "max_length": 500}) + self.sampler = self.gen.get_object_generator() + + def test_generate_data(self): + new_train, new_target, test_df = self.sampler.preprocess_data(self.train.copy(), + self.target.copy(), self.test) + gen_train, gen_target = self.sampler.generate_data(new_train, new_target, test_df, only_generated_data=False) + self.assertEqual(gen_train.shape[0], gen_target.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + self.assertTrue(gen_train.shape[0] > new_train.shape[0]) + self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) diff --git a/tests/test_sklearn_transformer.py b/tests/test_sklearn_transformer.py new file mode 100644 index 0000000..2f0449d --- /dev/null +++ b/tests/test_sklearn_transformer.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +"""Tests for the sklearn-compatible TabGANTransformer.""" + +import unittest + +import numpy as np +import pandas as pd + +from src.tabgan.sampler import OriginalGenerator, GANGenerator +from src.tabgan.sklearn_transformer import TabGANTransformer +from src.tabgan.constraints import RangeConstraint + + +class TestTabGANTransformerBasic(unittest.TestCase): + def setUp(self): + rng = np.random.RandomState(42) + self.X = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + self.y = pd.Series(rng.randint(0, 2, 60), name="target") + + def test_fit_transform_returns_dataframe(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + result = t.fit_transform(self.X, self.y) + self.assertIsInstance(result, pd.DataFrame) + self.assertGreater(len(result), 0) + + def test_augmented_target_available(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + t.fit(self.X, self.y) + aug_y = t.get_augmented_target() + self.assertIsNotNone(aug_y) + + def test_augmented_shapes_aligned(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + X_aug = t.fit_transform(self.X, self.y) + y_aug = t.get_augmented_target() + self.assertEqual(len(X_aug), len(y_aug)) + + def test_without_target(self): + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + result = t.fit_transform(self.X) + self.assertIsInstance(result, pd.DataFrame) + self.assertIsNone(t.get_augmented_target()) + + def test_transform_at_inference_passthrough(self): + """After first transform, subsequent transforms pass data through.""" + t = TabGANTransformer(generator_class=OriginalGenerator, gen_x_times=1.1) + t.fit(self.X, self.y) + _ = t.transform(self.X) # First transform consumes augmented data + result2 = t.transform(self.X) # Second should pass through + self.assertEqual(len(result2), len(self.X)) + + +class TestTabGANTransformerWithGAN(unittest.TestCase): + def setUp(self): + rng = np.random.RandomState(42) + self.X = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + self.y = pd.Series(rng.randint(0, 2, 60), name="target") + + def test_gan_generator(self): + t = TabGANTransformer( + generator_class=GANGenerator, + gen_x_times=1.0, + gen_params={"batch_size": 50, "patience": 5, "epochs": 2}, + ) + result = t.fit_transform(self.X, self.y) + self.assertGreater(len(result), 0) + self.assertEqual(result.shape[1], self.X.shape[1]) + + +class TestTabGANTransformerWithConstraints(unittest.TestCase): + def test_constraints_applied(self): + rng = np.random.RandomState(42) + X = pd.DataFrame(rng.randint(-50, 200, size=(60, 3)), columns=list("ABC")) + y = pd.Series(rng.randint(0, 2, 60)) + + t = TabGANTransformer( + generator_class=OriginalGenerator, + gen_x_times=1.5, + constraints=[RangeConstraint("A", min_val=0, max_val=100)], + ) + result = t.fit_transform(X, y) + self.assertGreaterEqual(result["A"].min(), 0) + self.assertLessEqual(result["A"].max(), 100) + + +class TestTabGANTransformerSklearnCompat(unittest.TestCase): + def test_get_params(self): + t = TabGANTransformer(gen_x_times=2.0, cat_cols=["A"]) + params = t.get_params() + self.assertEqual(params["gen_x_times"], 2.0) + self.assertEqual(params["cat_cols"], ["A"]) + + def test_set_params(self): + t = TabGANTransformer(gen_x_times=1.0) + t.set_params(gen_x_times=3.0) + self.assertEqual(t.gen_x_times, 3.0) + + def test_in_sklearn_pipeline(self): + """Verify it can be placed in a Pipeline (doesn't run the full pipeline).""" + from sklearn.pipeline import Pipeline + from sklearn.ensemble import RandomForestClassifier + + pipe = Pipeline([ + ("augment", TabGANTransformer( + generator_class=OriginalGenerator, + gen_x_times=1.1, + )), + ("model", RandomForestClassifier(n_estimators=5, random_state=42)), + ]) + # Pipeline should be constructable and have proper steps + self.assertEqual(len(pipe.steps), 2) + self.assertEqual(pipe.steps[0][0], "augment") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index c0f2319..443a918 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,14 +4,14 @@ from tabgan.utils import make_two_digit, get_year_mnth_dt_from_date -class TestUtils(unittest.TestCase): +class TestMakeTwoDigit(unittest.TestCase): def test_make_two_digit(self): self.assertEqual(make_two_digit('1'), '01') self.assertEqual(make_two_digit('12'), '12') self.assertEqual(make_two_digit('123'), '123') -class TestUtils(unittest.TestCase): +class TestGetYearMonthDtFromDate(unittest.TestCase): def test_get_year_month_dt_from_date(self): # create a sample dataframe df = pd.DataFrame({ diff --git a/tests/test_utils_extended.py b/tests/test_utils_extended.py new file mode 100644 index 0000000..82de28a --- /dev/null +++ b/tests/test_utils_extended.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +""" +Extended tests for tabgan.utils — covers compare_dataframes, calculate_psi, +collect_dates, and the time-series round-trip workflow. +""" + +import unittest + +import numpy as np +import pandas as pd + +from tabgan.utils import ( + calculate_psi, + collect_dates, + compare_dataframes, + get_year_mnth_dt_from_date, +) + + +# --------------------------------------------------------------------------- +# calculate_psi +# --------------------------------------------------------------------------- +class TestCalculatePSI(unittest.TestCase): + def test_identical_distributions_returns_near_zero(self): + arr = np.random.RandomState(42).normal(0, 1, size=500) + psi = calculate_psi(arr, arr.copy(), buckets=10) + # PSI of identical arrays should be ~0 + self.assertAlmostEqual(float(psi), 0.0, places=3) + + def test_shifted_distribution_returns_positive(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=1000) + actual = rng.normal(2, 1, size=1000) # shifted mean + psi = calculate_psi(expected, actual, buckets=10) + self.assertGreater(float(psi), 0.1) + + def test_2d_array_axis0(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=(200, 3)) + actual = rng.normal(0, 1, size=(200, 3)) + psi_values = calculate_psi(expected, actual, buckets=10, axis=0) + self.assertEqual(len(psi_values), 3) + for v in psi_values: + self.assertGreaterEqual(v, 0.0) + + def test_quantile_bucket_type(self): + rng = np.random.RandomState(42) + expected = rng.normal(0, 1, size=500) + actual = rng.normal(0, 1, size=500) + psi = calculate_psi(expected, actual, buckettype="quantiles", buckets=10) + self.assertGreaterEqual(float(psi), 0.0) + + +# --------------------------------------------------------------------------- +# compare_dataframes +# --------------------------------------------------------------------------- +class TestCompareDataframes(unittest.TestCase): + def test_identical_dataframes_score_high(self): + df = pd.DataFrame({"a": range(100), "b": range(100, 200)}) + score = compare_dataframes(df.copy(), df.copy()) + self.assertGreater(score, 0.5) + self.assertLessEqual(score, 1.0) + + def test_completely_different_columns_returns_zero(self): + df1 = pd.DataFrame({"a": [1, 2, 3]}) + df2 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + score = compare_dataframes(df1, df2) + self.assertEqual(score, 0.0) + + def test_similar_dataframes_score_moderate(self): + rng = np.random.RandomState(42) + df1 = pd.DataFrame({"x": rng.normal(0, 1, 200), "y": rng.normal(5, 2, 200)}) + df2 = pd.DataFrame({"x": rng.normal(0, 1, 200), "y": rng.normal(5, 2, 200)}) + score = compare_dataframes(df1, df2) + self.assertGreater(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_score_is_between_0_and_1(self): + rng = np.random.RandomState(0) + df1 = pd.DataFrame({"a": rng.randint(0, 100, 50)}) + df2 = pd.DataFrame({"a": rng.randint(0, 100, 50)}) + score = compare_dataframes(df1, df2) + self.assertGreaterEqual(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_with_only_numeric_columns(self): + rng = np.random.RandomState(10) + df1 = pd.DataFrame({"a": rng.randint(0, 50, 80), "b": rng.randint(0, 50, 80)}) + df2 = pd.DataFrame({"a": rng.randint(0, 50, 80), "b": rng.randint(0, 50, 80)}) + score = compare_dataframes(df1, df2) + self.assertGreaterEqual(score, 0.0) + self.assertLessEqual(score, 1.0) + + def test_with_mixed_types_raises_on_string_psi(self): + """compare_dataframes passes string columns to calculate_psi which + cannot handle them. This documents the current behavior.""" + df1 = pd.DataFrame({"num": [1, 2, 3, 4], "cat": ["a", "b", "a", "c"]}) + df2 = pd.DataFrame({"num": [1, 2, 5, 6], "cat": ["a", "b", "x", "y"]}) + with self.assertRaises(TypeError): + compare_dataframes(df1.copy(), df2.copy()) + + +# --------------------------------------------------------------------------- +# collect_dates (round-trip with get_year_mnth_dt_from_date) +# --------------------------------------------------------------------------- +class TestCollectDates(unittest.TestCase): + def test_round_trip(self): + """Decompose dates then reassemble — result should match originals.""" + dates = pd.to_datetime(["2022-01-15", "2023-06-01", "2021-12-31"]) + df = pd.DataFrame({"Date": dates, "value": [10, 20, 30]}) + decomposed = get_year_mnth_dt_from_date(df.copy(), "Date") + + # Drop the original Date column to simulate the generation pipeline + decomposed = decomposed.drop("Date", axis=1) + reassembled = collect_dates(decomposed) + + self.assertIn("Date", reassembled.columns) + self.assertNotIn("year", reassembled.columns) + self.assertNotIn("month", reassembled.columns) + self.assertNotIn("day", reassembled.columns) + + expected_dates = ["2022-01-15", "2023-06-01", "2021-12-31"] + self.assertEqual(list(reassembled["Date"]), expected_dates) + + def test_single_digit_month_day_padded(self): + """Months and days < 10 should be zero-padded.""" + df = pd.DataFrame({"year": [2020], "month": [3], "day": [5], "x": [1]}) + result = collect_dates(df) + self.assertEqual(result["Date"].iloc[0], "2020-03-05") + + +# --------------------------------------------------------------------------- +# Time-series generation workflow (integration-like) +# --------------------------------------------------------------------------- +class TestTimeSeriesWorkflow(unittest.TestCase): + def test_date_decompose_generate_collect(self): + """Full round-trip: decompose dates → OriginalGenerator → collect dates.""" + from tabgan.sampler import OriginalGenerator + + rng = np.random.RandomState(42) + train = pd.DataFrame(rng.randint(0, 100, size=(60, 3)), columns=list("ABC")) + min_date = pd.to_datetime("2020-01-01") + max_date = pd.to_datetime("2021-12-31") + d = (max_date - min_date).days + 1 + train["Date"] = min_date + pd.to_timedelta(rng.randint(d, size=60), unit="D") + train = get_year_mnth_dt_from_date(train, "Date") + + train_no_date = train.drop("Date", axis=1) + + new_train, _ = OriginalGenerator( + gen_x_times=1.1, + cat_cols=["year"], + is_post_process=True, + pregeneration_frac=2, + ).generate_data_pipe(train_no_date, None, train_no_date) + + self.assertIn("year", new_train.columns) + self.assertIn("month", new_train.columns) + self.assertIn("day", new_train.columns) + + new_train = collect_dates(new_train) + self.assertIn("Date", new_train.columns) + self.assertNotIn("year", new_train.columns) From 0eb0b3a09488ccfe70d5c71711097e6903198b43 Mon Sep 17 00:00:00 2001 From: Insaf Ashrapov Date: Sat, 28 Mar 2026 05:23:29 +0000 Subject: [PATCH 2/2] Add workflow diagram back to README Pipeline Architecture section Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 3ce678d..49576ac 100644 --- a/README.md +++ b/README.md @@ -335,6 +335,8 @@ tabgan-generate \ ## Pipeline Architecture +![Experiment design and workflow](images/workflow.png) + ``` Input (train_df, target, test_df) |