diff --git a/.gitignore b/.gitignore index 787cbdf..6dd64c4 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,4 @@ dmypy.json .DS_Store /src/tabgan/trainer_great/ +/tests/trainer_great/ diff --git a/README.md b/README.md index d520b20..1c92d6d 100644 --- a/README.md +++ b/README.md @@ -1,100 +1,132 @@ [![CodeFactor](https://www.codefactor.io/repository/github/diyago/tabular-data-generation/badge)](https://www.codefactor.io/repository/github/diyago/tabular-data-generation) -[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Python Version](https://img.shields.io/pypi/pyversions/tabgan)](https://pypi.org/project/tabgan/) [![Downloads](https://pepy.tech/badge/tabgan)](https://pepy.tech/project/tabgan) -# GANs and TimeGANs, Diffusions, LLM for tabular data +# TabGAN - Synthetic Tabular Data Generation + +A powerful library for generating high-quality synthetic tabular data using state-of-the-art generative models including GANs, Diffusion models, and Large Language Models. -Generative Networks are well-known for their success in realistic image generation. However, they can also be applied to generate tabular data. This library introduces major improvements for generating high-fidelity tabular data by offering a diverse suite of cutting-edge models, including Generative Adversarial Networks (GANs), specialized TimeGANs for time-series data, Denoising Diffusion Probabilistic Models (DDPM), and Large Language Model (LLM) based approaches. These enhancements allow for robust data generation across various dataset complexities and distributions, giving an opportunity to try GANs, TimeGANs, Diffusions, and LLMs for tabular data generation. -* Arxiv article: ["Tabular GANs for uneven distribution"](https://arxiv.org/abs/2010.00638) -* Medium post: [GANs for tabular data](https://medium.com/data-science/review-of-gans-for-tabular-data-a30a2199342) +## Overview -## How to use library +TabGAN provides a unified interface for generating synthetic tabular data using multiple generative approaches: -* Installation: `pip install tabgan` -* To generate new data to train by sampling and then filtering by adversarial training - call `GANGenerator().generate_data_pipe`. +- **GANs**: Conditional Tabular GAN (CTGAN) for modeling complex tabular distributions +- **Diffusion Models**: Forest Diffusion for high-quality synthetic data generation +- **LLMs**: GReaT framework for generating realistic tabular data using language models +- **Time-Series**: TimeGAN support for temporal data generation -### Data Format +*Related Research: [Tabular GANs for uneven distribution (arXiv)](https://arxiv.org/abs/2010.00638)* -TabGAN accepts data as a `numpy.ndarray` or `pandas.DataFrame` with columns categorized as: +## Installation -* **Continuous Columns**: Numerical columns with any possible value. -* **Discrete Columns**: Columns with a limited set of values (e.g., categorical data). +```bash +pip install tabgan +``` -Note: TabGAN does not differentiate between floats and integers, so all values are treated as floats. For integer requirements, round the output outside of TabGAN. +## Quick Start + +```python +import pandas as pd +import numpy as np +from tabgan.sampler import GANGenerator + +# Create sample data +train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD")) +target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y")) +test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD")) + +# Generate synthetic data +new_train, new_target = GANGenerator().generate_data_pipe(train, target, test) +``` + +## Available Generators + +| Generator | Description | Best For | +|-----------|-------------|----------| +| `GANGenerator` | CTGAN-based generation | General tabular data with mixed types | +| `ForestDiffusionGenerator` | Diffusion models + tree-based methods | Complex tabular structures | +| `LLMGenerator` | Large Language Model based | Capturing complex dependencies | +| `OriginalGenerator` | Baseline sampler | Baseline comparisons | + +## API Reference ### Sampler Parameters -All samplers (`OriginalGenerator`, `GANGenerator`, `ForestDiffusionGenerator`, `LLMGenerator`) share the following input parameters: +All generators accept these common parameters: -* **gen_x_times**: `float` (default: `1.1`) - How much data to generate. The output might be less due to postprocessing and adversarial filtering. -* **cat_cols**: `list` (default: `None`) - A list of column names to be treated as categorical. -* **bot_filter_quantile**: `float` (default: `0.001`) - The bottom quantile for postprocess filtering. Values below this quantile will be filtered out. -* **top_filter_quantile**: `float` (default: `0.999`) - The top quantile for postprocess filtering. Values above this quantile will be filtered out. -* **is_post_process**: `bool` (default: `True`) - Whether to perform post-filtering. If `False`, `bot_filter_quantile` and `top_filter_quantile` are ignored. -* **adversarial_model_params**: `dict` (default: see below) - Parameters for the adversarial filtering model. Default values are optimized for binary classification tasks. - ```python - { - "metrics": "AUC", "max_depth": 2, "max_bin": 100, - "learning_rate": 0.02, "random_state": 42, "n_estimators": 100, - } - ``` -* **pregeneration_frac**: `float` (default: `2`) - For the generation step, `gen_x_times * pregeneration_frac` amount of data will be generated. However, after postprocessing, the aim is to return an amount of data equivalent to `(1 + gen_x_times)` times the size of the original dataset (if `only_generated_data` is `False`, otherwise `gen_x_times` times the size of the original dataset). -* **only_generated_data**: `bool` (default: `False`) - If `True`, only the newly generated data is returned, without concatenating the input training dataframe. -* **gen_params**: `dict` (default: see below) - Parameters for the underlying generative model training. Specific to `GANGenerator` and `LLMGenerator`. - * For `GANGenerator`: - ```python - {"batch_size": 500, "patience": 25, "epochs" : 500} - ``` - * For `LLMGenerator`: - ```python - {"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500} - ``` - -The available samplers are: -1. **`GANGenerator`**: Utilizes the Conditional Tabular GAN (CTGAN) architecture, known for effectively modeling tabular data distributions and handling mixed data types (continuous and discrete). It learns the data distribution and generates synthetic samples that mimic the original data. -2. **`ForestDiffusionGenerator`**: Implements a novel approach using diffusion models guided by tree-based methods (Forest Diffusion). This technique is capable of generating high-quality synthetic data, particularly for complex tabular structures, by gradually adding noise to data and then learning to reverse the process. -3. **`LLMGenerator`**: Leverages Large Language Models (LLMs) using the GReaT (Generative Realistic Tabular data) framework. It transforms tabular data into a text format, fine-tunes an LLM on this representation, and then uses the LLM to generate new tabular instances by sampling from it. This approach is particularly promising for capturing complex dependencies and can generate diverse synthetic data. -4. **`OriginalGenerator`**: Acts as a baseline sampler. It typically returns the original training data or a direct sample from it. This is useful for comparison purposes to evaluate the effectiveness of more complex generative models. - - -### `generate_data_pipe` Method Parameters - -The `generate_data_pipe` method, available for all samplers, uses the following parameters: - -* **train_df**: `pd.DataFrame` - The training dataframe (features only, without the target variable). -* **target**: `pd.DataFrame` - The input target variable for the training dataset. -* **test_df**: `pd.DataFrame` - The test dataframe. The newly generated training dataframe should be statistically similar to this. -* **deep_copy**: `bool` (default: `True`) - Whether to make a copy of the input dataframes. If `False`, input dataframes will be modified in place. -* **only_adversarial**: `bool` (default: `False`) - If `True`, only adversarial filtering will be performed on the training dataframe; no new data will be generated. -* **use_adversarial**: `bool` (default: `True`) - Whether to perform adversarial filtering. -* **@return**: `Tuple[pd.DataFrame, pd.DataFrame]` - A tuple containing the newly generated/processed training dataframe and the corresponding target. - - -### Example Code +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `gen_x_times` | float | 1.1 | Multiplier for data generation amount | +| `cat_cols` | list | None | Column names to treat as categorical | +| `bot_filter_quantile` | float | 0.001 | Bottom quantile for post-process filtering | +| `top_filter_quantile` | float | 0.999 | Top quantile for post-process filtering | +| `is_post_process` | bool | True | Enable post-filtering | +| `pregeneration_frac` | float | 2 | Pre-generation data multiplier | +| `only_generated_data` | bool | False | Return only generated data (no original) | +| `gen_params` | dict | See below | Generator-specific parameters | + +### Generator-Specific Parameters + +**GANGenerator:** +```python +{"batch_size": 500, "patience": 25, "epochs": 500} +``` + +**LLMGenerator:** +```python +{"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500} +``` + +### generate_data_pipe Method + +| Parameter | Type | Description | +|-----------|------|-------------| +| `train_df` | pd.DataFrame | Training features | +| `target` | pd.DataFrame | Target variable | +| `test_df` | pd.DataFrame | Test features | +| `deep_copy` | bool | Make copy of input dataframes | +| `only_adversarial` | bool | Only perform adversarial filtering | +| `use_adversarial` | bool | Enable adversarial filtering | + +**Returns:** `Tuple[pd.DataFrame, pd.DataFrame]` - (new_train, new_target) + +## Data Format + +TabGAN accepts `numpy.ndarray` or `pandas.DataFrame` with: + +- **Continuous Columns**: Numerical columns with any possible value +- **Discrete Columns**: Columns with limited set values (categorical) + +> **Note:** TabGAN treats all values as floats. For integers, round the output after generation. + +## Examples + +### Basic Usage ```python from tabgan.sampler import OriginalGenerator, GANGenerator, ForestDiffusionGenerator, LLMGenerator import pandas as pd import numpy as np - -# random input data train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD")) target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y")) test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD")) -# generate data -new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test, ) -new_train2, new_target2 = GANGenerator(gen_params={"batch_size": 500, "epochs": 10, "patience": 5 }).generate_data_pipe(train, target, test, ) -new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test, ) -new_train4, new_target4 = LLMGenerator(gen_params={"batch_size": 32, - "epochs": 4, "llm": "distilgpt2", "max_length": 500}).generate_data_pipe(train, target, test, ) +# Different generators +new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test) +new_train2, new_target2 = GANGenerator(gen_params={"batch_size": 500, "epochs": 10, "patience": 5}).generate_data_pipe(train, target, test) +new_train3, new_target3 = ForestDiffusionGenerator().generate_data_pipe(train, target, test) +new_train4, new_target4 = LLMGenerator(gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}).generate_data_pipe(train, target, test) +``` + +### Full Parameter Example -# example with all params defined -new_train_gan_all_params, new_target_gan_all_params = GANGenerator( +```python +new_train, new_target = GANGenerator( gen_x_times=1.1, cat_cols=None, bot_filter_quantile=0.001, @@ -113,108 +145,133 @@ new_train_gan_all_params, new_target_gan_all_params = GANGenerator( only_adversarial=False, use_adversarial=True ) +``` +### LLM Conditional Text Generation + +Generate synthetic data with LLMs while controlling text generation based on specific conditions. This uses the internal `_generate_via_prompt` method for novel text generation. + +```python +import pandas as pd +from tabgan.sampler import LLMGenerator + +# Create sample data with text and categorical columns +train = pd.DataFrame({ + "Name": ["Anna", "Maria", "Ivan", "Sergey", "Olga", "Boris"], + "Gender": ["F", "F", "M", "M", "F", "M"], + "Age": [25, 30, 35, 40, 28, 32], + "Occupation": ["Engineer", "Doctor", "Artist", "Teacher", "Manager", "Pilot"] +}) + +# Generate new names conditioned on Gender, with other features imputed +new_train, _ = LLMGenerator( + gen_x_times=1.5, # Generate 1.5x the original data + text_generating_columns=["Name"], # Generate novel names + conditional_columns=["Gender"], # Condition on Gender column + gen_params={"batch_size": 32, "epochs": 4, "llm": "distilgpt2", "max_length": 500}, + is_post_process=False # Disable post-processing for this example +).generate_data_pipe( + train, + target=None, + test_df=None, + only_generated_data=True # Return only generated data +) + +print(new_train) ``` -Thus, you may use this library to improve your dataset quality: +**Parameters for conditional generation:** +- `text_generating_columns`: List of column names to generate novel text for +- `conditional_columns`: List of column names that condition the text generation -``` python -def fit_predict(clf, X_train, y_train, X_test, y_test): - clf.fit(X_train, y_train) - return sklearn.metrics.roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) +The model will: +1. Sample values for conditional columns from their distributions +2. Impute remaining non-text columns using the LLM +3. Generate novel text for text columns via prompt-based generation (using `_generate_via_prompt`) +4. Ensure generated text values are unique (not present in original data) +### Improving Model Performance +```python +import sklearn +from tabgan.sampler import GANGenerator +def evaluate(clf, X_train, y_train, X_test, y_test): + clf.fit(X_train, y_train) + return sklearn.metrics.roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1]) + +# Load dataset dataset = sklearn.datasets.load_breast_cancer() clf = sklearn.ensemble.RandomForestClassifier(n_estimators=25, max_depth=6) X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - pd.DataFrame(dataset.data), pd.DataFrame(dataset.target, columns=["target"]), test_size=0.33, random_state=42) -print("initial metric", fit_predict(clf, X_train, y_train, X_test, y_test)) + pd.DataFrame(dataset.data), pd.DataFrame(dataset.target, columns=["target"]), + test_size=0.33, random_state=42) -new_train1, new_target1 = OriginalGenerator().generate_data_pipe(X_train, y_train, X_test, ) -print("OriginalGenerator metric", fit_predict(clf, new_train1, new_target1, X_test, y_test)) +# Compare performance +print("Baseline:", evaluate(clf, X_train, y_train, X_test, y_test)) -new_train1, new_target1 = GANGenerator().generate_data_pipe(X_train, y_train, X_test, ) -print("GANGenerator metric", fit_predict(clf, new_train2, new_target2, X_test, y_test)) # Corrected variable name +# Generate and evaluate +new_train, new_target = GANGenerator().generate_data_pipe(X_train, y_train, X_test) +print("With GAN:", evaluate(clf, new_train, new_target, X_test, y_test)) ``` -### Advanced Usage: Generating Time-Series Data with TimeGAN - -You can easily adjust the code to generate multidimensional time-series data. This approach primarily involves extracting day, month, and year components from a date column to be used as features in the generation process. Below is a demonstration: +### Time-Series Data Generation ```python import pandas as pd import numpy as np -from tabgan.utils import get_year_mnth_dt_from_date,make_two_digit,collect_dates -from tabgan.sampler import OriginalGenerator, GANGenerator +from tabgan.utils import get_year_mnth_dt_from_date, collect_dates +from tabgan.sampler import GANGenerator - -train_size = 100 -train = pd.DataFrame( - np.random.randint(-10, 150, size=(train_size, 4)), columns=list("ABCD") - ) -min_date = pd.to_datetime('2019-01-01') -max_date = pd.to_datetime('2021-12-31') +train = pd.DataFrame(np.random.randint(-10, 150, size=(100, 4)), columns=list("ABCD")) +min_date, max_date = pd.to_datetime('2019-01-01'), pd.to_datetime('2021-12-31') d = (max_date - min_date).days + 1 - -train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d') +train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=100), unit='d') train = get_year_mnth_dt_from_date(train, 'Date') -new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001, - top_filter_quantile=0.999, - is_post_process=True, pregeneration_frac=2, only_generated_data=False).\ - generate_data_pipe(train.drop('Date', axis=1), None, - train.drop('Date', axis=1) - ) +new_train, _ = GANGenerator( + gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001, + top_filter_quantile=0.999, is_post_process=True, pregeneration_frac=2, + only_generated_data=False +).generate_data_pipe( + train.drop('Date', axis=1), None, train.drop('Date', axis=1) +) new_train = collect_dates(new_train) ``` -## Experiments -### Datasets and experiment design - -**Check for data generation quality** -Just use built-in function -``` -compare_dataframes(original_df, generated_df) # return between 0 and 1 -``` -**Running experiment** - -To run experiment follow these steps: +## Data Quality Validation -1. Clone the repository. All required datasets are stored in `./Research/data` folder. -2. Install requirements: `pip install -r requirements.txt` -3. Run experiments using `python ./Research/run_experiment.py`. You may - add more datasets, adjust validation type, and categorical encoders. -4. Observe metrics across all experiments in the console or in `./Research/results/fit_predict_scores.txt`. +Check generated data quality using the built-in function: +```python +from tabgan.utils import compare_dataframes -**Experiment design** +quality_score = compare_dataframes(original_df, generated_df) # Returns value between 0 and 1 +``` +### Experiment Workflow ![Experiment design and workflow](images/workflow.png) -**Picture 1.1** Experiment design and workflow +## Benchmark Results -## Results -The table below (Table 1.2) shows ROC AUC scores for different sampling strategies. To facilitate comparison across datasets with potentially different baseline AUC scores, the ROC AUC scores for each dataset were scaled using min-max normalization (where the maximum score achieved by any method on that dataset becomes 1, and the minimum becomes 0). These scaled scores were then averaged across all datasets for each sampling strategy. Therefore, a higher value in the table indicates better relative performance in generating data that is difficult for a classifier to distinguish from the original data, when compared to other methods on the same set of datasets. +The following table shows normalized ROC AUC scores (higher is better): -**Table 1.2** Averaged Min-Max Scaled ROC AUC scores for different sampling strategies across datasets. Higher is better (closer to 1 indicates performance similar to the best method on each dataset). - -| dataset_name | None | gan | sample_original | -|:-----------------------|-------------------:|------------------:|------------------------------:| -| credit | 0.997 | **0.998** | 0.997 | -| employee | **0.986** | 0.966 | 0.972 | -| mortgages | 0.984 | 0.964 | **0.988** | -| poverty_A | 0.937 | **0.950** | 0.933 | -| taxi | 0.966 | 0.938 | **0.987** | -| adult | 0.995 | 0.967 | **0.998** | +| Dataset | None | GAN | Sample Original | +|---------|------|-----|-----------------| +| credit | 0.997 | **0.998** | 0.997 | +| employee | **0.986** | 0.966 | 0.972 | +| mortgages | 0.984 | 0.964 | **0.988** | +| poverty_A | 0.937 | **0.950** | 0.933 | +| taxi | 0.966 | 0.938 | **0.987** | +| adult | 0.995 | 0.967 | **0.998** | ## Citation -If you use **tabgan** in a scientific publication, we would appreciate references to the following BibTex entry: -arxiv publication: +If you use TabGAN in your research, please cite: + ```bibtex @misc{ashrapov2020tabular, - title={Tabular GANs for uneven distribution}, + title={Tabular GANs for uneven distribution}, author={Insaf Ashrapov}, year={2020}, eprint={2010.00638}, @@ -227,8 +284,12 @@ arxiv publication: [1] Xu, L., & Veeramachaneni, K. (2018). *Synthesizing Tabular Data using Generative Adversarial Networks*. arXiv:1811.11264 [cs.LG]. -[2] Jolicoeur-Martineau, A., Fatras, K., & Kachman, T. (2023). *Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees*. Retrieved from https://github.com/SamsungSAILMontreal/ForestDiffusion. +[2] Jolicoeur-Martineau, A., Fatras, K., & Kachman, T. (2023). *Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees*. SamsungSAILMontreal/ForestDiffusion. [3] Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019). *Modeling Tabular data using Conditional GAN*. NeurIPS. [4] Borisov, V., Sessler, K., Leemann, T., Pawelczyk, M., & Kasneci, G. (2023). *Language Models are Realistic Tabular Data Generators*. ICLR. + +## License + +Apache License 2.0 - See [LICENSE](LICENSE) file for details. diff --git a/images/workflow.png b/images/workflow.png index c1d75b1..06d0081 100644 Binary files a/images/workflow.png and b/images/workflow.png differ diff --git a/requirements.txt b/requirements.txt index 42c07bb..4ac3b01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,12 @@ -scipy>=1.4.1 +scipy>=1.12.0 category_encoders>=2.6.3 -numpy>=1.22.0 torch>=1.6.0 -pandas>=1.2.2 +pandas>=2.2.0 lightgbm>=2.2.3 scikit_learn>=1.5.2 torchvision>=0.4.2 -numpy>=2.0 -python-dateutil==2.8.1 +numpy>=1.23.1 +python-dateutil>=2.8.2 tqdm>=4.61.1 xgboost>=2.0.0 -be-great==0.0.8 \ No newline at end of file +be-great==0.0.13 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index db13bba..0ad8e13 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = python-dateutil tqdm xgboost - be-great + be-great>=0.0.13 [options.packages.find] where = src @@ -66,15 +66,8 @@ testing = pytest-cov [options.entry_points] -# Add here console scripts like: -# console_scripts = -# script_name = tabgan.module:function -# For example: -# console_scripts = -# fibonacci = tabgan.skeleton:run -# And any other entry points, for example: -# pyscaffold.cli = -# awesome = pyscaffoldext.awesome.extension:AwesomeExtension +console_scripts = + tabgan-generate = tabgan.cli:main [test] # py.test options when running `python setup.py test` diff --git a/src/tabgan/abc_sampler.py b/src/tabgan/abc_sampler.py index b44a5ed..f8ae341 100644 --- a/src/tabgan/abc_sampler.py +++ b/src/tabgan/abc_sampler.py @@ -83,7 +83,35 @@ def generate_data_pipe( class Sampler(ABC): - """Interface for each sampling strategy""" + """Interface for each sampling strategy. + + Concrete sampler implementations share a common configuration interface + (generation factor, categorical columns, post-processing flags, etc.). + This base ``__init__`` stores those shared parameters so that subclasses + can call ``super().__init__(...)`` and focus on strategy-specific logic. + """ + + def __init__( + self, + gen_x_times: float, + cat_cols: list | None, + bot_filter_quantile: float, + top_filter_quantile: float, + is_post_process: bool, + adversarial_model_params: dict, + pregeneration_frac: float, + only_generated_data: bool, + gen_params: dict | None = None, + ) -> None: + self.gen_x_times = gen_x_times + self.cat_cols = cat_cols + self.bot_filter_quantile = bot_filter_quantile + self.top_filter_quantile = top_filter_quantile + self.is_post_process = is_post_process + self.adversarial_model_params = adversarial_model_params + self.pregeneration_frac = pregeneration_frac + self.only_generated_data = only_generated_data + self.gen_params = gen_params or {} def get_generated_shape(self, input_df): """Calculates final output shape""" diff --git a/src/tabgan/cli.py b/src/tabgan/cli.py new file mode 100644 index 0000000..a2f6025 --- /dev/null +++ b/src/tabgan/cli.py @@ -0,0 +1,129 @@ +import argparse +import logging +from typing import List, Optional + +import pandas as pd + +from tabgan.sampler import ( + OriginalGenerator, + GANGenerator, + ForestDiffusionGenerator, + LLMGenerator, +) + + +def _parse_cat_cols(raw: Optional[str]) -> Optional[List[str]]: + if not raw: + return None + return [c.strip() for c in raw.split(",") if c.strip()] + + +def main() -> None: + """ + Command-line interface for generating synthetic tabular data with tabgan. + + Example: + tabgan-generate \\ + --input-csv train.csv \\ + --target-col target \\ + --generator gan \\ + --gen-x-times 1.5 \\ + --cat-cols year,gender \\ + --output-csv synthetic_train.csv + """ + parser = argparse.ArgumentParser( + description="Generate synthetic tabular data using tabgan samplers." + ) + parser.add_argument( + "--input-csv", + required=True, + help="Path to input CSV file containing training data (with or without target column).", + ) + parser.add_argument( + "--target-col", + default=None, + help="Name of the target column in the CSV (optional).", + ) + parser.add_argument( + "--output-csv", + required=True, + help="Path to write the generated synthetic dataset as CSV.", + ) + parser.add_argument( + "--generator", + choices=["original", "gan", "diffusion", "llm"], + default="gan", + help="Which sampler to use for generation.", + ) + parser.add_argument( + "--gen-x-times", + type=float, + default=1.1, + help="Factor controlling how many synthetic samples to generate relative to the training size.", + ) + parser.add_argument( + "--cat-cols", + default=None, + help="Comma-separated list of categorical column names (e.g. 'year,gender').", + ) + parser.add_argument( + "--only-generated", + action="store_true", + help="If set, output only synthetic rows instead of original + synthetic.", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + logging.info("Reading input CSV from %s", args.input_csv) + df = pd.read_csv(args.input_csv) + + target_df = None + train_df = df + if args.target_col is not None: + if args.target_col not in df.columns: + raise ValueError(f"Target column '{args.target_col}' not found in input CSV.") + target_df = df[[args.target_col]] + train_df = df.drop(columns=[args.target_col]) + + cat_cols = _parse_cat_cols(args.cat_cols) + + generator_map = { + "original": OriginalGenerator, + "gan": GANGenerator, + "diffusion": ForestDiffusionGenerator, + "llm": LLMGenerator, + } + generator_cls = generator_map[args.generator] + + logging.info("Initializing %s generator", generator_cls.__name__) + generator = generator_cls( + gen_x_times=args.gen_x_times, + cat_cols=cat_cols, + only_generated_data=bool(args.only_generated), + ) + + # Use train_df itself as test_df when a dedicated hold-out set is not provided. + logging.info("Generating synthetic data...") + new_train, new_target = generator.generate_data_pipe( + train_df, target_df, train_df + ) + + if new_target is not None and args.target_col is not None: + out_df = new_train.copy() + # new_target can be DataFrame or Series; align to a 1D array + if hasattr(new_target, "values") and new_target.ndim > 1: + out_df[args.target_col] = new_target.values.ravel() + else: + out_df[args.target_col] = new_target + else: + out_df = new_train + + logging.info("Writing synthetic data to %s", args.output_csv) + out_df.to_csv(args.output_csv, index=False) + + +if __name__ == "__main__": + main() + diff --git a/src/tabgan/sampler.py b/src/tabgan/sampler.py index 9f59a73..09a50f6 100644 --- a/src/tabgan/sampler.py +++ b/src/tabgan/sampler.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- import logging +import numpy as np import warnings from typing import Tuple -import numpy as np import pandas as pd import torch from be_great import GReaT @@ -80,34 +80,55 @@ def __init__( pregeneration_frac: float = 2, only_generated_data: bool = False, gen_params: dict = {"batch_size": 45, 'patience': 25, "epochs": 50, "llm": "distilgpt2"}, + text_generating_columns: list = None, + conditional_columns: list = None, ): """ + Initialize an original sampler configuration. - @param gen_x_times: float = 1.1 - how much data to generate, output might be less because of postprocessing and - adversarial filtering - @param cat_cols: list = None - categorical columns - @param bot_filter_quantile: float = 0.001 - bottom quantile for postprocess filtering - @param top_filter_quantile: float = 0.999 - top quantile for postprocess filtering - @param is_post_process: bool = True - perform or not postfiltering, if false bot_filter_quantile - and top_filter_quantile ignored - @param adversarial_model_params: dict params for adversarial filtering model, default values for binary task - @param pregeneration_frac: float = 2 - for generation step gen_x_times * pregeneration_frac amount of data - will be generated. However, in postprocessing (1 + gen_x_times) % of original data will be returned - @param only_generated_data: bool = False If True after generation get only newly generated, without - concatenating input train dataframe. - @param gen_params: dict params for GAN training. Only works for SamplerGAN, ForestDiffusionGenerator, - LLMGenerator. + Args: + gen_x_times (float): Factor controlling how many synthetic samples + to generate relative to the training size. The final amount + can be smaller after post-processing and adversarial filtering. + cat_cols (list | None): Names of categorical columns in the + training data. + bot_filter_quantile (float): Lower quantile used for numeric + post-processing filters. + top_filter_quantile (float): Upper quantile used for numeric + post-processing filters. + is_post_process (bool): Whether to apply post-processing filters + based on the distribution of `test_df`. If False, the + quantile-based filters are skipped. + adversarial_model_params (dict): Parameters for the adversarial + filtering model used to keep generated samples close to the + test distribution. + pregeneration_frac (float): Oversampling factor applied before + post-processing. The final number of rows is derived from + `gen_x_times`. + only_generated_data (bool): If True, return only synthetic rows. + If False, append generated rows to the original training data. + gen_params (dict): Model-specific generation parameters shared by + subclasses (GAN, ForestDiffusion, LLM). + text_generating_columns (list | None): Column names for which new + text values should be generated (used by `SamplerLLM`). + conditional_columns (list | None): Column names that condition + text generation for `text_generating_columns`. """ - self.gen_x_times = gen_x_times - self.cat_cols = cat_cols - self.is_post_process = is_post_process - self.bot_filter_quantile = bot_filter_quantile - self.top_filter_quantile = top_filter_quantile - self.adversarial_model_params = adversarial_model_params - self.pregeneration_frac = pregeneration_frac - self.only_generated_data = only_generated_data - self.gen_params = gen_params - self.TEMP_TARGET = "TEMP_TARGET" + super().__init__( + gen_x_times=gen_x_times, + cat_cols=cat_cols, + bot_filter_quantile=bot_filter_quantile, + top_filter_quantile=top_filter_quantile, + is_post_process=is_post_process, + adversarial_model_params=adversarial_model_params, + pregeneration_frac=pregeneration_frac, + only_generated_data=only_generated_data, + gen_params=gen_params, + ) + self.text_generating_columns = text_generating_columns + self.conditional_columns = conditional_columns + if not hasattr(self, "TEMP_TARGET"): + self.TEMP_TARGET = "TEMP_TARGET" @staticmethod def preprocess_data_df(df) -> pd.DataFrame: @@ -244,27 +265,26 @@ def _validate_data(train_df, target, test_df): def handle_generated_data(self, train_df, generated_df, only_generated_data): """ - Integrates synthetic data with the original dataset by preserving data types - and structural alignment. + Align and optionally merge generated rows with the original training data. - This method transforms generated data to match the original dataset's structure - and types. It can either combine synthetic with original data or return only - the synthetic data. + The generated data is cast to the dtypes and column order of `train_df` + so that downstream models receive data with a consistent schema. Args: - train_df: The original dataset that defines the expected structure - generated_df: The synthetic data to be processed - only_generated_data: Boolean flag to return only synthetic data + train_df (pd.DataFrame): Original training data used to infer the + schema and target column. + generated_df (pd.DataFrame or array-like): Newly generated + samples to be aligned with `train_df`. + only_generated_data (bool): If True, return only synthetic rows; + otherwise, append them to `train_df` before returning. Returns: - A tuple containing: - - Feature matrix (with or without original data) - - Corresponding target vector + Tuple[pd.DataFrame, pd.Series | pd.DataFrame]: Features and + corresponding target values. """ generated_df = pd.DataFrame(generated_df) generated_df.columns = train_df.columns - # Preserve original data types for column_index in range(len(generated_df.columns)): target_column = generated_df.columns[column_index] generated_df[target_column] = generated_df[target_column].astype( @@ -373,25 +393,231 @@ def check_params(self): self.gen_params["epochs"])) self.gen_params["epochs"] = 3 + def _build_training_frame(self, train_df: pd.DataFrame, target: pd.DataFrame | None) -> pd.DataFrame: + """ + Return a copy of the training frame with TEMP_TARGET attached when a target is provided. + """ + current_train_df = train_df.copy() + if target is not None: + current_train_df[self.TEMP_TARGET] = target + return current_train_df + + def _fit_great_model(self, current_train_df: pd.DataFrame): + """ + Fit a GReaT model on the provided training frame and return the instance and inference device. + """ + logging.info("Fitting LLM model") + is_fp16 = torch.cuda.is_available() + try: + from be_great import GReaT + except ImportError: + raise ImportError("be_great library is not installed. Please install it to use LLMGenerator.") + + great_model_instance = GReaT( + llm=self.gen_params["llm"], + batch_size=self.gen_params["batch_size"], + epochs=self.gen_params["epochs"], + fp16=is_fp16, + ) + great_model_instance.fit(current_train_df) + logging.info("Finished training LLM model") + + device = "cuda" if torch.cuda.is_available() else "cpu" + return great_model_instance, device + + def _conditional_text_generation( + self, + great_model_instance, + current_train_df: pd.DataFrame, + train_df: pd.DataFrame, + target: pd.DataFrame | None, + device: str, + ) -> pd.DataFrame: + """ + Generate rows when text and conditional columns are specified. + """ + logging.info("Starting conditional generation of text columns.") + num_samples_to_generate = int(self.gen_x_times * train_df.shape[0]) + + original_unique_text_values: dict[str, set] = {} + for col in self.text_generating_columns: + if col not in current_train_df.columns: + raise ValueError(f"Text generating column '{col}' not found in training data.") + original_unique_text_values[col] = set(current_train_df[col].unique()) + + attribute_distributions: dict[str, pd.Series] = {} + for col in self.conditional_columns: + if col not in current_train_df.columns: + raise ValueError(f"Conditional column '{col}' not found in training data.") + attribute_distributions[col] = current_train_df[col].value_counts(normalize=True) + + generated_rows: list[dict] = [] + all_train_columns = current_train_df.columns.tolist() + + for _ in range(num_samples_to_generate): + current_row_data: dict = {} + + for attr_col in self.conditional_columns: + dist = attribute_distributions[attr_col] + current_row_data[attr_col] = np.random.choice(dist.index, p=dist.values) + + row_template_for_impute = pd.DataFrame(columns=all_train_columns, index=[0]) + for col in all_train_columns: + if col in current_row_data: + row_template_for_impute.loc[0, col] = current_row_data[col] + elif col not in self.text_generating_columns: + row_template_for_impute.loc[0, col] = np.nan + + imputed_full_row_df = great_model_instance.impute( + row_template_for_impute.copy(), + max_length=self.gen_params.get("max_length", 500), + ) + + for col in all_train_columns: + if col not in self.text_generating_columns and col not in current_row_data: + current_row_data[col] = imputed_full_row_df.loc[0, col] + + for text_col in self.text_generating_columns: + prompt_parts: list[str] = [] + for cond_col in self.conditional_columns: + prompt_parts.append(f"{cond_col}: {current_row_data[cond_col]}") + for other_col in all_train_columns: + if ( + other_col not in self.text_generating_columns + and other_col not in self.conditional_columns + and other_col in current_row_data + ): + val_str = str(current_row_data[other_col]) + if len(val_str) > 30: + val_str = val_str[:27] + "..." + prompt_parts.append(f"{other_col}: {val_str}") + + prompt = ", ".join(prompt_parts) + f", Generate {text_col}: " + + generated_text_candidate = None + max_retries = 10 + for _retry_attempt in range(max_retries): + generated_text_candidate = self._generate_via_prompt( + prompt, + great_model_instance, + device=device, + ) + if generated_text_candidate not in original_unique_text_values[text_col]: + break + else: + logging.warning( + f"Max retries reached for generating novel text for {text_col}. Using last candidate." + ) + current_row_data[text_col] = generated_text_candidate + + ordered_row = {col: current_row_data.get(col) for col in train_df.columns} + if target is not None and self.TEMP_TARGET in current_row_data: + ordered_row[self.TEMP_TARGET] = current_row_data[self.TEMP_TARGET] + + generated_rows.append(ordered_row) + + generated_df = pd.DataFrame(generated_rows) + return generated_df.reindex(columns=current_train_df.columns) + + def _standard_llm_sampling( + self, + great_model_instance, + current_train_df: pd.DataFrame, + device: str, + ) -> pd.DataFrame: + """ + Fallback sampling when no explicit text/conditional columns are provided. + """ + logging.info("Starting standard LLM sampling.") + return great_model_instance.sample( + int(self.gen_x_times * current_train_df.shape[0]), + device=device, + max_length=self.gen_params["max_length"], + ) + def generate_data( self, train_df, target, test_df, only_generated_data: bool ) -> Tuple[pd.DataFrame, pd.DataFrame]: self._validate_data(train_df, target, test_df) self.check_params() - if target is not None: - train_df[self.TEMP_TARGET] = target - logging.info("Fitting LLM model") - is_fp16 = torch.cuda.is_available() - model = GReaT(llm=self.gen_params["llm"], batch_size=self.gen_params["batch_size"], - epochs=self.gen_params["epochs"], fp16=is_fp16) - model.fit(train_df) - logging.info("Finished training ForestDiffusionModel") - device = "cuda" if torch.cuda.is_available() else "cpu" + current_train_df = self._build_training_frame(train_df, target) + great_model_instance, device = self._fit_great_model(current_train_df) - generated_df = model.sample(int(self.gen_x_times * train_df.shape[0]), device=device, - max_length=self.gen_params["max_length"]) - return self.handle_generated_data(train_df, generated_df, only_generated_data) + if self.text_generating_columns and self.conditional_columns: + generated_df = self._conditional_text_generation( + great_model_instance, + current_train_df=current_train_df, + train_df=train_df, + target=target, + device=device, + ) + else: + generated_df = self._standard_llm_sampling( + great_model_instance, + current_train_df=current_train_df, + device=device, + ) + + # When a target is provided, ``current_train_df`` already includes the + # TEMP_TARGET column and represents the true training frame used for + # generation. Passing it to ``handle_generated_data`` keeps feature and + # target alignment consistent for both conditional and standard LLM + # sampling paths. + base_train_for_handling = current_train_df if target is not None else train_df + return self.handle_generated_data(base_train_for_handling, generated_df, only_generated_data) + + def _generate_via_prompt(self, prompt: str, great_model_instance, device: str, max_tokens_to_generate=50) -> str: + """ + Generate a short text completion from the underlying GReaT LLM. + + Args: + prompt (str): Serialized row description used as generation context. + great_model_instance: Fitted GReaT instance providing `model` and + `tokenizer` attributes. + device (str): Target device for inference (for example, ``"cpu"`` + or ``"cuda"``). + max_tokens_to_generate (int): Maximum number of new tokens to + sample from the model. + + Returns: + str: Post-processed generated text. Returns an empty string if + generation fails. + """ + llm_model = great_model_instance.model + tokenizer = great_model_instance.tokenizer + + if llm_model is None or tokenizer is None: + logging.error("LLM model or tokenizer not available in GReaT instance.") + return "" # Or raise an error + + llm_model.to(device) + + inputs = tokenizer(prompt, return_tensors="pt", truncation=True, + max_length=tokenizer.model_max_length - max_tokens_to_generate) + input_ids = inputs.input_ids.to(device) + attention_mask = inputs.attention_mask.to(device) + + try: + outputs = llm_model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=max_tokens_to_generate, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, # Enable sampling for more diverse outputs + temperature=0.7, # Default temperature, can be tuned + top_k=50, # Default top_k, can be tuned + top_p=0.95 # Default top_p, can be tuned + ) + generated_text = tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True) + + generated_text = generated_text.split('\n')[0].split('|')[0].strip() + + return generated_text + + except Exception as e: + logging.error(f"Error during text generation via prompt: {e}") + return "" # Fallback or re-raise if __name__ == "__main__": @@ -411,7 +637,7 @@ def generate_data( GANGenerator(cat_cols=["A"], gen_x_times=20, only_generated_data=True), ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=10, only_generated_data=True), ForestDiffusionGenerator(gen_x_times=15, only_generated_data=False, - gen_params={"batch_size": 500, "patience": 25, "epochs": 500}) + gen_params={"batch_size": 500, "patience": 25, "epochs": 500}) ] for gen in generators: diff --git a/tests/conftest.py b/tests/conftest.py index ffa0de2..d97df36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,21 @@ # -*- coding: utf-8 -*- """ - Dummy conftest.py for tabgan. +Test configuration for tabgan. - If you don't know what this is for, just leave it empty. - Read more about conftest.py under: - - https://docs.pytest.org/en/stable/fixture.html - - https://docs.pytest.org/en/stable/writing_plugins.html +We ensure that the project `src` directory is on ``sys.path`` so that both +``src.tabgan`` and sibling top-level packages such as ``_ForestDiffusion`` +are importable when running tests from the repository root. """ -# import pytest +import os +import sys +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +SRC_PATH = PROJECT_ROOT / "src" + +src_str = os.fspath(SRC_PATH) +if src_str not in sys.path: + sys.path.insert(0, src_str) + diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..b4886d9 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,56 @@ +import os +import subprocess +import sys +import tempfile + +import numpy as np +import pandas as pd + + +def test_tabgan_generate_cli_creates_output_with_target(): + # Prepare temporary input and output CSV paths + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, "train.csv") + output_path = os.path.join(tmpdir, "synthetic.csv") + + # Small dummy dataset with a target column + df = pd.DataFrame( + np.random.randint(0, 10, size=(20, 3)), + columns=["A", "B", "target"], + ) + df.to_csv(input_path, index=False) + + # Set up environment with PYTHONPATH pointing to src directory + env = os.environ.copy() + src_path = os.path.join(os.path.dirname(__file__), "..", "src") + env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "") + + # Invoke the installed console script through Python -m to avoid PATH issues in CI + cmd = [ + sys.executable, + "-m", + "tabgan.cli", + "--input-csv", + input_path, + "--target-col", + "target", + "--generator", + "original", + "--gen-x-times", + "1.0", + "--output-csv", + output_path, + ] + + result = subprocess.run(cmd, check=False, capture_output=True, text=True, env=env) + + assert result.returncode == 0, f"CLI failed: {result.stderr}" + assert os.path.exists(output_path), "Output CSV was not created by CLI" + + out_df = pd.read_csv(output_path) + + # Target column must be present + assert "target" in out_df.columns + # At least as many rows as the original (OriginalGenerator samples with replacement) + assert len(out_df) >= len(df) + diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 34ce968..2ee45c7 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -5,10 +5,11 @@ __license__ = "Apache 2.0" from unittest import TestCase +from unittest.mock import patch, MagicMock, call import numpy as np import pandas as pd -from src.tabgan.sampler import OriginalGenerator, Sampler, GANGenerator, ForestDiffusionGenerator, LLMGenerator +from src.tabgan.sampler import OriginalGenerator, Sampler, GANGenerator, ForestDiffusionGenerator, LLMGenerator, SamplerLLM class TestOriginalGenerator(TestCase): @@ -94,6 +95,337 @@ def test_generate_data(self): self.assertTrue(gen_train.shape[0] > new_train.shape[0]) self.assertEqual(np.max(self.target.nunique()), np.max(new_target.nunique())) + +class TestSamplerLLMConditional(TestCase): + def setUp(self): + self.train_df = pd.DataFrame({ + "Name": ["Anna", "Maria", "Ivan", "Sergey"], + "Gender": ["F", "F", "M", "M"], + "Age": [25, 30, 35, 40], + "Occupation": ["Engineer", "Doctor", "Artist", "Teacher"] + }) + self.target_df = pd.DataFrame({"Y": [0, 1, 0, 1]}) # Can be None if not used for LLM + # test_df is used for postprocessing, might not be strictly needed for all LLM tests if postprocessing is off + self.test_df = pd.DataFrame({ + "Name": ["Olga", "Boris", "Svetlana"], + "Gender": ["F", "M", "F"], + "Age": [28, 32, 45], + "Occupation": ["Manager", "Pilot", "Scientist"] + }) + + # Default gen_params for LLMGenerator + self.gen_params = {"batch_size": 32, "epochs": 1, "llm": "distilgpt2", "max_length": 50} + + + @patch.object(SamplerLLM, "_fit_great_model") + @patch.object(SamplerLLM, "_generate_via_prompt") + def test_conditional_generation_basic(self, mock_generate_prompt, mock_fit_great): + # --- Mock GReaT setup (via patched _fit_great_model) --- + mock_great_instance = MagicMock() + mock_great_instance.fit.return_value = None + mock_great_instance.model = MagicMock() # mock the underlying llm + mock_great_instance.tokenizer = MagicMock() # mock the tokenizer + + # Configure mock_great_instance.impute + # It should take a DataFrame and fill NaNs in 'Age' and 'Occupation' for this test + def mock_impute_logic(df_to_impute, max_length): + df_imputed = df_to_impute.copy() + if "Age" in df_imputed.columns and pd.isna(df_imputed.loc[0, "Age"]): + df_imputed.loc[0, "Age"] = 33 # Predictable age + if "Occupation" in df_imputed.columns and pd.isna(df_imputed.loc[0, "Occupation"]): + # Based on Gender if available + gender = df_imputed.loc[0, "Gender"] + df_imputed.loc[0, "Occupation"] = "MockOccupationF" if gender == "F" else "MockOccupationM" + return df_imputed + mock_great_instance.impute.side_effect = mock_impute_logic + mock_fit_great.return_value = (mock_great_instance, "cpu") + + # Configure mock_generate_prompt for "Name" + # It needs to return different names based on gender and ensure novelty + # Store original names to check against for novelty + original_names = set(self.train_df["Name"].unique()) + + # Use a dict to track generated names per gender to ensure novelty within test + generated_names_by_gender = {"F": [], "M": []} + + def mock_prompt_logic(prompt_text, great_model_inst, device): + # Simplified logic: just check for gender in prompt + if "Gender: F" in prompt_text: + candidate = "Laura" + if candidate in original_names or candidate in generated_names_by_gender["F"]: + candidate = "Sophia" # Next novel female name + generated_names_by_gender["F"].append(candidate) + return candidate + elif "Gender: M" in prompt_text: + candidate = "Peter" + if candidate in original_names or candidate in generated_names_by_gender["M"]: + candidate = "David" # Next novel male name + generated_names_by_gender["M"].append(candidate) + return candidate + return "UnknownName" + mock_generate_prompt.side_effect = mock_prompt_logic + + # --- LLMGenerator setup --- + llm_generator = LLMGenerator( + gen_x_times=0.5, # Generate 2 new samples (0.5 * 4 original) + text_generating_columns=["Name"], + conditional_columns=["Gender"], + gen_params=self.gen_params, + # Disable post_process and adversarial for simpler focused test + is_post_process=False + ) + # --- Run generation --- + # For this test, target can be None as LLMGenerator handles it internally if provided + # and we are mostly concerned with feature generation. + # test_df is also not strictly necessary if is_post_process=False + new_train_df, _ = llm_generator.generate_data_pipe( + self.train_df.copy(), + target=None, # Or self.target_df.copy() if testing with target + test_df=None, # test_df not needed when postprocessing is disabled + only_generated_data=True, # Focus on generated samples + ) + + # --- Assertions --- + self.assertEqual(len(new_train_df), 2) # 0.5 * 4 samples + + # Check that _generate_via_prompt was called for each new sample + self.assertEqual(mock_generate_prompt.call_count, 2) + + # Check generated names and their novelty + for index, row in new_train_df.iterrows(): + name = row["Name"] + gender = row["Gender"] + self.assertNotIn(name, original_names) + if gender == "F": + self.assertIn(name, ["Laura", "Sophia"]) + elif gender == "M": + self.assertIn(name, ["Peter", "David"]) + + # Check imputed values + self.assertEqual(row["Age"], 33) + expected_occupation = "MockOccupationF" if gender == "F" else "MockOccupationM" + self.assertEqual(row["Occupation"], expected_occupation) + + # Check gender distribution (simple check for this small sample size) + # Ensure that the generated names align with their conditioned gender from the input. + # The actual distribution preservation is statistical over many samples. + # Here, we mainly check if the conditioning worked for each sample. + generated_F = new_train_df[new_train_df["Gender"] == "F"] + generated_M = new_train_df[new_train_df["Gender"] == "M"] + + # Depending on how attributes are sampled, we might get 1F/1M or 2F/0M or 0F/2M for 2 samples. + # The mock_prompt_logic ensures Name matches Gender. + # The attribute_distributions sampling in generate_data should pick F/M with 0.5 prob each. + + for _, row in generated_F.iterrows(): + self.assertIn(row["Name"], ["Laura", "Sophia"]) + for _, row in generated_M.iterrows(): + self.assertIn(row["Name"], ["Peter", "David"]) + + @patch.object(SamplerLLM, "_fit_great_model") + @patch.object(SamplerLLM, "_generate_via_prompt") + def test_llm_generator_fallback_behavior(self, mock_generate_prompt, mock_fit_great): + # --- Mock GReaT setup for standard sampling (via patched _fit_great_model) --- + mock_great_instance = MagicMock() + mock_great_instance.fit.return_value = None + + # Expected columns for the dummy generated data by great_model_instance.sample + # This should match self.train_df columns + self.target_df column if target is used + # For this test, assuming target is None for simplicity in LLMGenerator call + sample_columns = self.train_df.columns.tolist() + + # Create dummy data that model.sample() would return + dummy_sampled_data = pd.DataFrame([ + ["SampledName1", "F", 50, "SampledOccupation1"], + ["SampledName2", "M", 55, "SampledOccupation2"] + ], columns=sample_columns) + mock_great_instance.sample.return_value = dummy_sampled_data + mock_fit_great.return_value = (mock_great_instance, "cpu") + + # --- LLMGenerator setup (no text_generating_columns) --- + llm_generator = LLMGenerator( + gen_x_times=0.5, # Generate 2 samples + gen_params=self.gen_params, + is_post_process=False + ) + # --- Run generation --- + new_train_df, _ = llm_generator.generate_data_pipe( + self.train_df.copy(), + target=None, + test_df=None, + only_generated_data=True, + ) + + # --- Assertions --- + self.assertEqual(len(new_train_df), 2) + mock_great_instance.sample.assert_called_once() + mock_generate_prompt.assert_not_called() + + # Check if the output matches the dummy_sampled_data + pd.testing.assert_frame_equal(new_train_df.reset_index(drop=True), dummy_sampled_data.reset_index(drop=True)) + +class TestSamplerLLMWithTarget(TestCase): + def setUp(self): + self.train_df = pd.DataFrame({ + "Name": ["Anna", "Maria", "Ivan", "Sergey"], + "Gender": ["F", "F", "M", "M"], + "Age": [25, 30, 35, 40], + "Occupation": ["Engineer", "Doctor", "Artist", "Teacher"], + }) + self.target_df = pd.DataFrame({"Y": [0, 1, 0, 1]}) + self.gen_params = {"batch_size": 32, "epochs": 3, "llm": "distilgpt2", "max_length": 50} + + @patch.object(SamplerLLM, "_fit_great_model") + @patch.object(SamplerLLM, "_generate_via_prompt") + def test_conditional_generation_with_target(self, mock_generate_prompt, mock_fit_great): + # Configure mocked GReaT instance + mock_great_instance = MagicMock() + + def mock_impute_logic(df_to_impute, max_length): + df_imputed = df_to_impute.copy() + # Ensure TEMP_TARGET is imputed so that generated targets are not all NaN + if "TEMP_TARGET" in df_imputed.columns and pd.isna(df_imputed.loc[0, "TEMP_TARGET"]): + df_imputed.loc[0, "TEMP_TARGET"] = 1 + # Fill numeric/text fields to avoid NaNs in generated features + for col in ["Age", "Occupation"]: + if col in df_imputed.columns and pd.isna(df_imputed.loc[0, col]): + df_imputed.loc[0, col] = {"Age": 33, "Occupation": "MockOccupation"}.get(col) + return df_imputed + + mock_great_instance.impute.side_effect = mock_impute_logic + mock_fit_great.return_value = (mock_great_instance, "cpu") + + # Simple prompt generation just to avoid calling the real model + mock_generate_prompt.return_value = "GeneratedName" + + llm_generator = LLMGenerator( + gen_x_times=0.5, + text_generating_columns=["Name"], + conditional_columns=["Gender"], + gen_params=self.gen_params, + is_post_process=False, + ) + llm_sampler = llm_generator.get_object_generator() + + # Call SamplerLLM.generate_data directly with a non-None target + new_train_df, new_target = llm_sampler.generate_data( + self.train_df.copy(), + self.target_df.copy(), + test_df=None, + only_generated_data=True, + ) + + # We expect 0.5 * 4 = 2 generated rows and aligned target + self.assertEqual(len(new_train_df), 2) + self.assertIsNotNone(new_target) + self.assertEqual(len(new_target), 2) + + # TEMP_TARGET should be present in the frame passed to _fit_great_model + passed_train_df = mock_fit_great.call_args[0][0] + self.assertIn("TEMP_TARGET", passed_train_df.columns) + # Original target values should be copied into TEMP_TARGET for training + self.assertTrue((passed_train_df["TEMP_TARGET"].reset_index(drop=True) == self.target_df["Y"]).all()) + + @patch.object(SamplerLLM, "_fit_great_model") + @patch.object(SamplerLLM, "_generate_via_prompt") + def test_novelty_retry_logic(self, mock_generate_prompt, mock_fit_great): + # Train data with a single unique name that should be treated as "non-novel" + train_df = pd.DataFrame({ + "Name": ["Anna", "Anna"], + "Gender": ["F", "F"], + "Age": [25, 30], + "Occupation": ["Engineer", "Doctor"], + }) + + mock_great_instance = MagicMock() + + def mock_impute_logic(df_to_impute, max_length): + # Just return the same frame; we only care about the text column here + return df_to_impute.fillna({"Age": 33, "Occupation": "MockOccupation"}) + + mock_great_instance.impute.side_effect = mock_impute_logic + mock_fit_great.return_value = (mock_great_instance, "cpu") + + # First call returns a non-novel name (present in original data), + # second call returns a novel one to exercise retry logic. + mock_generate_prompt.side_effect = ["Anna", "NewAnna"] + + llm_generator = LLMGenerator( + gen_x_times=0.5, # 1 new sample from 2 original rows + text_generating_columns=["Name"], + conditional_columns=["Gender"], + gen_params=self.gen_params, + is_post_process=False, + ) + llm_sampler = llm_generator.get_object_generator() + + new_train_df, _ = llm_sampler.generate_data( + train_df.copy(), + target=None, + test_df=None, + only_generated_data=True, + ) + + # Retry logic should cause at least two calls to _generate_via_prompt + self.assertGreaterEqual(mock_generate_prompt.call_count, 2) + + # The finally stored name must be the novel one, not the original "Anna" + self.assertEqual(len(new_train_df), 1) + self.assertEqual(new_train_df.iloc[0]["Name"], "NewAnna") + + def test_empty_text_or_conditional_columns_use_fallback_sampling(self): + llm_generator = LLMGenerator( + gen_x_times=0.5, + text_generating_columns=["Name"], + conditional_columns=["Gender"], + gen_params=self.gen_params, + is_post_process=False, + ) + llm_sampler = llm_generator.get_object_generator() + + dummy_generated = pd.DataFrame( + [ + ["SampledName1", "F", 50, "SampledOccupation1"], + ["SampledName2", "M", 55, "SampledOccupation2"], + ], + columns=self.train_df.columns, + ) + + # Case 1: text_generating_columns cleared to empty list -> fallback to standard sampling + llm_sampler.text_generating_columns = [] + with patch.object(SamplerLLM, "_fit_great_model", return_value=(MagicMock(), "cpu")) as mock_fit, \ + patch.object(SamplerLLM, "_conditional_text_generation") as mock_conditional, \ + patch.object(SamplerLLM, "_standard_llm_sampling", return_value=dummy_generated) as mock_standard: + new_train_df, _ = llm_sampler.generate_data( + self.train_df.copy(), + target=None, + test_df=None, + only_generated_data=True, + ) + + mock_fit.assert_called_once() + mock_standard.assert_called_once() + mock_conditional.assert_not_called() + pd.testing.assert_frame_equal(new_train_df.reset_index(drop=True), dummy_generated.reset_index(drop=True)) + + # Case 2: conditional_columns cleared to None -> fallback to standard sampling again + llm_sampler = llm_generator.get_object_generator() + llm_sampler.conditional_columns = None + with patch.object(SamplerLLM, "_fit_great_model", return_value=(MagicMock(), "cpu")) as mock_fit, \ + patch.object(SamplerLLM, "_conditional_text_generation") as mock_conditional, \ + patch.object(SamplerLLM, "_standard_llm_sampling", return_value=dummy_generated) as mock_standard: + new_train_df, _ = llm_sampler.generate_data( + self.train_df.copy(), + target=None, + test_df=None, + only_generated_data=True, + ) + + mock_fit.assert_called_once() + mock_standard.assert_called_once() + mock_conditional.assert_not_called() + pd.testing.assert_frame_equal(new_train_df.reset_index(drop=True), dummy_generated.reset_index(drop=True)) + class TestSamplerSamplerDiffusion(TestCase): def setUp(self): self.train = pd.DataFrame(np.random.randint(-10, 150, size=(50, 4)), columns=list('ABCD'))