From 68e1e2ff1473b3e2eb3fb97f204ed7a21ca2579d Mon Sep 17 00:00:00 2001 From: tianhao Date: Wed, 17 Jun 2026 23:52:36 +0800 Subject: [PATCH] fix initial measurement sigma --- dpsynth/discrete_mechanisms/aim.py | 5 +++-- dpsynth/discrete_mechanisms/mst.py | 5 +++-- tests/discrete_mechanisms/aim_test.py | 20 ++++++++++++++++++++ tests/discrete_mechanisms/mst_test.py | 20 ++++++++++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/dpsynth/discrete_mechanisms/aim.py b/dpsynth/discrete_mechanisms/aim.py index d4ea143..71a11a8 100644 --- a/dpsynth/discrete_mechanisms/aim.py +++ b/dpsynth/discrete_mechanisms/aim.py @@ -190,13 +190,14 @@ def __call__( rho_per_round = zcdp_rho / max_rounds if initial_measurements is None: - rho_remaining -= self.one_way_budget_fraction * zcdp_rho + one_way_zcdp_rho = self.one_way_budget_fraction * zcdp_rho + rho_remaining -= one_way_zcdp_rho marginal_queries = [cl for cl in candidates.keys() if len(cl) == 1] measurements = common.measure_marginals_with_noise( rng, data, marginal_queries=marginal_queries, - gdp_sigma=zcdp_rho * self.one_way_budget_fraction, + gdp_sigma=accounting.zcdp_gaussian_sigma(one_way_zcdp_rho), ) else: measurements = list(initial_measurements) diff --git a/dpsynth/discrete_mechanisms/mst.py b/dpsynth/discrete_mechanisms/mst.py index a8a3b3a..77be029 100644 --- a/dpsynth/discrete_mechanisms/mst.py +++ b/dpsynth/discrete_mechanisms/mst.py @@ -225,12 +225,13 @@ def __call__( budget_remaining = self.zcdp_rho if initial_measurements is None: - budget_remaining -= self.one_way_budget_fraction * self.zcdp_rho + one_way_zcdp_rho = self.one_way_budget_fraction * self.zcdp_rho + budget_remaining -= one_way_zcdp_rho one_way_measurements = common.measure_marginals_with_noise( rng, data, marginal_queries=[(a,) for a in data.domain], - gdp_sigma=self.zcdp_rho * self.one_way_budget_fraction, + gdp_sigma=accounting.zcdp_gaussian_sigma(one_way_zcdp_rho), ) else: one_way_measurements = initial_measurements diff --git a/tests/discrete_mechanisms/aim_test.py b/tests/discrete_mechanisms/aim_test.py index 4924a9a..7040398 100644 --- a/tests/discrete_mechanisms/aim_test.py +++ b/tests/discrete_mechanisms/aim_test.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest +from dpsynth.discrete_mechanisms import accounting from dpsynth.discrete_mechanisms import aim from dpsynth.discrete_mechanisms import aim_gdp import mbi @@ -65,6 +68,23 @@ def test_uncalibrated_aim_gdp_raises(self): with self.assertRaisesRegex(ValueError, "calibrate"): config(np.random.default_rng(0), data) + def test_aim_initial_measurement_uses_gaussian_sigma(self): + data = mbi.Dataset.synthetic(mbi.Domain(["a"], [3]), N=10) + config = aim.AIMMechanism( + workload=[("a",)], max_rounds=1, one_way_budget_fraction=0.25 + ).calibrate(zcdp_rho=2.0) + expected_sigma = accounting.zcdp_gaussian_sigma(0.5) + + def assert_sigma(*_, gdp_sigma, **__): + self.assertAlmostEqual(gdp_sigma, expected_sigma) + raise RuntimeError("stop after checking sigma") + + with mock.patch.object( + aim.common, "measure_marginals_with_noise", side_effect=assert_sigma + ): + with self.assertRaisesRegex(RuntimeError, "checking sigma"): + config(np.random.default_rng(0), data) + if __name__ == "__main__": absltest.main() diff --git a/tests/discrete_mechanisms/mst_test.py b/tests/discrete_mechanisms/mst_test.py index 7f607f2..6bdfd07 100644 --- a/tests/discrete_mechanisms/mst_test.py +++ b/tests/discrete_mechanisms/mst_test.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from absl.testing import absltest import dp_accounting +from dpsynth.discrete_mechanisms import accounting from dpsynth.discrete_mechanisms import mst import mbi import numpy as np @@ -97,6 +100,23 @@ def test_dp_event_returns_zcdp(self): event = config.dp_event self.assertIsInstance(event, dp_accounting.ZCDpEvent) + def test_initial_measurement_uses_gaussian_sigma(self): + data = mbi.Dataset.synthetic(mbi.Domain(['a'], [3]), N=10) + config = mst.MSTMechanism(one_way_budget_fraction=0.25).calibrate( + zcdp_rho=2.0 + ) + expected_sigma = accounting.zcdp_gaussian_sigma(0.5) + + def assert_sigma(*_, gdp_sigma, **__): + self.assertAlmostEqual(gdp_sigma, expected_sigma) + raise RuntimeError('stop after checking sigma') + + with mock.patch.object( + mst.common, 'measure_marginals_with_noise', side_effect=assert_sigma + ): + with self.assertRaisesRegex(RuntimeError, 'checking sigma'): + config(np.random.default_rng(0), data) + if __name__ == '__main__': absltest.main()