Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 167 additions & 170 deletions README.md

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ platforms = any
classifiers =
Development Status :: 5 - Production/Stable
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.13

[options]
zip_safe = False
Expand All @@ -35,7 +41,7 @@ setup_requires =
# tests_require = pytest; pytest-cov
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
# python_requires = >= 3.4
python_requires = >= 3.5
python_requires = >= 3.9
install_requires =
pandas
numpy>=2.0
Expand All @@ -48,6 +54,8 @@ install_requires =
tqdm
xgboost
be-great>=0.0.13
matplotlib>=3.5
requests

[options.packages.find]
where = src
Expand Down
30 changes: 23 additions & 7 deletions src/tabgan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# -*- coding: utf-8 -*-
from pkg_resources import DistributionNotFound, get_distribution
from importlib.metadata import version, PackageNotFoundError
from .sampler import OriginalGenerator, Sampler, GANGenerator, ForestDiffusionGenerator, LLMGenerator
from .llm_config import LLMAPIConfig
from .llm_api_client import LLMAPIClient
from .constraints import (
Constraint,
RangeConstraint,
UniqueConstraint,
FormulaConstraint,
RegexConstraint,
ConstraintEngine,
)
from .privacy_metrics import PrivacyMetrics
from .quality_report import QualityReport
from .sklearn_transformer import TabGANTransformer

__all__ = [
"OriginalGenerator",
Expand All @@ -12,13 +23,18 @@
"LLMGenerator",
"LLMAPIConfig",
"LLMAPIClient",
"Constraint",
"RangeConstraint",
"UniqueConstraint",
"FormulaConstraint",
"RegexConstraint",
"ConstraintEngine",
"PrivacyMetrics",
"QualityReport",
"TabGANTransformer",
]

try:
# Change here if project is renamed and does not equal the package name
dist_name = __name__
__version__ = get_distribution(dist_name).version
except DistributionNotFound:
__version__ = version(__name__)
except PackageNotFoundError:
__version__ = "unknown"
finally:
del get_distribution, DistributionNotFound
19 changes: 17 additions & 2 deletions src/tabgan/abc_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc
import logging
from abc import ABC, abstractmethod
from typing import Tuple
from typing import List, Optional, Tuple
from .utils import seed_everything
import pandas as pd

