From 76bbb18cd8c286ff078932f1ed9f58c0deed5660 Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Sat, 20 Jun 2026 12:49:47 -0700 Subject: [PATCH] Return result dataclasses from DP mechanism __call__ methods Wrap all DPMechanism.__call__ return types in typed result dataclasses to enable future extensibility (e.g., extra measured information, or additional auxialiary outputs) without breaking existing callers or serialized objects. New optional fields can be added to any result class with default=None. PiperOrigin-RevId: 935368864 --- dpsynth/data_generation_v2.py | 4 +- dpsynth/data_generation_v3.py | 19 +++++-- dpsynth/discrete_mechanisms/__init__.py | 6 +++ dpsynth/discrete_mechanisms/aim.py | 13 +++-- dpsynth/discrete_mechanisms/aim_gdp.py | 13 +++-- dpsynth/discrete_mechanisms/direct.py | 11 ++++- dpsynth/discrete_mechanisms/independent.py | 11 ++++- dpsynth/discrete_mechanisms/mst.py | 13 +++-- dpsynth/discrete_mechanisms/swift.py | 13 +++-- dpsynth/local_mode/initialization.py | 7 +-- dpsynth/local_mode/primitives.py | 49 ++++++++++++++++--- tests/data_generation_v3_test.py | 15 ++++-- tests/discrete_mechanisms/aim_test.py | 4 +- tests/discrete_mechanisms/direct_test.py | 4 +- tests/discrete_mechanisms/independent_test.py | 4 +- tests/discrete_mechanisms/mst_test.py | 2 +- tests/discrete_mechanisms/swift_test.py | 2 +- tests/local_mode/primitives_test.py | 8 +-- 18 files changed, 149 insertions(+), 49 deletions(-) diff --git a/dpsynth/data_generation_v2.py b/dpsynth/data_generation_v2.py index 7d83c9e..dad7e54 100644 --- a/dpsynth/data_generation_v2.py +++ b/dpsynth/data_generation_v2.py @@ -265,14 +265,14 @@ def generate( cross_attribute_constraints, discrete.domain ) - model = discrete_config.calibrate(zcdp_rho=discrete_zcdp_rho)( + result = discrete_config.calibrate(zcdp_rho=discrete_zcdp_rho)( rng, data=discrete, initial_measurements=one_way_measurements, initial_potentials=initial_potentials, ) - synthetic_data = model.synthetic_data() + synthetic_data = result.model.synthetic_data() logging.info('[SynthKit Tabular]: Generated discrete synthetic data.') # Convert synthetic data back to the original domain. diff --git a/dpsynth/data_generation_v3.py b/dpsynth/data_generation_v3.py index d4172f8..154d076 100644 --- a/dpsynth/data_generation_v3.py +++ b/dpsynth/data_generation_v3.py @@ -71,6 +71,13 @@ def _create_initializers( return initializers +@dataclasses.dataclass +class DataGenerationResult: + """Result of end-to-end DP synthetic data generation.""" + + synthetic_data: pd.DataFrame + + @dataclasses.dataclass class DataGenerationV3(primitives.DPMechanism): """End-to-end DP synthetic data generation mechanism. @@ -297,7 +304,7 @@ def dp_event(self) -> dp_accounting.DpEvent: def __call__( self, rng: np.random.Generator, data: pd.DataFrame - ) -> pd.DataFrame: + ) -> DataGenerationResult: """Generates differentially private synthetic data. Args: @@ -306,7 +313,7 @@ def __call__( specified in ``domains``. Returns: - A synthetic DataFrame with the same domain columns as the input. + A DataGenerationResult containing the synthetic DataFrame. Raises: ValueError: If calibrate() has not been called or if required columns are @@ -349,13 +356,13 @@ def __call__( initial_potentials = constraints.get_initial_parameters( self.cross_attribute_constraints, discrete.domain ) - model = self.discrete_mechanism( + mechanism_result = self.discrete_mechanism( rng, data=discrete, initial_measurements=one_way_measurements, initial_potentials=initial_potentials, ) - synthetic_data = model.synthetic_data() + synthetic_data = mechanism_result.model.synthetic_data() logging.info('[DPSynth]: Generated discrete synthetic data.') # Phase 4: Decode synthetic data back to original domain. @@ -373,4 +380,6 @@ def __call__( logging.info('[DPSynth]: Converted data back to original domain.') column_order = [col for col in data.columns if col in self.domains] - return pd.DataFrame(synthetic_columns)[column_order] + return DataGenerationResult( + synthetic_data=pd.DataFrame(synthetic_columns)[column_order] + ) diff --git a/dpsynth/discrete_mechanisms/__init__.py b/dpsynth/discrete_mechanisms/__init__.py index f29ef74..24e3473 100644 --- a/dpsynth/discrete_mechanisms/__init__.py +++ b/dpsynth/discrete_mechanisms/__init__.py @@ -17,11 +17,17 @@ # pylint: disable=g-importing-member from dpsynth.discrete_mechanisms.aim import AIMMechanism +from dpsynth.discrete_mechanisms.aim import AIMMechanismResult from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanism +from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanismResult from dpsynth.discrete_mechanisms.direct import DirectMechanism +from dpsynth.discrete_mechanisms.direct import DirectMechanismResult from dpsynth.discrete_mechanisms.independent import IndependentMechanism +from dpsynth.discrete_mechanisms.independent import IndependentMechanismResult from dpsynth.discrete_mechanisms.mst import MSTMechanism +from dpsynth.discrete_mechanisms.mst import MSTMechanismResult from dpsynth.discrete_mechanisms.swift import SWIFTMechanism +from dpsynth.discrete_mechanisms.swift import SWIFTMechanismResult from dpsynth.local_mode.primitives import DPMechanism as DiscreteMechanism # Backwards-compatible aliases. diff --git a/dpsynth/discrete_mechanisms/aim.py b/dpsynth/discrete_mechanisms/aim.py index d4ea143..cb9fd58 100644 --- a/dpsynth/discrete_mechanisms/aim.py +++ b/dpsynth/discrete_mechanisms/aim.py @@ -90,6 +90,13 @@ def _worst_approximated( return keys[idx] +@dataclasses.dataclass +class AIMMechanismResult: + """Result of running the AIM mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class AIMMechanism(primitives.DPMechanism): """Configuration for the AIM mechanism. @@ -155,7 +162,7 @@ def __call__( *, initial_measurements: list[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> AIMMechanismResult: """Runs the AIM mechanism on the given data. Args: @@ -165,7 +172,7 @@ def __call__( initial_potentials: Optional initial potentials (constraints). Returns: - A MarkovRandomField representing the estimated data distribution. + An AIMMechanismResult containing the estimated data distribution. """ if self.zcdp_rho is None: raise ValueError('Must call calibrate() before using the mechanism.') @@ -288,4 +295,4 @@ def __call__( sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round) logging.info('[AIM] Reducing sigma: %.1f', sigma) - return model + return AIMMechanismResult(model=model) diff --git a/dpsynth/discrete_mechanisms/aim_gdp.py b/dpsynth/discrete_mechanisms/aim_gdp.py index 624d5b0..7526c53 100644 --- a/dpsynth/discrete_mechanisms/aim_gdp.py +++ b/dpsynth/discrete_mechanisms/aim_gdp.py @@ -125,6 +125,13 @@ def _worst_approximated( return max(current_scores, key=current_scores.get) +@dataclasses.dataclass +class AIMGDPMechanismResult: + """Result of running the AIM-GDP mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class AIMGDPMechanism(primitives.DPMechanism): """Configuration for the AIM mechanism with Gaussian DP. @@ -201,7 +208,7 @@ def __call__( *, initial_measurements: list[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> AIMGDPMechanismResult: """Runs the AIM-GDP mechanism on the given data. Args: @@ -211,7 +218,7 @@ def __call__( initial_potentials: Optional initial potentials (constraints). Returns: - A MarkovRandomField representing the estimated data distribution. + An AIMGDPMechanismResult containing the estimated data distribution. """ if self.gdp_sigma is None: raise ValueError('Must call calibrate() before using the mechanism.') @@ -358,4 +365,4 @@ def __call__( '[AIM] Increasing budget per round: %.5f', budget_per_round ) - return model + return AIMGDPMechanismResult(model=model) diff --git a/dpsynth/discrete_mechanisms/direct.py b/dpsynth/discrete_mechanisms/direct.py index 6ba9d56..1e1387e 100644 --- a/dpsynth/discrete_mechanisms/direct.py +++ b/dpsynth/discrete_mechanisms/direct.py @@ -24,6 +24,13 @@ import numpy as np +@dataclasses.dataclass +class DirectMechanismResult: + """Result of running the direct mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class DirectMechanism(primitives.DPMechanism): """Configuration for the direct mechanism. @@ -63,7 +70,7 @@ def __call__( *, initial_measurements: list[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> DirectMechanismResult: """Generate synthetic data using user specified two way marginals.""" if self.gdp_sigma is None: raise ValueError('Must call calibrate() before using the mechanism.') @@ -88,4 +95,4 @@ def __call__( potentials=initial_potentials, marginal_oracle=marginal_oracle, ) - return model + return DirectMechanismResult(model=model) diff --git a/dpsynth/discrete_mechanisms/independent.py b/dpsynth/discrete_mechanisms/independent.py index 6485669..e88b32d 100644 --- a/dpsynth/discrete_mechanisms/independent.py +++ b/dpsynth/discrete_mechanisms/independent.py @@ -24,6 +24,13 @@ import numpy as np +@dataclasses.dataclass +class IndependentMechanismResult: + """Result of running the independent mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class IndependentMechanism(primitives.DPMechanism): """Configuration for the independent mechanism. @@ -60,7 +67,7 @@ def __call__( *, initial_measurements: list[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> IndependentMechanismResult: """Generate synthetic data via the independent mechanism.""" if self.gdp_sigma is None: raise ValueError('Must call calibrate() before using the mechanism.') @@ -91,4 +98,4 @@ def __call__( potentials=potentials, marginal_oracle=marginal_oracle, ) - return model + return IndependentMechanismResult(model=model) diff --git a/dpsynth/discrete_mechanisms/mst.py b/dpsynth/discrete_mechanisms/mst.py index a8a3b3a..6fb62ac 100644 --- a/dpsynth/discrete_mechanisms/mst.py +++ b/dpsynth/discrete_mechanisms/mst.py @@ -156,6 +156,13 @@ def _select_two_way_marginal_queries( ) +@dataclasses.dataclass +class MSTMechanismResult: + """Result of running the MST mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class MSTMechanism(primitives.DPMechanism): """Configuration for the maximum spanning tree mechanism. @@ -201,7 +208,7 @@ def __call__( *, initial_measurements: list[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> MSTMechanismResult: """Runs the MST mechanism on the given data. Args: @@ -212,7 +219,7 @@ def __call__( estimation. Returns: - A fitted MarkovRandomField model. + An MSTMechanismResult containing the estimated data distribution. Raises: ValueError: If calibrate() has not been called. @@ -271,4 +278,4 @@ def __call__( marginal_oracle=marginal_oracle, ) logging.info('[MST]: Fit distribution to the noisy measurements.') - return model + return MSTMechanismResult(model=model) diff --git a/dpsynth/discrete_mechanisms/swift.py b/dpsynth/discrete_mechanisms/swift.py index 2a698ef..4bce42f 100644 --- a/dpsynth/discrete_mechanisms/swift.py +++ b/dpsynth/discrete_mechanisms/swift.py @@ -44,6 +44,13 @@ import tqdm +@dataclasses.dataclass +class SWIFTMechanismResult: + """Result of running the SWIFT mechanism.""" + + model: mbi.MarkovRandomField + + @dataclasses.dataclass class SWIFTMechanism(primitives.DPMechanism): """Configuration for the SWIFT mechanism. @@ -94,7 +101,7 @@ def __call__( *, initial_measurements: Sequence[mbi.LinearMeasurement] | None = None, initial_potentials: mbi.CliqueVector | None = None, - ) -> mbi.MarkovRandomField: + ) -> SWIFTMechanismResult: """Runs the SWIFT mechanism on the given data. Args: @@ -105,7 +112,7 @@ def __call__( estimation. Returns: - A fitted MarkovRandomField model. + A SWIFTMechanismResult containing the estimated data distribution. Raises: ValueError: If calibrate() has not been called. @@ -197,7 +204,7 @@ def __call__( ) logging.info('[SWIFT] Estimated final model.') - return model + return SWIFTMechanismResult(model=model) def _is_supported(clique: mbi.Clique, tree: nx.Graph) -> bool: diff --git a/dpsynth/local_mode/initialization.py b/dpsynth/local_mode/initialization.py index cf05f3d..9611004 100644 --- a/dpsynth/local_mode/initialization.py +++ b/dpsynth/local_mode/initialization.py @@ -94,7 +94,7 @@ def __call__( ) -> ColumnMeasurement: """Returns a ColumnMeasurement with the discretization transform.""" # Dedup: concentrated data can make quantiles return duplicate edges. - edges = _validate_mechanism(self.mechanism)(rng, data) + edges = _validate_mechanism(self.mechanism)(rng, data).quantiles bin_edges = np.unique(np.asarray(edges, dtype=float)) cat_attr = vtx.categorical_attribute_from_edges(bin_edges, self.attribute) return ColumnMeasurement(cat_attr, bin_edges) @@ -136,7 +136,7 @@ def __call__( """Returns a ColumnMeasurement with the noisy histogram.""" mechanism = _validate_mechanism(self.mechanism) encoded = vtx.discrete_encode(data, self.attribute) - noisy_counts = mechanism(rng, encoded) + noisy_counts = mechanism(rng, encoded).counts measurement = mbi.LinearMeasurement( noisy_counts, (self.name,), stddev=mechanism.sigma ) @@ -185,7 +185,8 @@ def __call__( mechanism = _validate_mechanism(self.mechanism) # Map raw values to integer partition IDs for thresholding. unique_values, inverse = np.unique(data, return_inverse=True) - selected_ids, counts, _ = mechanism(rng, inverse) + result = mechanism(rng, inverse) + selected_ids, counts = result.selected_partitions, result.estimated_counts selected_values = list(unique_values[selected_ids]) # Build the discovered domain: default first, then selected values. diff --git a/dpsynth/local_mode/primitives.py b/dpsynth/local_mode/primitives.py index 37b65eb..59ebe36 100644 --- a/dpsynth/local_mode/primitives.py +++ b/dpsynth/local_mode/primitives.py @@ -32,6 +32,28 @@ import scipy.stats +@dataclasses.dataclass +class QuantileResult: + """Result of a differentially private quantile computation.""" + + quantiles: list[float] + + +@dataclasses.dataclass +class HistogramResult: + """Result of a differentially private histogram computation.""" + + counts: np.ndarray + + +@dataclasses.dataclass +class PartitionSelectionResult: + """Result of differentially private partition selection.""" + + selected_partitions: np.ndarray + estimated_counts: np.ndarray + + class DPMechanism(abc.ABC): """Abstract base class for differentially private mechanisms. @@ -522,12 +544,16 @@ def dp_event(self) -> dp_accounting.DpEvent: for eps in self._epsilon_levels ]) - def __call__(self, rng: np.random.Generator, data: np.ndarray) -> list[float]: + def __call__( + self, rng: np.random.Generator, data: np.ndarray + ) -> QuantileResult: """Computes differentially private quantiles.""" if self._epsilon_levels is None: raise ValueError(_UNCALIBRATED_MSG.format(param='_epsilon_levels')) - return _quantiles( - rng, data, self.lower, self.upper, np.asarray(self._epsilon_levels) + return QuantileResult( + quantiles=_quantiles( + rng, data, self.lower, self.upper, np.asarray(self._epsilon_levels) + ) ) @@ -557,11 +583,15 @@ def dp_event(self) -> dp_accounting.DpEvent: raise ValueError(_UNCALIBRATED_MSG.format(param='sigma')) return dp_accounting.GaussianDpEvent(noise_multiplier=self.sigma) - def __call__(self, rng: np.random.Generator, data: np.ndarray) -> np.ndarray: + def __call__( + self, rng: np.random.Generator, data: np.ndarray + ) -> HistogramResult: """Computes a differentially private histogram.""" if self.sigma is None: raise ValueError(_UNCALIBRATED_MSG.format(param='sigma')) - return _gaussian_histogram(rng, data, self.domain_size, self.sigma) + return HistogramResult( + counts=_gaussian_histogram(rng, data, self.domain_size, self.sigma) + ) @dataclasses.dataclass @@ -595,7 +625,7 @@ def dp_event(self) -> dp_accounting.DpEvent: def __call__( self, rng: np.random.Generator, data: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, float]: + ) -> PartitionSelectionResult: """Runs partition selection on integer-encoded partition IDs. Args: @@ -603,11 +633,14 @@ def __call__( data: 1D array of integer partition IDs. Returns: - A tuple of (selected_partitions, noisy_counts, sigma). + A ``PartitionSelectionResult`` with selected partitions and noisy counts. """ if self.sigma is None: raise ValueError(_UNCALIBRATED_MSG.format(param='sigma')) gdp_budget = np.inf if self.sigma == 0.0 else 1.0 / (self.sigma**2) - return select_partitions_gaussian_thresholding( + parts, counts, _ = select_partitions_gaussian_thresholding( rng, data, gdp_budget, self.delta ) + return PartitionSelectionResult( + selected_partitions=parts, estimated_counts=counts + ) diff --git a/tests/data_generation_v3_test.py b/tests/data_generation_v3_test.py index ba979e7..85bad5f 100644 --- a/tests/data_generation_v3_test.py +++ b/tests/data_generation_v3_test.py @@ -38,7 +38,8 @@ def test_end_to_end_categorical(self): df = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['x', 'y', 'z']}) rng = np.random.default_rng(0) calibrated = DataGenerationV3(domains=domains).calibrate(zcdp_rho=100.0) - synthetic_df = calibrated(rng, df) + result = calibrated(rng, df) + synthetic_df = result.synthetic_data self.assertIsInstance(synthetic_df, pd.DataFrame) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) @@ -50,7 +51,8 @@ def test_end_to_end_numerical(self): df = pd.DataFrame({'A': [5, 5, 0], 'B': [5, -10, -5]}, dtype=float) rng = np.random.default_rng(0) calibrated = DataGenerationV3(domains=domains).calibrate(zcdp_rho=100.0) - synthetic_df = calibrated(rng, df) + result = calibrated(rng, df) + synthetic_df = result.synthetic_data self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) for col, attr in domains.items(): self.assertTrue( @@ -67,7 +69,8 @@ def test_end_to_end_mixed_domain(self): calibrated = DataGenerationV3(domains=domains).calibrate( zcdp_rho=100.0, delta=1e-5 ) - synthetic_df = calibrated(rng, df) + result = calibrated(rng, df) + synthetic_df = result.synthetic_data self.assertIsInstance(synthetic_df, pd.DataFrame) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) @@ -85,7 +88,8 @@ def test_end_to_end_with_epsilon_delta(self): calibrated = DataGenerationV3(domains=domains).calibrate( epsilon=100, delta=0.1 ) - synthetic_df = calibrated(rng, df) + result = calibrated(rng, df) + synthetic_df = result.synthetic_data self.assertIsInstance(synthetic_df, pd.DataFrame) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) @@ -143,7 +147,8 @@ def test_calibrate_small_epsilon(self): calibrated = DataGenerationV3(domains=domains).calibrate( epsilon=0.2, delta=1e-5 ) - synthetic_df = calibrated(rng, df) + result = calibrated(rng, df) + synthetic_df = result.synthetic_data self.assertIsInstance(synthetic_df, pd.DataFrame) self.assertListEqual(synthetic_df.columns.tolist(), ['A', 'B']) diff --git a/tests/discrete_mechanisms/aim_test.py b/tests/discrete_mechanisms/aim_test.py index 4924a9a..8ae8b82 100644 --- a/tests/discrete_mechanisms/aim_test.py +++ b/tests/discrete_mechanisms/aim_test.py @@ -27,7 +27,7 @@ def test_fits_one_way_marginals_with_aim(self): config = aim.AIMMechanism(workload=workload, max_rounds=4, pgm_iters=500) calibrated = config.calibrate(zcdp_rho=10000) - synthetic = calibrated(np.random.default_rng(0), data) + synthetic = calibrated(np.random.default_rng(0), data).model for col in data.domain: expected = data.project([col]).datavector() @@ -42,7 +42,7 @@ def test_fits_one_way_marginals_with_aim_gdp(self): workload=workload, max_rounds=4, pgm_iters=500 ) calibrated = config.calibrate(zcdp_rho=10000) - synthetic = calibrated(np.random.default_rng(0), data) + synthetic = calibrated(np.random.default_rng(0), data).model for col in data.domain: expected = data.project([col]).datavector() diff --git a/tests/discrete_mechanisms/direct_test.py b/tests/discrete_mechanisms/direct_test.py index c3a65e6..bab1c62 100644 --- a/tests/discrete_mechanisms/direct_test.py +++ b/tests/discrete_mechanisms/direct_test.py @@ -31,7 +31,9 @@ def test_fits_one_way_marginals(self): ], pgm_iters=500, ) - synthetic = config.calibrate(zcdp_rho=10000)(np.random.default_rng(0), data) + synthetic = config.calibrate(zcdp_rho=10000)( + np.random.default_rng(0), data + ).model for col in data.domain: expected = data.project([col]).datavector() diff --git a/tests/discrete_mechanisms/independent_test.py b/tests/discrete_mechanisms/independent_test.py index e3668b7..6c4cfb3 100644 --- a/tests/discrete_mechanisms/independent_test.py +++ b/tests/discrete_mechanisms/independent_test.py @@ -24,7 +24,9 @@ def test_fits_one_way_marginals(self): data = mbi.Dataset.synthetic(mbi.Domain(["a", "b", "c"], [3, 4, 5]), N=1000) config = independent.IndependentMechanism(pgm_iters=500) - synthetic = config.calibrate(zcdp_rho=10000)(np.random.default_rng(0), data) + synthetic = config.calibrate(zcdp_rho=10000)( + np.random.default_rng(0), data + ).model for col in data.domain: expected = data.project([col]).datavector() diff --git a/tests/discrete_mechanisms/mst_test.py b/tests/discrete_mechanisms/mst_test.py index 7f607f2..3bcf6ed 100644 --- a/tests/discrete_mechanisms/mst_test.py +++ b/tests/discrete_mechanisms/mst_test.py @@ -74,7 +74,7 @@ def test_fits_one_way_marginals(self): config = mst.MSTMechanism(pgm_iters=500).calibrate(zcdp_rho=10000) - synthetic = config(np.random.default_rng(0), data) + synthetic = config(np.random.default_rng(0), data).model for col in data.domain: expected = data.project([col]).datavector() diff --git a/tests/discrete_mechanisms/swift_test.py b/tests/discrete_mechanisms/swift_test.py index 2dd67f9..a6911d2 100644 --- a/tests/discrete_mechanisms/swift_test.py +++ b/tests/discrete_mechanisms/swift_test.py @@ -124,7 +124,7 @@ def test_fits_one_way_marginals(self): config = swift.SWIFTMechanism(pgm_iters=500).calibrate(zcdp_rho=10000) - synthetic = config(np.random.default_rng(0), data) + synthetic = config(np.random.default_rng(0), data).model for col in data.domain: expected = data.project([col]).datavector() diff --git a/tests/local_mode/primitives_test.py b/tests/local_mode/primitives_test.py index 5257c68..d66cbef 100644 --- a/tests/local_mode/primitives_test.py +++ b/tests/local_mode/primitives_test.py @@ -313,7 +313,7 @@ def test_calibrate_and_call(self): calibrated = mech.calibrate(zcdp_rho=100.0) data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) result = calibrated(self.rng, data) - self.assertLen(result, 3) + self.assertLen(result.quantiles, 3) @parameterized.parameters([0.3, 1.0, 2.718]) def test_calibrate_default_ratio(self, zcdp_rho): @@ -366,13 +366,13 @@ def test_calibrate_and_call(self): calibrated = mech.calibrate(zcdp_rho=0.5) data = np.array([0, 0, 1, 1, 1, 2]) result = calibrated(self.rng, data) - self.assertLen(result, 4) - np.testing.assert_allclose(result, [2, 3, 1, 0], atol=5.0) + self.assertLen(result.counts, 4) + np.testing.assert_allclose(result.counts, [2, 3, 1, 0], atol=5.0) def test_direct_sigma(self): mech = primitives.DPGaussianHistogram(domain_size=3, sigma=0.0) data = np.array([0, 0, 1, 2, 2, 2]) - np.testing.assert_array_equal(mech(self.rng, data), [2, 1, 3]) + np.testing.assert_array_equal(mech(self.rng, data).counts, [2, 1, 3]) def test_dp_event_raises_before_calibration(self): mech = primitives.DPGaussianHistogram(domain_size=4)