Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@ dmypy.json

.DS_Store
/src/tabgan/trainer_great/
/tests/trainer_great/
331 changes: 196 additions & 135 deletions README.md

Large diffs are not rendered by default.

Binary file modified images/workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
scipy>=1.4.1
scipy>=1.12.0
category_encoders>=2.6.3
numpy>=1.22.0
torch>=1.6.0
pandas>=1.2.2
pandas>=2.2.0
lightgbm>=2.2.3
scikit_learn>=1.5.2
torchvision>=0.4.2
numpy>=2.0
python-dateutil==2.8.1
numpy>=1.23.1
python-dateutil>=2.8.2
tqdm>=4.61.1
xgboost>=2.0.0
be-great==0.0.8
be-great==0.0.13
13 changes: 3 additions & 10 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
python-dateutil
tqdm
xgboost
be-great
be-great>=0.0.13

[options.packages.find]
where = src
Expand All @@ -66,15 +66,8 @@ testing =
pytest-cov

[options.entry_points]
# Add here console scripts like:
# console_scripts =
# script_name = tabgan.module:function
# For example:
# console_scripts =
# fibonacci = tabgan.skeleton:run
# And any other entry points, for example:
# pyscaffold.cli =
# awesome = pyscaffoldext.awesome.extension:AwesomeExtension
console_scripts =
tabgan-generate = tabgan.cli:main

[test]
# py.test options when running `python setup.py test`
Expand Down
30 changes: 29 additions & 1 deletion src/tabgan/abc_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,35 @@ def generate_data_pipe(


class Sampler(ABC):
"""Interface for each sampling strategy"""
"""Interface for each sampling strategy.

Concrete sampler implementations share a common configuration interface
(generation factor, categorical columns, post-processing flags, etc.).
This base ``__init__`` stores those shared parameters so that subclasses
can call ``super().__init__(...)`` and focus on strategy-specific logic.
"""

def __init__(
self,
gen_x_times: float,
cat_cols: list | None,
bot_filter_quantile: float,
top_filter_quantile: float,
is_post_process: bool,
adversarial_model_params: dict,
pregeneration_frac: float,
only_generated_data: bool,
gen_params: dict | None = None,
) -> None:
self.gen_x_times = gen_x_times
self.cat_cols = cat_cols
self.bot_filter_quantile = bot_filter_quantile
self.top_filter_quantile = top_filter_quantile
self.is_post_process = is_post_process
self.adversarial_model_params = adversarial_model_params
self.pregeneration_frac = pregeneration_frac
self.only_generated_data = only_generated_data
self.gen_params = gen_params or {}

def get_generated_shape(self, input_df):
"""Calculates final output shape"""
Expand Down
129 changes: 129 additions & 0 deletions src/tabgan/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import logging
from typing import List, Optional

import pandas as pd

from tabgan.sampler import (
OriginalGenerator,
GANGenerator,
ForestDiffusionGenerator,
LLMGenerator,
)


def _parse_cat_cols(raw: Optional[str]) -> Optional[List[str]]:
if not raw:
return None
return [c.strip() for c in raw.split(",") if c.strip()]


def main() -> None:
"""
Command-line interface for generating synthetic tabular data with tabgan.

Example:
tabgan-generate \\
--input-csv train.csv \\
--target-col target \\
--generator gan \\
--gen-x-times 1.5 \\
--cat-cols year,gender \\
--output-csv synthetic_train.csv
"""
parser = argparse.ArgumentParser(
description="Generate synthetic tabular data using tabgan samplers."
)
parser.add_argument(
"--input-csv",
required=True,
help="Path to input CSV file containing training data (with or without target column).",
)
parser.add_argument(
"--target-col",
default=None,
help="Name of the target column in the CSV (optional).",
)
parser.add_argument(
"--output-csv",
required=True,
help="Path to write the generated synthetic dataset as CSV.",
)
parser.add_argument(
"--generator",
choices=["original", "gan", "diffusion", "llm"],
default="gan",
help="Which sampler to use for generation.",
)
parser.add_argument(
"--gen-x-times",
type=float,
default=1.1,
help="Factor controlling how many synthetic samples to generate relative to the training size.",
)
parser.add_argument(
"--cat-cols",
default=None,
help="Comma-separated list of categorical column names (e.g. 'year,gender').",
)
parser.add_argument(
"--only-generated",
action="store_true",
help="If set, output only synthetic rows instead of original + synthetic.",
)

args = parser.parse_args()

logging.basicConfig(level=logging.INFO)

logging.info("Reading input CSV from %s", args.input_csv)
df = pd.read_csv(args.input_csv)

target_df = None
train_df = df
if args.target_col is not None:
if args.target_col not in df.columns:
raise ValueError(f"Target column '{args.target_col}' not found in input CSV.")
target_df = df[[args.target_col]]
train_df = df.drop(columns=[args.target_col])

cat_cols = _parse_cat_cols(args.cat_cols)

generator_map = {
"original": OriginalGenerator,
"gan": GANGenerator,
"diffusion": ForestDiffusionGenerator,
"llm": LLMGenerator,
}
generator_cls = generator_map[args.generator]

logging.info("Initializing %s generator", generator_cls.__name__)
generator = generator_cls(
gen_x_times=args.gen_x_times,
cat_cols=cat_cols,
only_generated_data=bool(args.only_generated),
)

# Use train_df itself as test_df when a dedicated hold-out set is not provided.
logging.info("Generating synthetic data...")
new_train, new_target = generator.generate_data_pipe(
train_df, target_df, train_df
)

if new_target is not None and args.target_col is not None:
out_df = new_train.copy()
# new_target can be DataFrame or Series; align to a 1D array
if hasattr(new_target, "values") and new_target.ndim > 1:
out_df[args.target_col] = new_target.values.ravel()
else:
out_df[args.target_col] = new_target
else:
out_df = new_train

logging.info("Writing synthetic data to %s", args.output_csv)
out_df.to_csv(args.output_csv, index=False)


if __name__ == "__main__":
main()

Loading
Loading