Expand Down Expand Up @@ -30,6 +30,7 @@ def generate_data_pipe(
only_adversarial: bool = False,
use_adversarial: bool = True,
only_generated_data: bool = False,
constraints: Optional[List] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Defines logic for sampling
Expand All @@ -41,6 +42,7 @@ def generate_data_pipe(
@param use_adversarial: perform or not adversarial filtering
@param only_generated_data: After generation get only newly generated, without concating input train dataframe.
Only works for SamplerGAN or ForestDiffusionGenerator.
@param constraints: Optional list of Constraint instances to enforce on generated data.
@return: Newly generated train dataframe and test data
"""
seed_everything()
Expand All @@ -55,7 +57,7 @@ def generate_data_pipe(
train_df.copy(), target.copy(), test_df
)
else:
logging.info("Preprocessing input data with deep copying input data.")
logging.info("Preprocessing input data without deep copying.")
new_train, new_target, test_df = generator.preprocess_data(
train_df, target, test_df
)
Expand All @@ -76,6 +78,19 @@ def generate_data_pipe(
new_train, new_target = generator.adversarial_filtering(
new_train, new_target, test_df
)
if constraints:
from .constraints import ConstraintEngine
logging.info("Applying constraints")
engine = ConstraintEngine(constraints, strategy="fix")
# Temporarily attach target to keep rows aligned
target_col = "__constraint_target__"
if new_target is not None:
new_train[target_col] = new_target.values if hasattr(new_target, 'values') else new_target
new_train = engine.apply(new_train)
if new_target is not None:
new_target = new_train[target_col].reset_index(drop=True)
new_train = new_train.drop(columns=[target_col]).reset_index(drop=True)

gc.collect()

logging.info("Total finishing, returning data")
Expand Down
157 changes: 157 additions & 0 deletions src/tabgan/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
"""
Constraint system for enforcing business rules on generated data.

Constraints are applied as a post-generation step — after the main
generation pipeline produces synthetic rows, the ConstraintEngine filters
or repairs rows that violate the declared rules.
"""

import logging
import re
from abc import ABC, abstractmethod
from typing import List, Optional

import pandas as pd

__all__ = [
"Constraint",
"RangeConstraint",
"UniqueConstraint",
"FormulaConstraint",
"RegexConstraint",
"ConstraintEngine",
]


class Constraint(ABC):
"""Base class for data constraints."""

@abstractmethod
def is_satisfied(self, df: pd.DataFrame) -> pd.Series:
"""Return a boolean Series — True for rows that satisfy the constraint."""
raise NotImplementedError

@abstractmethod
def fix(self, df: pd.DataFrame) -> pd.DataFrame:
"""Attempt to repair violating rows in-place and return the DataFrame."""
raise NotImplementedError

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class RangeConstraint(Constraint):
"""Enforce numeric column values within [min_val, max_val]."""

def __init__(self, column: str, min_val: float = None, max_val: float = None):
if min_val is None and max_val is None:
raise ValueError("At least one of min_val or max_val must be specified")
self.column = column
self.min_val = min_val
self.max_val = max_val

def is_satisfied(self, df: pd.DataFrame) -> pd.Series:
col = df[self.column]
mask = pd.Series(True, index=df.index)
if self.min_val is not None:
mask &= col >= self.min_val
if self.max_val is not None:
mask &= col <= self.max_val
return mask

def fix(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df[self.column] = df[self.column].clip(lower=self.min_val, upper=self.max_val)
return df

def __repr__(self) -> str:
return f"RangeConstraint(column={self.column!r}, min={self.min_val}, max={self.max_val})"


class UniqueConstraint(Constraint):
"""Enforce uniqueness of values in a column (drop duplicate rows)."""

def __init__(self, column: str):
self.column = column

def is_satisfied(self, df: pd.DataFrame) -> pd.Series:
return ~df[self.column].duplicated(keep="first")

def fix(self, df: pd.DataFrame) -> pd.DataFrame:
return df.drop_duplicates(subset=[self.column], keep="first").reset_index(drop=True)

def __repr__(self) -> str:
return f"UniqueConstraint(column={self.column!r})"


class FormulaConstraint(Constraint):
"""Enforce a boolean expression evaluated via ``pd.DataFrame.eval``.

Example expressions:
- ``"end_date > start_date"``
- ``"price * quantity == total"``
- ``"age >= 0"``
"""

def __init__(self, expression: str):
self.expression = expression

def is_satisfied(self, df: pd.DataFrame) -> pd.Series:
return df.eval(self.expression)

def fix(self, df: pd.DataFrame) -> pd.DataFrame:
mask = self.is_satisfied(df)
return df[mask].reset_index(drop=True)

def __repr__(self) -> str:
return f"FormulaConstraint({self.expression!r})"


class RegexConstraint(Constraint):
"""Enforce that string values in a column match a regular expression."""

def __init__(self, column: str, pattern: str):
self.column = column
self.pattern = pattern
self._compiled = re.compile(pattern)

def is_satisfied(self, df: pd.DataFrame) -> pd.Series:
return df[self.column].astype(str).str.fullmatch(self.pattern).fillna(False)

def fix(self, df: pd.DataFrame) -> pd.DataFrame:
mask = self.is_satisfied(df)
return df[mask].reset_index(drop=True)

def __repr__(self) -> str:
return f"RegexConstraint(column={self.column!r}, pattern={self.pattern!r})"


class ConstraintEngine:
"""Apply a list of constraints to a DataFrame.

Args:
constraints: List of ``Constraint`` instances to enforce.
strategy: ``"filter"`` drops violating rows; ``"fix"`` attempts
repair first, then filters remaining violations.
"""

def __init__(self, constraints: List[Constraint], strategy: str = "filter"):
if strategy not in ("filter", "fix"):
raise ValueError(f"strategy must be 'filter' or 'fix', got {strategy!r}")
self.constraints = constraints
self.strategy = strategy

def apply(self, df: pd.DataFrame) -> pd.DataFrame:
initial_len = len(df)
for constraint in self.constraints:
if self.strategy == "fix":
df = constraint.fix(df)
# After fix (or directly if filter), drop remaining violations
mask = constraint.is_satisfied(df)
df = df[mask].reset_index(drop=True)

dropped = initial_len - len(df)
if dropped > 0:
logging.info(f"ConstraintEngine: dropped {dropped} rows ({dropped / initial_len:.1%})")
return df
2 changes: 1 addition & 1 deletion src/tabgan/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(self, cols):
self.cols = cols
self.counts_dict = None

def fit(self, X: pd.DataFrame):
def fit(self, X: pd.DataFrame, y=None):
counts_dict = {}
for col in self.cols:
values, counts = np.unique(X[col], return_counts=True)
Expand Down
Loading
Loading