From 915ae452f32ebee2d39b4261e9e279544046e9f8 Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Thu, 18 Jun 2026 10:19:41 -0700 Subject: [PATCH] =?UTF-8?q?Deprecate=20data=5Fgeneration=5Fv2;=20rename=20?= =?UTF-8?q?DataGenerationV3=20=E2=86=92=20TabularSynthesizer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace data_generation_v2 with the new TabularSynthesizer as the primary in-memory entry point: - Rename DataGenerationV3 class to TabularSynthesizer (backward-compat alias kept). - __init__.py now exports TabularSynthesizer instead of generate (from v2). - pydantic_api.py migrated to use TabularSynthesizer. Categorical possible_values are now stringified for v3 compatibility (enums → .value, bools/None → str). - bin/main.py migrated from dpsynth.generate() to dpsynth.TabularSynthesizer. - BUILD deps updated for __init__ and pydantic_api targets. - README.md updated to reference TabularSynthesizer as the recommended API. - Examples and colab notebooks updated to use the v3 class-based API. - data_generation_v2.generate() now emits a DeprecationWarning. PiperOrigin-RevId: 934424293 --- README.md | 11 +- dpsynth/__init__.py | 2 +- dpsynth/bin/main.py | 12 +- dpsynth/data_generation_v2.py | 12 +- dpsynth/data_generation_v3.py | 18 +- ...led_example_theory_and_in_memory_api.ipynb | 600 ++++++++++-------- dpsynth/examples/quickstart.ipynb | 105 +-- dpsynth/pydantic_api.py | 90 ++- tests/data_generation_v3_test.py | 20 +- tests/pydantic_api_test.py | 14 +- 10 files changed, 519 insertions(+), 365 deletions(-) diff --git a/README.md b/README.md index a43a8eb..1d2c519 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,8 @@ inference), they were developed independently and have different trade-offs: ### 1. In-Memory (Local) Mode -**Entry point:** [`dpsynth.generate()`](dpsynth/__init__.py) (backed by -[`data_generation_v2.py`](dpsynth/data_generation_v2.py)) +**Entry point:** [`dpsynth.TabularSynthesizer`](dpsynth/__init__.py) (backed by +[`data_generation_v3.py`](dpsynth/data_generation_v3.py)) Designed for **datasets that fit in memory** (e.g., Pandas DataFrames). We have tested this on datasets up to ~100M rows, though performance will depend on the @@ -127,8 +127,9 @@ These modules are used by both the in-memory and pipeline code paths: * **[`discrete_mechanisms/`](dpsynth/discrete_mechanisms/README.md)**: Local, single-machine DP mechanisms (AIM, MST, etc.) and shared mathematical utilities like domain compression. -* **[`data_generation_v2.py`](dpsynth/data_generation_v2.py)**: The end-to-end - in-memory generation pipeline. This is what `dpsynth.generate()` calls. +* **[`data_generation_v3.py`](dpsynth/data_generation_v3.py)**: The + end-to-end in-memory generation pipeline. This is what + `dpsynth.TabularSynthesizer` exposes. * **[`local_mode/`](dpsynth/local_mode/)**: Locally-optimized DP primitives for quantiles and partition selection (NumPy/SciPy-based). * **[`pydantic_api.py`](dpsynth/pydantic_api.py)**: API for synthesizing @@ -166,7 +167,7 @@ These modules are used by both the in-memory and pipeline code paths: | Scenario | Recommended | |---|---| -| Fits in memory, Pandas workflow | **In-Memory** (`dpsynth.generate`) | +| Fits in memory, Pandas workflow | **In-Memory** (`dpsynth.TabularSynthesizer`) | | Discrete data, precomputed marginals | **In-Memory** (`discrete_mechanisms`) | | Large-scale, distributed processing | **Pipeline** (`data_generation`) | | Marginals from an external system | **Post-Processing** | diff --git a/dpsynth/__init__.py b/dpsynth/__init__.py index 6cc9ca5..976da41 100644 --- a/dpsynth/__init__.py +++ b/dpsynth/__init__.py @@ -18,6 +18,6 @@ __version__ = '0.1.0' from dpsynth import discrete_mechanisms from dpsynth import domain -from dpsynth.data_generation_v2 import generate +from dpsynth.data_generation_v3 import TabularSynthesizer from dpsynth.domain import CategoricalAttribute from dpsynth.domain import NumericalAttribute diff --git a/dpsynth/bin/main.py b/dpsynth/bin/main.py index b3c578c..4cc6ca1 100644 --- a/dpsynth/bin/main.py +++ b/dpsynth/bin/main.py @@ -105,13 +105,11 @@ def main(_): case _: raise ValueError(f'Unknown mechanism: {_MECHANISM.value}') - synthetic_df = dpsynth.generate( - df, - attribute_domains, - epsilon=_EPSILON.value, - delta=_DELTA.value, - discrete_config=mechanism_config, - ) + mechanism = dpsynth.TabularSynthesizer( + domains=attribute_domains, + discrete_mechanism=mechanism_config, + ).calibrate(epsilon=_EPSILON.value, delta=_DELTA.value) + synthetic_df = mechanism(np.random.default_rng(_SEED.value), df) synthetic_df.to_csv(_OUTPUT_PATH.value, index=False) diff --git a/dpsynth/data_generation_v2.py b/dpsynth/data_generation_v2.py index 7d83c9e..f849e47 100644 --- a/dpsynth/data_generation_v2.py +++ b/dpsynth/data_generation_v2.py @@ -14,11 +14,15 @@ """Implementation of an end-to-end DP synthetic data generation mechanism. -In this module there is implementation to run locally. +.. deprecated:: + This module is deprecated. Use + :class:`dpsynth.data_generation_v3.TabularSynthesizer` + instead. """ from collections.abc import Mapping, Sequence from typing import TypeAlias +import warnings from absl import logging import dp_accounting @@ -133,6 +137,12 @@ def generate( Returns: A synthetic dataset. """ + warnings.warn( + 'data_generation_v2.generate() is deprecated. Use' + ' data_generation_v3.TabularSynthesizer instead.', + DeprecationWarning, + stacklevel=2, + ) assert 0 <= one_way_marginal_budget_fraction <= 1 if not skip_compression and cross_attribute_constraints: raise ValueError( diff --git a/dpsynth/data_generation_v3.py b/dpsynth/data_generation_v3.py index d4172f8..b2ef1ad 100644 --- a/dpsynth/data_generation_v3.py +++ b/dpsynth/data_generation_v3.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""End-to-end DP synthetic data generation using local mode primitives.""" +"""End-to-end DP synthetic tabular data generation using local mode primitives.""" from __future__ import annotations @@ -72,7 +72,7 @@ def _create_initializers( @dataclasses.dataclass -class DataGenerationV3(primitives.DPMechanism): +class TabularSynthesizer(primitives.DPMechanism): """End-to-end DP synthetic data generation mechanism. This mechanism encodes input categorical and numerical data into a discrete @@ -82,8 +82,8 @@ class DataGenerationV3(primitives.DPMechanism): Usage:: - v3 = DataGenerationV3(domains=domains) - calibrated = v3.calibrate(zcdp_rho=1.0) + synth = TabularSynthesizer(domains=domains) + calibrated = synth.calibrate(zcdp_rho=1.0) synthetic_df = calibrated(rng, df) Attributes: @@ -110,7 +110,7 @@ def calibrate( delta: float | None = None, numerical_bins: int = 32, init_budget_fraction: float = 0.1, - ) -> DataGenerationV3: + ) -> TabularSynthesizer: """Returns a calibrated copy of this mechanism. Supports two calibration modes: @@ -133,7 +133,7 @@ def calibrate( init_budget_fraction: Fraction of total budget for initialization. Returns: - A new DataGenerationV3 instance with calibrated sub-mechanisms. + A new TabularSynthesizer instance with calibrated sub-mechanisms. Raises: ValueError: If arguments are invalid or delta is missing when required. @@ -224,7 +224,7 @@ def _calibrate_approx_dp( init_budget_fraction: Fraction of zCDP budget for initialization. Returns: - A new DataGenerationV3 instance with calibrated sub-mechanisms. + A new TabularSynthesizer instance with calibrated sub-mechanisms. """ inits = self.initializers or _create_initializers( self.domains, numerical_bins, init_delta @@ -374,3 +374,7 @@ def __call__( column_order = [col for col in data.columns if col in self.domains] return pd.DataFrame(synthetic_columns)[column_order] + + +# Backward-compatible alias. +DataGenerationV3 = TabularSynthesizer diff --git a/dpsynth/examples/detailed_example_theory_and_in_memory_api.ipynb b/dpsynth/examples/detailed_example_theory_and_in_memory_api.ipynb index 2286b7f..38826bd 100644 --- a/dpsynth/examples/detailed_example_theory_and_in_memory_api.ipynb +++ b/dpsynth/examples/detailed_example_theory_and_in_memory_api.ipynb @@ -1,10 +1,8 @@ { "cells": [ { + "id": "536f643c", "cell_type": "markdown", - "metadata": { - "id": "ExmRcxAxASgM" - }, "source": [ "# Differentially Private Synthetic Tabular Data (go/dp-tabular-data-colab)\n", "\n", @@ -14,7 +12,7 @@ " Given a collection of records as input, we generate a collection of records as output that shares the same schema as the input data, while also trying to preserve important statistical properties of the original data. The mechanisms we currently implement follow the **Select-Measure-Generate** paradigm. \n", "\n", " 0. **Discretize** numeric attributes so all attributes are categorical.\n", - " 1. **Select** a collection of queries to measure \u2014 typically low-dimensional marginals, or queries of the form \"SELECT col1, col2, COUNT(*) FROM DATA GROUP BY col1, col2\"\n", + " 1. **Select** a collection of queries to measure — typically low-dimensional marginals, or queries of the form \"SELECT col1, col2, COUNT(*) FROM DATA GROUP BY col1, col2\"\n", " 2. **Measure** the selected queries privately using a noise-addition mechanism.\n", " 3. **Generate** synthetic data that best explains the noisy measurements.\n", "\n", @@ -44,12 +42,16 @@ "2. **MST** privately selects and measures some subset of 2-way marginals, fits a model to the noisy observations, then samples synthetic data. It typically runs in around 10 minutes.\n", "3. **AIM** iteratively selects and measures a low-order marginal, one at a time, updates its model of the data distribution, and repeats until the privacy budget is consumed. It provides the best utility but requires the most runtime. The runtime/utility trade-off can be configured via AIM's kwargs.\n", "\n" - ] + ], + "metadata": { + "id": "ExmRcxAxASgM" + } }, { + "id": "6400d5b5", "cell_type": "code", "source": [ - "#@title Install libraries\n", + "# @title Install libraries\n", "!pip install git+https://github.com/google/dpsynth.git\n", "!pip install stdmetrics\n", "!pip install scikit-learn\n", @@ -62,21 +64,23 @@ "outputs": [] }, { + "id": "915b40a0", "cell_type": "code", - "execution_count": null, + "source": [ + "# @title Import libraries\n", + "import dpsynth\n", + "import pandas as pd\n", + "import sdmetrics" + ], "metadata": { "id": "kOtCLaN8AQTL", "cellView": "form" }, - "outputs": [], - "source": [ - "#@title Import libraries\n", - "import pandas as pd\n", - "import dpsynth\n", - "import sdmetrics" - ] + "execution_count": null, + "outputs": [] }, { + "id": "9fd3df1a", "cell_type": "markdown", "source": [ "## Download Dataset\n", @@ -88,6 +92,7 @@ } }, { + "id": "ebeba744", "cell_type": "code", "source": [ "# Install Kaggle API\n", @@ -127,10 +132,8 @@ "outputs": [] }, { + "id": "f2cfbc9f", "cell_type": "markdown", - "metadata": { - "id": "zBGo9GJUKt4z" - }, "source": [ "# Load Data and Domain\n", "\n", @@ -138,222 +141,264 @@ "\n", "* **Data**: The data may contain a mix of categorical and numerical attributes. For this example, we load the UCI adult dataset, which is a small dataset derived from Census sources.\n", "* **Domain**: The domain specificies what values are possible for each attribute in the dataset. For numerical attributes like age, this is the minimum and maximum possible value (or some approximation thereof). For categorical attributees like education, this is just a list, like ['HS-grad', 'Bachelors', 'Masters', ...]). The domain information can be supplied externally in the form of a [yaml file](adult_domain.yaml)." - ] + ], + "metadata": { + "id": "zBGo9GJUKt4z" + } }, { + "id": "1f76632b", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "h3CB36tJKU99" - }, - "outputs": [], "source": [ "data = pd.read_csv('dpsynth/examples/adult.csv')\n", "data.head()" - ] + ], + "metadata": { + "id": "h3CB36tJKU99" + }, + "execution_count": null, + "outputs": [] }, { + "id": "5aa2f7e7", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "joxIqzA4kT2z" - }, - "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", + "\n", "data, test = train_test_split(data, test_size=0.3, random_state=42)\n", "print(data.shape)\n", "print(test.shape)" - ] + ], + "metadata": { + "id": "joxIqzA4kT2z" + }, + "execution_count": null, + "outputs": [] }, { + "id": "ef6b2cd5", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gE0Ky5dOK2vu" - }, - "outputs": [], "source": [ - "attribute_domains = dpsynth.domain.from_yaml_file('dpsynth/examples/adult_domain.yaml')\n", + "attribute_domains = dpsynth.domain.from_yaml_file(\n", + " 'dpsynth/examples/adult_domain.yaml'\n", + ")\n", "# Below we look at how two columns (numerical and categorical) are represented.\n", "# Age is a numerical attribute between 17 and 90.\n", "print(attribute_domains['age'])\n", "# Education is a categorical value with the given possible values.\n", "print(attribute_domains['education'])" - ] + ], + "metadata": { + "id": "gE0Ky5dOK2vu" + }, + "execution_count": null, + "outputs": [] }, { + "id": "fc9671d8", "cell_type": "markdown", - "metadata": { - "id": "aUwmX5ZNL2eE" - }, "source": [ "# Generate DP Synthetic Data\n", "\n", "Once we have the data and domain information, generating DP synthetic data is a simple one-liner shown below. Below we use the MST mechanism, although other base mechanisms like AIM can be used with a one-line change." - ] + ], + "metadata": { + "id": "aUwmX5ZNL2eE" + } }, { + "id": "800c7c4c", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_ZIawFyUK80V" - }, - "outputs": [], "source": [ "# MST requires couple of minutes to run, while AIM requires about\n", "# ~40 minutes to run. For this colab notebook, we therefore only run MST.\n", "\n", - "choose_mechanism = \"mst\" #@param [ \"mst\", \"aim\"]\n", + "choose_mechanism = \"mst\" # @param [ \"mst\", \"aim\"]\n", "\n", "if choose_mechanism == \"mst\":\n", - " mechanism = dpsynth.discrete_mechanisms.maximum_spanning_tree_mechanism\n", + " mechanism = dpsynth.discrete_mechanisms.MSTMechanism()\n", "else:\n", - " mechanism = dpsynth.discrete_mechanisms.adaptive_iterative_mechanism\n", + " mechanism = dpsynth.discrete_mechanisms.AIMMechanism()\n", "\n", "print(\"Started generating synthetic data, this may take a while...\")\n", "import warnings\n", + "from dpsynth import data_generation_v3\n", + "import numpy as np\n", + "\n", "with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\") # suppress some noisy warnings\n", - " synthetic_data = dpsynth.generate(\n", - " data=data,\n", + " warnings.simplefilter(\"ignore\") # suppress some noisy warnings\n", + " synth = data_generation_v3.TabularSynthesizer(\n", " domains=attribute_domains,\n", + " discrete_mechanism=mechanism,\n", + " ).calibrate(\n", " epsilon=1.0,\n", - " delta=1.0 / data.shape[0]**2, # 1 / N^2 where N is the number of rows in the dataset\n", - " mechanism=mechanism,\n", - " seed=577215664,\n", + " delta=1.0\n", + " / data.shape[0]\n", + " ** 2, # 1 / N^2 where N is the number of rows in the dataset\n", " )\n", + " synthetic_data = synth(np.random.default_rng(), data)\n", "print(\"Done, synthetic data is generated!\")" - ] - }, - { + ], "metadata": { - "id": "AW1JZ7t3p38J" + "id": "_ZIawFyUK80V" }, + "execution_count": null, + "outputs": [] + }, + { + "id": "3c96eac5", "cell_type": "markdown", "source": [ "Now let's take a look at the generated data. The number of generated rows should match the number of rows in the training dataset only approximately because noise is added to preserve anonymity." - ] + ], + "metadata": { + "id": "AW1JZ7t3p38J" + } }, { - "metadata": { - "id": "Ih4cUT58pN7-" - }, + "id": "3584fb60", "cell_type": "code", "source": [ - "print(f\"Generated {synthetic_data.shape[0]} rows of synthetic data. Training dataset has {data.shape[0]} rows (diff is {synthetic_data.shape[0] - data.shape[0]} which is {(synthetic_data.shape[0] - data.shape[0]) / data.shape[0] * 100:.2f}%).\")\n", + "print(\n", + " f\"Generated {synthetic_data.shape[0]} rows of synthetic data. Training\"\n", + " f\" dataset has {data.shape[0]} rows (diff is\"\n", + " f\" {synthetic_data.shape[0] - data.shape[0]} which is\"\n", + " f\" {(synthetic_data.shape[0] - data.shape[0]) / data.shape[0] * 100:.2f}%).\"\n", + ")\n", "synthetic_data.head()" ], - "outputs": [], - "execution_count": null + "metadata": { + "id": "Ih4cUT58pN7-" + }, + "execution_count": null, + "outputs": [] }, { + "id": "5458703b", "cell_type": "markdown", - "metadata": { - "id": "FbF1nRVCMoIR" - }, "source": [ "# Inspect and evaluate the synthetic data\n", "Now we'll check that the synthetic data preserves important statistical properties of the real data." - ] + ], + "metadata": { + "id": "FbF1nRVCMoIR" + } }, { + "id": "91176b5c", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aZEvW2dMMqj1" - }, - "outputs": [], "source": [ "def compare_histograms(column: str | list[str]) -> pd.DataFrame:\n", - " return pd.DataFrame({'Synthetic': synthetic_data[column].value_counts(), 'Real': data[column].value_counts()}).fillna(0)\n", + " return pd.DataFrame({\n", + " 'Synthetic': synthetic_data[column].value_counts(),\n", + " 'Real': data[column].value_counts(),\n", + " }).fillna(0)\n", "\n", "\n", "compare_histograms('race')" - ] + ], + "metadata": { + "id": "aZEvW2dMMqj1" + }, + "execution_count": null, + "outputs": [] }, { + "id": "a3014919", "cell_type": "code", - "execution_count": null, + "source": [ + "compare_histograms(['workclass', 'income>50K'])" + ], "metadata": { "id": "RWc2JVZ4NWya" }, - "outputs": [], - "source": [ - "compare_histograms(['workclass', 'income>50K'])" - ] + "execution_count": null, + "outputs": [] }, { + "id": "bf89d989", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kTFBaB4NO_zU" - }, - "outputs": [], "source": [ - "def compare_groupby_mean(groupby_cols: str | list[str], agg_col: str) -> pd.DataFrame:\n", - " return pd.DataFrame({'Synthetic': synthetic_data.groupby(groupby_cols)[agg_col].mean(), 'Real': data.groupby(groupby_cols)[agg_col].mean()})\n", + "def compare_groupby_mean(\n", + " groupby_cols: str | list[str], agg_col: str\n", + ") -> pd.DataFrame:\n", + " return pd.DataFrame({\n", + " 'Synthetic': synthetic_data.groupby(groupby_cols)[agg_col].mean(),\n", + " 'Real': data.groupby(groupby_cols)[agg_col].mean(),\n", + " })\n", + "\n", "\n", "compare_groupby_mean('sex', 'age')" - ] + ], + "metadata": { + "id": "kTFBaB4NO_zU" + }, + "execution_count": null, + "outputs": [] }, { + "id": "7e1b82b9", "cell_type": "markdown", - "metadata": { - "id": "2oU2MG4ZDFx4" - }, "source": [ "**1. Downstream task utility**\n", "\n", "The efficacy score reports the utility of synthetic data in downstream tasks and compares to train data utility. Ideally, you want the synth data utility score to be *close* to that of the real data because it indicates that the synth data can be used as replacement for the real data in the application setting.\n", "\n", "In our setting we will check on the following problem: given the tabular data we have we will train a binary classifier model to predict whether a person has income greater than 50K or not. The expectation is that performance of classifier trained on synthetic and real data will be comparable. Performance will be evaluated via f1 score." - ] + ], + "metadata": { + "id": "2oU2MG4ZDFx4" + } }, { + "id": "c9187d3e", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nztDFIoOlSf_" - }, - "outputs": [], "source": [ "from sdmetrics.single_table import BinaryDecisionTreeClassifier\n", "from sklearn.metrics import f1_score" - ] + ], + "metadata": { + "id": "nztDFIoOlSf_" + }, + "execution_count": null, + "outputs": [] }, { + "id": "26c1a56c", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0qxy8C5qlh7W" - }, - "outputs": [], "source": [ - "target=\"income>50K\"\n", + "target = 'income>50K'\n", "# only one instance of holand and not seen in training data, sdmetrics transformer doesn't handle it sadly.\n", "test = test[test['native-country'] != 'Holand-Netherlands']\n", "\n", - "scorer=lambda y_true, y_pred: f1_score(y_true, y_pred, pos_label=0)\n", - "score_real = BinaryDecisionTreeClassifier.compute(test_data=test, train_data=data, target=target, scorer=scorer)\n", - "score_synth = BinaryDecisionTreeClassifier.compute(test_data=test, train_data=synthetic_data, target=target, scorer=scorer)" - ] + "scorer = lambda y_true, y_pred: f1_score(y_true, y_pred, pos_label=0)\n", + "score_real = BinaryDecisionTreeClassifier.compute(\n", + " test_data=test, train_data=data, target=target, scorer=scorer\n", + ")\n", + "score_synth = BinaryDecisionTreeClassifier.compute(\n", + " test_data=test, train_data=synthetic_data, target=target, scorer=scorer\n", + ")" + ], + "metadata": { + "id": "0qxy8C5qlh7W" + }, + "execution_count": null, + "outputs": [] }, { + "id": "cf98eb04", "cell_type": "code", - "execution_count": null, + "source": [ + "print(f\"Real data (baseline): {score_real}\\nSynth data: {score_synth}\")" + ], "metadata": { "id": "2yIEIYSqnnrH" }, - "outputs": [], - "source": [ - "print(f\"Real data (baseline): {score_real}\\nSynth data: {score_synth}\")" - ] + "execution_count": null, + "outputs": [] }, { - "metadata": { - "id": "4Ij1_ibGcNQs" - }, + "id": "3e5016ef", "cell_type": "markdown", "source": [ "**Improvements of recall metric on downstream binary classification tasks**\n", @@ -365,102 +410,110 @@ "* (baseline) 0.5: Augmenting the real data with synthetic data does not change the ML classifier's recall at all\n", "\n", "* (worst) 0.0: Augmenting the real data with synthetic data decreases the ML classifier's recall by the most it possibly can (by 100%)." - ] + ], + "metadata": { + "id": "4Ij1_ibGcNQs" + } }, { + "id": "426dc080", "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "oyBfCx4RoUXT" - }, - "outputs": [], "source": [ "# @title Create metadata\n", "# ideally, this should be created manually but lazily inferring the dtype from data\n", - "import pandas as pd\n", - "import re\n", "from datetime import datetime\n", + "import re\n", "from IPython.display import clear_output\n", + "import pandas as pd\n", "\n", "\n", "# Function to infer the column type based on its data\n", "def infer_column_type(series):\n", - " # Check if all values are numerical (either integer or float)\n", - " if pd.api.types.is_numeric_dtype(series):\n", - " return \"numerical\"\n", + " # Check if all values are numerical (either integer or float)\n", + " if pd.api.types.is_numeric_dtype(series):\n", + " return \"numerical\"\n", "\n", - " # Check if all values are boolean-like\n", - " elif pd.api.types.is_bool_dtype(series):\n", - " return \"boolean\"\n", + " # Check if all values are boolean-like\n", + " elif pd.api.types.is_bool_dtype(series):\n", + " return \"boolean\"\n", "\n", - " # Check if the column can be converted to datetime (if not already)\n", - " elif pd.to_datetime(series, errors='coerce').notna().all():\n", - " return \"datetime\"\n", + " # Check if the column can be converted to datetime (if not already)\n", + " elif pd.to_datetime(series, errors=\"coerce\").notna().all():\n", + " return \"datetime\"\n", "\n", - " # If values seem like address-like, can be categorized as 'address'\n", - " # For simplicity, assume any column with \"street\", \"city\", or \"zipcode\" is an address\n", - " elif series.apply(lambda x: bool(re.search(r\"(street|city|zip|address)\", str(x), re.I))).any():\n", - " return \"address\"\n", + " # If values seem like address-like, can be categorized as 'address'\n", + " # For simplicity, assume any column with \"street\", \"city\", or \"zipcode\" is an address\n", + " elif series.apply(\n", + " lambda x: bool(re.search(r\"(street|city|zip|address)\", str(x), re.I))\n", + " ).any():\n", + " return \"address\"\n", + "\n", + " # Otherwise, consider it categorical (text or mixed values)\n", + " return \"categorical\"\n", "\n", - " # Otherwise, consider it categorical (text or mixed values)\n", - " return \"categorical\"\n", "\n", "# Function to read the CSV and generate the metadata dictionary\n", "def create_metadata_from_csv(df):\n", - " metadata = {\n", - " \"primary_key\": \"\", # You could infer this based on uniqueness in the data\n", - " \"columns\": {}\n", - " }\n", + " metadata = {\n", + " \"primary_key\": \"\", # You could infer this based on uniqueness in the data\n", + " \"columns\": {},\n", + " }\n", + "\n", + " # Iterate through columns and infer their types\n", + " for column_name in df.columns:\n", + " column_data = df[column_name]\n", "\n", - " # Iterate through columns and infer their types\n", - " for column_name in df.columns:\n", - " column_data = df[column_name]\n", + " # Infer the type of the column\n", + " column_type = infer_column_type(column_data)\n", "\n", - " # Infer the type of the column\n", - " column_type = infer_column_type(column_data)\n", + " # Initialize the metadata for the column\n", + " column_metadata = {\"sdtype\": column_type}\n", "\n", - " # Initialize the metadata for the column\n", - " column_metadata = {\"sdtype\": column_type}\n", + " # Add additional information for datetime columns\n", + " if column_type == \"datetime\":\n", + " column_metadata[\"datetime_format\"] = \"%Y-%m-%d\" # default format\n", "\n", - " # Add additional information for datetime columns\n", - " if column_type == \"datetime\":\n", - " column_metadata[\"datetime_format\"] = \"%Y-%m-%d\" # default format\n", + " # Example: Adding special rules for \"address\" column\n", + " if column_type == \"address\":\n", + " column_metadata[\"pii\"] = (\n", + " True # Assuming addresses contain personal information\n", + " )\n", "\n", - " # Example: Adding special rules for \"address\" column\n", - " if column_type == \"address\":\n", - " column_metadata[\"pii\"] = True # Assuming addresses contain personal information\n", + " # Add the column metadata to the dictionary\n", + " metadata[\"columns\"][column_name] = column_metadata\n", "\n", - " # Add the column metadata to the dictionary\n", - " metadata[\"columns\"][column_name] = column_metadata\n", + " return metadata\n", "\n", - " return metadata\n", "\n", "# Example usage\n", "metadata_dict = create_metadata_from_csv(data)\n", "clear_output()" - ] + ], + "metadata": { + "cellView": "form", + "id": "oyBfCx4RoUXT" + }, + "execution_count": null, + "outputs": [] }, { + "id": "b37f8474", "cell_type": "code", - "execution_count": null, + "source": [ + "metadata_dict" + ], "metadata": { "id": "2mHl7AVmoidz" }, - "outputs": [], - "source": [ - "metadata_dict" - ] + "execution_count": null, + "outputs": [] }, { + "id": "e5de679f", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6BKHThQdoYdv" - }, - "outputs": [], "source": [ "from sdmetrics.single_table.data_augmentation import BinaryClassifierRecallEfficacy\n", + "\n", "breakdown = BinaryClassifierRecallEfficacy.compute_breakdown(\n", " real_training_data=data,\n", " synthetic_data=synthetic_data,\n", @@ -468,17 +521,20 @@ " metadata=metadata_dict,\n", " prediction_column_name=target,\n", " minority_class_label=\">50K\",\n", - " classifier='XGBoost',\n", + " classifier=\"XGBoost\",\n", " fixed_precision_value=0.9,\n", ")\n", "breakdown" - ] + ], + "metadata": { + "id": "6BKHThQdoYdv" + }, + "execution_count": null, + "outputs": [] }, { + "id": "60e5c941", "cell_type": "markdown", - "metadata": { - "id": "2sLapxx7C-vF" - }, "source": [ "**2. Quality Report**\n", "\n", @@ -489,80 +545,87 @@ "\n", "The scores are in [0; 1] range, the higher the better.\n", "\n" - ] + ], + "metadata": { + "id": "2sLapxx7C-vF" + } }, { + "id": "1549e9ca", "cell_type": "code", - "execution_count": null, + "source": [ + "from sdmetrics.reports.single_table import QualityReport" + ], "metadata": { "id": "pXOsecXIpjsN" }, - "outputs": [], - "source": [ - "from sdmetrics.reports.single_table import QualityReport" - ] + "execution_count": null, + "outputs": [] }, { + "id": "e4766a85", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UeyIoedopqmN" - }, - "outputs": [], "source": [ "quality_report = QualityReport()\n", "# between real train and synthetic\n", "_ = quality_report.generate(data, synthetic_data, metadata_dict)" - ] - }, - { + ], "metadata": { - "id": "fiTFYVHn0k_6" + "id": "UeyIoedopqmN" }, + "execution_count": null, + "outputs": [] + }, + { + "id": "7d4997c7", "cell_type": "markdown", "source": [ "Let's see Column Shapes scores per each column" - ] + ], + "metadata": { + "id": "fiTFYVHn0k_6" + } }, { - "metadata": { - "id": "Woc3DUZfyw0g" - }, + "id": "c386e84d", "cell_type": "code", "source": [ "fig = quality_report.get_visualization(property_name='Column Shapes')\n", "fig.show()" ], - "outputs": [], - "execution_count": null - }, - { "metadata": { - "id": "NmG1hyLN02Vw" + "id": "Woc3DUZfyw0g" }, + "execution_count": null, + "outputs": [] + }, + { + "id": "128882c2", "cell_type": "markdown", "source": [ "Let's see the scores for each pair of columns - the greener the plot is the better.\n", "\n", "Additionally there will be \"Numerical correlations\" plots that show correlations inside the dataset, it is the best if two plots match as much as possible." - ] + ], + "metadata": { + "id": "NmG1hyLN02Vw" + } }, { + "id": "e220417c", "cell_type": "code", - "execution_count": null, + "source": [ + "quality_report.get_visualization(property_name='Column Pair Trends')" + ], "metadata": { "id": "yFzHP1KnptMd" }, - "outputs": [], - "source": [ - "quality_report.get_visualization(property_name='Column Pair Trends')" - ] + "execution_count": null, + "outputs": [] }, { + "id": "328a1aef", "cell_type": "markdown", - "metadata": { - "id": "FOXbzl8-C4oZ" - }, "source": [ "**3. Diagnostic Report**\n", "\n", @@ -573,50 +636,54 @@ "\n", "The scores have to be always 100% because they are basic validity checks.\n", "In our case data validity score is not 100% but very close to it because of the problem with `Holand-Netherlands`." - ] + ], + "metadata": { + "id": "FOXbzl8-C4oZ" + } }, { + "id": "de07f527", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QLYYvNDCp_7O" - }, - "outputs": [], "source": [ "from sdmetrics.reports.single_table import DiagnosticReport\n", "\n", "diagnostic_report = DiagnosticReport()" - ] + ], + "metadata": { + "id": "QLYYvNDCp_7O" + }, + "execution_count": null, + "outputs": [] }, { + "id": "e4f14a1e", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "y-08QRCLqCpz" - }, - "outputs": [], "source": [ "# between real train and synthetic data\n", "diagnostic_report.generate(data, synthetic_data, metadata_dict)" - ] + ], + "metadata": { + "id": "y-08QRCLqCpz" + }, + "execution_count": null, + "outputs": [] }, { + "id": "d8d68c48", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hTwMvvqjqKMk" - }, - "outputs": [], "source": [ "fig = diagnostic_report.get_visualization(property_name='Data Validity')\n", "fig.show()" - ] + ], + "metadata": { + "id": "hTwMvvqjqKMk" + }, + "execution_count": null, + "outputs": [] }, { + "id": "7b779403", "cell_type": "markdown", - "metadata": { - "id": "7Xqk1ezgCrmG" - }, "source": [ "**4. New Row Synthesis**\n", "\n", @@ -625,33 +692,31 @@ "(best) 1.0: The rows in the synthetic data are all new. There are no matches with the real data.\n", "\n", "(worst) 0.0: All the rows in the synthetic data are copies of rows in the real data. Definitely undesired." - ] + ], + "metadata": { + "id": "7Xqk1ezgCrmG" + } }, { + "id": "df5e0dce", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NGXSh86lqNcw" - }, - "outputs": [], "source": [ "# calculate whether the synthetic data is new or whether it's an exact copy of the real data\n", "\n", "# Quite slow so, might want to consider subsampling.\n", "from sdmetrics.single_table import NewRowSynthesis\n", "\n", - "NewRowSynthesis.compute(\n", - " data,\n", - " synthetic_data,\n", - " metadata_dict\n", - ")" - ] + "NewRowSynthesis.compute(data, synthetic_data, metadata_dict)" + ], + "metadata": { + "id": "NGXSh86lqNcw" + }, + "execution_count": null, + "outputs": [] }, { + "id": "e3c1c60c", "cell_type": "markdown", - "metadata": { - "id": "5O9O1GBbCjn0" - }, "source": [ "**5. \"Privacy\": Disclosure protection**\n", "\n", @@ -664,26 +729,26 @@ "Scores between 0.0 and 1.0 indicate the relative risk of disclosure. For example, a score of 0.825 indicates that the synthetic data has 82.5% of the protection that random data would provide.\n", "\n", "Read more [in this documentation](https://docs.sdv.dev/sdmetrics/metrics/metrics-glossary/disclosureprotection)." - ] + ], + "metadata": { + "id": "5O9O1GBbCjn0" + } }, { + "id": "cc42dade", "cell_type": "code", - "execution_count": null, + "source": [ + "data.columns" + ], "metadata": { "id": "1FEMv9H5qe6f" }, - "outputs": [], - "source": [ - "data.columns" - ] + "execution_count": null, + "outputs": [] }, { + "id": "24405bd5", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XmVqp_NeqRQW" - }, - "outputs": [], "source": [ "from sdmetrics.single_table import DisclosureProtectionEstimate\n", "\n", @@ -698,21 +763,26 @@ " 'race',\n", " 'sex',\n", " 'native-country',\n", - " 'income>50K'\n", + " 'income>50K',\n", "]\n", "score = DisclosureProtectionEstimate.compute_breakdown(\n", " real_data=data,\n", " synthetic_data=synthetic_data,\n", " known_column_names=['age', 'native-country'],\n", " sensitive_column_names=['income>50K'],\n", - " continuous_column_names=list(set(data.columns)-set(discrete_columns)),\n", + " continuous_column_names=list(set(data.columns) - set(discrete_columns)),\n", " num_rows_subsample=2500,\n", " num_iterations=100,\n", - " verbose=True\n", + " verbose=True,\n", ")\n", "\n", "score" - ] + ], + "metadata": { + "id": "XmVqp_NeqRQW" + }, + "execution_count": null, + "outputs": [] } ], "metadata": { @@ -728,6 +798,6 @@ "name": "python" } }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 0, + "nbformat": 4 } diff --git a/dpsynth/examples/quickstart.ipynb b/dpsynth/examples/quickstart.ipynb index a83e33b..28c87bb 100644 --- a/dpsynth/examples/quickstart.ipynb +++ b/dpsynth/examples/quickstart.ipynb @@ -13,18 +13,18 @@ } }, { - "metadata": { - "cellView": "form", - "id": "267so-uABJnh" - }, "id": "267so-uABJnh", "cell_type": "code", "source": [ - "#@title Install DPSynth\n", + "# @title Install DPSynth\n", "!pip install git+https://github.com/google/dpsynth.git" ], - "outputs": [], - "execution_count": null + "metadata": { + "cellView": "form", + "id": "267so-uABJnh" + }, + "execution_count": null, + "outputs": [] }, { "id": "dcb89b1a", @@ -86,36 +86,44 @@ "source": [ "import os\n", "import dpsynth\n", + "from dpsynth import data_generation_v3\n", "from dpsynth import discrete_mechanisms\n", "from dpsynth import domain\n", + "import numpy as np\n", "import pandas as pd\n", "\n", "# 1. Load the downloaded dataset\n", "sensitive_df = pd.read_csv(os.path.join(ds_path, 'adult.csv'))\n", "\n", - "print(\"Original Sensitive Data (sample):\")\n", + "print('Original Sensitive Data (sample):')\n", "print(sensitive_df[['age', 'education']].head())\n", "\n", "# 2. Define domain schema\n", "# We consider age as Categorical to avoid discretization issues as suggested by reviewer.\n", "attribute_domains = {\n", - " 'age': domain.CategoricalAttribute(possible_values=list(range(17, 91))), # Adult dataset age range\n", - " 'education': domain.CategoricalAttribute(possible_values=sensitive_df['education'].unique().tolist())\n", + " 'age': domain.CategoricalAttribute(\n", + " possible_values=list(range(17, 91))\n", + " ), # Adult dataset age range\n", + " 'education': domain.CategoricalAttribute(\n", + " possible_values=sensitive_df['education'].unique().tolist()\n", + " ),\n", "}\n", "\n", - "# 3. Configure the synthesis mechanism (MST as default)\n", - "mst_config = discrete_mechanisms.MSTConfig(seed=42)\n", + "# 3. Configure and run the synthesis mechanism (MST as default)\n", + "mechanism = data_generation_v3.TabularSynthesizer(\n", + " domains=attribute_domains,\n", + " discrete_mechanism=discrete_mechanisms.MSTConfig(seed=42),\n", + ").calibrate(epsilon=1.0, delta=1e-5)\n", "\n", "# 4. Generate Differentially Private synthetic data\n", - "synthetic_df = dpsynth.generate(\n", - " data=sensitive_df[['age', 'education']], # Restrict to these columns for simplicity\n", - " domains=attribute_domains,\n", - " epsilon=1.0,\n", - " delta=1e-5,\n", - " discrete_config=mst_config,\n", + "synthetic_df = mechanism(\n", + " np.random.default_rng(),\n", + " sensitive_df[\n", + " ['age', 'education']\n", + " ], # Restrict to these columns for simplicity\n", ")\n", "\n", - "print(\"\\nGenerated Synthetic Data (sample):\")\n", + "print('\\nGenerated Synthetic Data (sample):')\n", "print(synthetic_df.head())" ], "metadata": { @@ -142,23 +150,24 @@ "source": [ "import apache_beam as beam\n", "from dpsynth import data_generation\n", - "from dpsynth.pipeline_transformations import types\n", "from dpsynth import domain\n", "from dpsynth.dataset_descriptors import csv_descriptor\n", - "import pipeline_dp\n", + "from dpsynth.pipeline_transformations import types\n", "import pandas as pd\n", + "import pipeline_dp\n", "\n", "# 1. Define domain schema\n", "# We use a small sample to create the descriptor\n", "sample_df = pd.read_csv('adult.csv', nrows=100)\n", "descriptor = csv_descriptor.get_dataset_descriptor_for_csv(\n", - " dataframe=sample_df,\n", - " field_names=[\"age\", \"education\"]\n", + " dataframe=sample_df, field_names=['age', 'education']\n", ")\n", "\n", "attribute_domains = {\n", " 'age': domain.CategoricalAttribute(possible_values=list(range(17, 91))),\n", - " 'education': domain.CategoricalAttribute(possible_values=sample_df['education'].unique().tolist())\n", + " 'education': domain.CategoricalAttribute(\n", + " possible_values=sample_df['education'].unique().tolist()\n", + " ),\n", "}\n", "\n", "# 2. Configure the synthesis mechanism\n", @@ -173,29 +182,31 @@ "\n", "# 3. Run the pipeline with BeamBackend\n", "with beam.Pipeline() as pipeline:\n", - " # Load data from CSV file in Beam\n", - " # We use a simple map to parse lines for simplicity\n", - " def csv_to_dict(line):\n", - " values = line.split(',')\n", - " # Assuming standard order: age is 0, education is 3\n", - " return {'age': int(values[0]), 'education': values[3].strip()}\n", - "\n", - " raw_records = (\n", - " pipeline\n", - " | \"ReadCSV\" >> beam.io.ReadFromText('adult.csv', skip_header_lines=1)\n", - " | \"ParseCSV\" >> beam.Map(csv_to_dict)\n", - " )\n", - "\n", - " beam_backend = pipeline_dp.BeamBackend()\n", - "\n", - " synthetic_records = data_generation.generate(\n", - " input_data=raw_records,\n", - " config=config,\n", - " backend=beam_backend,\n", - " )\n", - "\n", - " # Print sample of results\n", - " synthetic_records | \"Sample\" >> beam.combiners.Sample.FixedSizeGlobally(5) | \"Print\" >> beam.Map(print)" + " # Load data from CSV file in Beam\n", + " # We use a simple map to parse lines for simplicity\n", + " def csv_to_dict(line):\n", + " values = line.split(',')\n", + " # Assuming standard order: age is 0, education is 3\n", + " return {'age': int(values[0]), 'education': values[3].strip()}\n", + "\n", + " raw_records = (\n", + " pipeline\n", + " | 'ReadCSV' >> beam.io.ReadFromText('adult.csv', skip_header_lines=1)\n", + " | 'ParseCSV' >> beam.Map(csv_to_dict)\n", + " )\n", + "\n", + " beam_backend = pipeline_dp.BeamBackend()\n", + "\n", + " synthetic_records = data_generation.generate(\n", + " input_data=raw_records,\n", + " config=config,\n", + " backend=beam_backend,\n", + " )\n", + "\n", + " # Print sample of results\n", + " synthetic_records | 'Sample' >> beam.combiners.Sample.FixedSizeGlobally(\n", + " 5\n", + " ) | 'Print' >> beam.Map(print)" ], "metadata": { "id": "8a0a9c11" diff --git a/dpsynth/pydantic_api.py b/dpsynth/pydantic_api.py index fbef5bf..97212a1 100644 --- a/dpsynth/pydantic_api.py +++ b/dpsynth/pydantic_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Synthetic Tabular Data API for synthesizing collections of pydantic Models.""" +"""Pydantic-based API for DP synthetic tabular data generation.""" from collections.abc import Iterable import enum @@ -22,9 +22,10 @@ import typing # for typing.get_origin and typing.get_args from typing import Any, Literal, TypeVar -from dpsynth import data_generation_v2 +from dpsynth import data_generation_v3 from dpsynth import discrete_mechanisms from dpsynth import domain +import numpy as np import pandas as pd import pydantic from pydantic.fields import annotated_types @@ -101,16 +102,16 @@ def _categorical_attribute_from_field_info( """Infers a CategoricalAttribute from a pydantic FieldInfo.""" optional, base_type = _get_base_type(field_info.annotation) if inspect.isclass(base_type) and issubclass(base_type, enum.Enum): - possible_values = list(base_type) + possible_values = [e.value for e in base_type] elif typing.get_origin(base_type) is Literal: - possible_values = list(typing.get_args(base_type)) + possible_values = [str(v) for v in typing.get_args(base_type)] elif base_type is bool: - possible_values = [False, True] + possible_values = [str(False), str(True)] else: raise ValueError(f"Unexpected type annotation: {base_type}.") if optional: - possible_values = [None] + possible_values + possible_values = [str(None)] + possible_values return domain.CategoricalAttribute( possible_values=possible_values, out_of_domain_index=0 @@ -164,12 +165,71 @@ def dp_synthetic_data_generation( data = list(data) cls = data[0].__class__ - synthetic = data_generation_v2.generate( - data=pd.DataFrame([user.model_dump() for user in data], dtype="object"), - domains=infer_domain_from_model(cls), - epsilon=epsilon, - delta=delta, - discrete_config=mechanism_config, - ) - - return [cls(**user) for _, user in synthetic.iterrows()] + domains_dict = infer_domain_from_model(cls) + mechanism = data_generation_v3.TabularSynthesizer( + domains=domains_dict, + discrete_mechanism=mechanism_config, + ).calibrate(epsilon=epsilon, delta=delta) + records = [ + { + k: v.value if isinstance(v, enum.Enum) else v + for k, v in user.model_dump().items() + } + for user in data + ] + df = pd.DataFrame(records) + + # Track missing rates for optional numerical columns so we can restore + # None values in the output after v3 (which can't handle NaN). + # Also ensure all numerical columns are proper numeric dtypes (model_dump + # may produce object-typed columns). + none_rates: dict[str, float] = {} + for col, attr in domains_dict.items(): + if isinstance(attr, domain.NumericalAttribute): + if not attr.clip_to_range: # optional field + rate = df[col].isna().mean() + none_rates[col] = rate + mid = (attr.min_value + attr.max_value) / 2 + df[col] = df[col].fillna(mid) + df[col] = pd.to_numeric(df[col]) + + # Cast categorical columns to str so that mixed types (None, bool, str) + # are all homogeneous and sortable for v3's np.unique-based encoder. + for col, attr in domains_dict.items(): + if isinstance(attr, domain.CategoricalAttribute): + df[col] = df[col].astype(str) + synthetic = mechanism(np.random.default_rng(), df) + + # Post-process synthetic output for pydantic compatibility. + rng = np.random.default_rng() + for col, attr in domains_dict.items(): + if isinstance(attr, domain.CategoricalAttribute): + # Convert str("None") back to Python None for optional categoricals. + synthetic[col] = synthetic[col].where( + synthetic[col] != str(None), other=np.nan + ) + elif isinstance(attr, domain.NumericalAttribute): + # v3 outputs continuous floats; round int-typed columns for pydantic. + # For optional columns (clip_to_range=False), v3 may include NaN for + # out-of-domain values, so we round only non-NaN cells. + if attr.dtype == "int": + col_data = pd.to_numeric(synthetic[col], errors="coerce") + non_nan = col_data.notna() + col_data.loc[non_nan] = col_data.loc[non_nan].round() + synthetic[col] = col_data + # Restore None values for optional numerical fields. + if col in none_rates and none_rates[col] > 0: + mask = rng.random(len(synthetic)) < none_rates[col] + synthetic[col] = synthetic[col].where(~mask, other=np.nan) + + # iterrows() converts None to NaN for numeric columns; pydantic needs None. + def _safe_isnan(v): + try: + return pd.isna(v) and not isinstance(v, str) + except (ValueError, TypeError): + return False + + return [ + cls(**{k: None if _safe_isnan(v) else v for k, v in row.items()}) + for _, row in synthetic.iterrows() + ] diff --git a/tests/data_generation_v3_test.py b/tests/data_generation_v3_test.py index ba979e7..bf00e16 100644 --- a/tests/data_generation_v3_test.py +++ b/tests/data_generation_v3_test.py @@ -21,7 +21,7 @@ import numpy as np import pandas as pd -DataGenerationV3 = data_generation_v3.DataGenerationV3 +TabularSynthesizer = data_generation_v3.TabularSynthesizer class DataGenerationV3Test(absltest.TestCase): @@ -37,7 +37,7 @@ def test_end_to_end_categorical(self): } df = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['x', 'y', 'z']}) rng = np.random.default_rng(0) - calibrated = DataGenerationV3(domains=domains).calibrate(zcdp_rho=100.0) + calibrated = TabularSynthesizer(domains=domains).calibrate(zcdp_rho=100.0) synthetic_df = calibrated(rng, df) self.assertIsInstance(synthetic_df, pd.DataFrame) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) @@ -49,7 +49,7 @@ def test_end_to_end_numerical(self): } df = pd.DataFrame({'A': [5, 5, 0], 'B': [5, -10, -5]}, dtype=float) rng = np.random.default_rng(0) - calibrated = DataGenerationV3(domains=domains).calibrate(zcdp_rho=100.0) + calibrated = TabularSynthesizer(domains=domains).calibrate(zcdp_rho=100.0) synthetic_df = calibrated(rng, df) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) for col, attr in domains.items(): @@ -64,7 +64,7 @@ def test_end_to_end_mixed_domain(self): } df = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': [1.0, 5.0, 10.0]}) rng = np.random.default_rng(0) - calibrated = DataGenerationV3(domains=domains).calibrate( + calibrated = TabularSynthesizer(domains=domains).calibrate( zcdp_rho=100.0, delta=1e-5 ) synthetic_df = calibrated(rng, df) @@ -82,7 +82,7 @@ def test_end_to_end_with_epsilon_delta(self): } df = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['x', 'y', 'z']}) rng = np.random.default_rng(0) - calibrated = DataGenerationV3(domains=domains).calibrate( + calibrated = TabularSynthesizer(domains=domains).calibrate( epsilon=100, delta=0.1 ) synthetic_df = calibrated(rng, df) @@ -94,7 +94,7 @@ def test_raises_on_freeform_text_attribute(self): 'A': domain.CategoricalAttribute(possible_values=['a', 'b']), 'text': domain.FreeFormTextAttribute(max_tokens=128), } - v3 = DataGenerationV3(domains=domains) + v3 = TabularSynthesizer(domains=domains) with self.assertRaises(ValueError): v3.calibrate(zcdp_rho=1.0) @@ -106,7 +106,7 @@ def test_raises_when_not_calibrated(self): } df = pd.DataFrame({'A': ['a', 'b', 'c']}) rng = np.random.default_rng(0) - v3 = DataGenerationV3(domains=domains) + v3 = TabularSynthesizer(domains=domains) with self.assertRaises(ValueError): v3(rng, df) @@ -116,7 +116,7 @@ def test_dp_event_returns_composed_event(self): possible_values=['a', 'b', 'c'], out_of_domain_index=0 ), } - calibrated = DataGenerationV3(domains=domains).calibrate(zcdp_rho=100.0) + calibrated = TabularSynthesizer(domains=domains).calibrate(zcdp_rho=100.0) self.assertIsInstance(calibrated.dp_event, dp_accounting.ComposedDpEvent) def test_calibrate_raises_on_conflicting_params(self): @@ -125,7 +125,7 @@ def test_calibrate_raises_on_conflicting_params(self): possible_values=['a', 'b', 'c'], out_of_domain_index=0 ), } - v3 = DataGenerationV3(domains=domains) + v3 = TabularSynthesizer(domains=domains) with self.assertRaises(ValueError): v3.calibrate(zcdp_rho=1.0, epsilon=1.0, delta=1e-5) @@ -140,7 +140,7 @@ def test_calibrate_small_epsilon(self): } df = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['x', 'y', 'z']}) rng = np.random.default_rng(0) - calibrated = DataGenerationV3(domains=domains).calibrate( + calibrated = TabularSynthesizer(domains=domains).calibrate( epsilon=0.2, delta=1e-5 ) synthetic_df = calibrated(rng, df) diff --git a/tests/pydantic_api_test.py b/tests/pydantic_api_test.py index 36c751a..a91e1b7 100644 --- a/tests/pydantic_api_test.py +++ b/tests/pydantic_api_test.py @@ -140,7 +140,7 @@ def test_categorical_attribute_from_field_info(self): self.assertEqual( attr_bool, domain.CategoricalAttribute( - possible_values=[False, True], out_of_domain_index=0 + possible_values=["False", "True"], out_of_domain_index=0 ), ) @@ -151,7 +151,7 @@ def test_categorical_attribute_from_field_info(self): self.assertEqual( attr_enum_opt, domain.CategoricalAttribute( - possible_values=[None, Color.RED, Color.GREEN, Color.BLUE], + possible_values=["None", "red", "green", "blue"], out_of_domain_index=0, ), ) @@ -180,10 +180,10 @@ def test_infer_domain_from_model(self): dtype="float", ), "is_member": domain.CategoricalAttribute( - possible_values=[False, True], out_of_domain_index=0 + possible_values=["False", "True"], out_of_domain_index=0 ), "favorite_color": domain.CategoricalAttribute( - possible_values=[Color.RED, Color.GREEN, Color.BLUE], + possible_values=["red", "green", "blue"], out_of_domain_index=0, ), "optional_code": domain.NumericalAttribute( @@ -193,14 +193,14 @@ def test_infer_domain_from_model(self): possible_values=["active", "inactive"], out_of_domain_index=0 ), "optional_bool": domain.CategoricalAttribute( - possible_values=[None, False, True], out_of_domain_index=0 + possible_values=["None", "False", "True"], out_of_domain_index=0 ), "optional_enum": domain.CategoricalAttribute( - possible_values=[None, Color.RED, Color.GREEN, Color.BLUE], + possible_values=["None", "red", "green", "blue"], out_of_domain_index=0, ), "optional_literal": domain.CategoricalAttribute( - possible_values=[None, "a", "b"], out_of_domain_index=0 + possible_values=["None", "a", "b"], out_of_domain_index=0 ), } self.assertEqual(domain_spec, expected_domain_spec)