Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dpsynth/discrete_mechanisms/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,14 @@ def __call__(
rho_per_round = zcdp_rho / max_rounds

if initial_measurements is None:
rho_remaining -= self.one_way_budget_fraction * zcdp_rho
one_way_zcdp_rho = self.one_way_budget_fraction * zcdp_rho
rho_remaining -= one_way_zcdp_rho
marginal_queries = [cl for cl in candidates.keys() if len(cl) == 1]
measurements = common.measure_marginals_with_noise(
rng,
data,
marginal_queries=marginal_queries,
gdp_sigma=zcdp_rho * self.one_way_budget_fraction,
gdp_sigma=accounting.zcdp_gaussian_sigma(one_way_zcdp_rho),
)
else:
measurements = list(initial_measurements)
Expand Down
5 changes: 3 additions & 2 deletions dpsynth/discrete_mechanisms/mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,13 @@ def __call__(
budget_remaining = self.zcdp_rho

if initial_measurements is None:
budget_remaining -= self.one_way_budget_fraction * self.zcdp_rho
one_way_zcdp_rho = self.one_way_budget_fraction * self.zcdp_rho
budget_remaining -= one_way_zcdp_rho
one_way_measurements = common.measure_marginals_with_noise(
rng,
data,
marginal_queries=[(a,) for a in data.domain],
gdp_sigma=self.zcdp_rho * self.one_way_budget_fraction,
gdp_sigma=accounting.zcdp_gaussian_sigma(one_way_zcdp_rho),
)
else:
one_way_measurements = initial_measurements
Expand Down
20 changes: 20 additions & 0 deletions tests/discrete_mechanisms/aim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest
from dpsynth.discrete_mechanisms import accounting
from dpsynth.discrete_mechanisms import aim
from dpsynth.discrete_mechanisms import aim_gdp
import mbi
Expand Down Expand Up @@ -65,6 +68,23 @@ def test_uncalibrated_aim_gdp_raises(self):
with self.assertRaisesRegex(ValueError, "calibrate"):
config(np.random.default_rng(0), data)

def test_aim_initial_measurement_uses_gaussian_sigma(self):
data = mbi.Dataset.synthetic(mbi.Domain(["a"], [3]), N=10)
config = aim.AIMMechanism(
workload=[("a",)], max_rounds=1, one_way_budget_fraction=0.25
).calibrate(zcdp_rho=2.0)
expected_sigma = accounting.zcdp_gaussian_sigma(0.5)

def assert_sigma(*_, gdp_sigma, **__):
self.assertAlmostEqual(gdp_sigma, expected_sigma)
raise RuntimeError("stop after checking sigma")

with mock.patch.object(
aim.common, "measure_marginals_with_noise", side_effect=assert_sigma
):
with self.assertRaisesRegex(RuntimeError, "checking sigma"):
config(np.random.default_rng(0), data)


if __name__ == "__main__":
absltest.main()
20 changes: 20 additions & 0 deletions tests/discrete_mechanisms/mst_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest
import dp_accounting
from dpsynth.discrete_mechanisms import accounting
from dpsynth.discrete_mechanisms import mst
import mbi
import numpy as np
Expand Down Expand Up @@ -97,6 +100,23 @@ def test_dp_event_returns_zcdp(self):
event = config.dp_event
self.assertIsInstance(event, dp_accounting.ZCDpEvent)

def test_initial_measurement_uses_gaussian_sigma(self):
data = mbi.Dataset.synthetic(mbi.Domain(['a'], [3]), N=10)
config = mst.MSTMechanism(one_way_budget_fraction=0.25).calibrate(
zcdp_rho=2.0
)
expected_sigma = accounting.zcdp_gaussian_sigma(0.5)

def assert_sigma(*_, gdp_sigma, **__):
self.assertAlmostEqual(gdp_sigma, expected_sigma)
raise RuntimeError('stop after checking sigma')

with mock.patch.object(
mst.common, 'measure_marginals_with_noise', side_effect=assert_sigma
):
with self.assertRaisesRegex(RuntimeError, 'checking sigma'):
config(np.random.default_rng(0), data)


if __name__ == '__main__':
absltest.main()