Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions dpsynth/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from typing import Any, Literal, TypeAlias

import attr
import numpy as np
import pandas as pd
import yaml

Expand Down Expand Up @@ -141,6 +142,13 @@ class NumericalAttribute:
max_value] if they are a numeric type, and to min_value otherwise. If
False, out-of-domain values will be grouped together and treated as a
single special out-of-domain value.
sentinel: The value to assign to out-of-domain entries during reverse
discretization when ``clip_to_range`` is False. Defaults to ``None``,
which resolves to ``np.nan`` for numeric modes or ``''`` for
``interval_handling='interval'``. Set to an integer (e.g. ``-1``) to keep
the output array as an integer dtype instead of silently promoting to
float. When explicitly set, must be a string for interval mode and numeric
for other modes. Use ``resolved_sentinel`` to get the effective value.
dtype: The dtype of the data (either 'int' or 'float').
interval_handling: Controls how discretized intervals are converted back to
numerical values. 'midpoint' returns the interval midpoint (or the finite
Expand All @@ -153,6 +161,7 @@ class NumericalAttribute:
min_value: float = attr.field(converter=float)
max_value: float = attr.field(converter=float)
clip_to_range: bool = attr.field(default=True)
sentinel: float | int | str | None = attr.field(default=None)
dtype: str = attr.field(default='float')
interval_handling: str = attr.field(default='midpoint')
description: str | None = attr.field(default=None)
Expand All @@ -174,11 +183,34 @@ def _validate_dtype(self, *_):

@interval_handling.validator # pytype: disable=attribute-error
def _validate_interval_handling(self, *_):
"""Validates interval_handling mode and sentinel type compatibility."""
if self.interval_handling not in ['midpoint', 'sample', 'interval']:
raise ValueError(
'interval_handling must be "midpoint", "sample", or "interval",'
f' got {self.interval_handling}.'
)
if self.sentinel is not None:
if self.interval_handling == 'interval':
if not isinstance(self.sentinel, str):
raise ValueError(
"interval_handling='interval' requires a string sentinel, got"
f' sentinel={self.sentinel!r}.'
)
elif not isinstance(self.sentinel, (int, float, np.integer, np.floating)):
raise ValueError(
'sentinel must be numeric when'
f' interval_handling={self.interval_handling!r}, got'
f' sentinel={self.sentinel!r}.'
)

@property
def resolved_sentinel(self) -> float | int | str:
"""Returns the effective sentinel, with mode-appropriate defaults."""
if self.sentinel is not None:
return self.sentinel
if self.interval_handling == 'interval':
return ''
return np.nan

@property
def exclusive_min_value(self) -> float:
Expand Down
10 changes: 4 additions & 6 deletions dpsynth/local_mode/vectorized_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,24 @@ def undiscretize(
A 1-D array. For ``'midpoint'`` and ``'sample'`` the dtype is float
(or int when ``dtype == 'int'`` and all values are in-domain). For
``'interval'`` the dtype is ``object`` (strings). Out-of-domain bins
(index 0 when ``clip_to_range`` is ``False``) map to ``NaN`` or ``""``.
(index 0 when ``clip_to_range`` is ``False``) map to
``attribute_domain.resolved_sentinel``.
"""
rng = np.random.default_rng(rng)
min_, max_ = attribute_domain.exclusive_min_value, attribute_domain.max_value
_validate_bin_edges(bin_edges, attribute_domain)
full_edges = np.r_[min_, bin_edges, max_]
lefts, rights = full_edges[:-1], full_edges[1:]
handling = attribute_domain.interval_handling
sentinel = attribute_domain.resolved_sentinel

if handling == 'interval':
values = np.array([f'({l}, {r}]' for l, r in zip(lefts, rights)], dtype=str)
if not attribute_domain.clip_to_range:
sentinel = np.array('', dtype=str)
values = np.r_[sentinel, values]
values = np.r_[np.array(sentinel, dtype=str), values]
return values[bin_indices]
elif handling == 'sample':
if not attribute_domain.clip_to_range:
sentinel = np.nan
ood = bin_indices == 0
idx = bin_indices - 1
result = np.where(ood, sentinel, rng.uniform(lefts[idx], rights[idx]))
Expand All @@ -188,14 +188,12 @@ def undiscretize(
elif handling == 'midpoint':
midpoints = (lefts + rights) / 2.0
if not attribute_domain.clip_to_range:
sentinel = np.nan
midpoints = np.r_[sentinel, midpoints]
result = midpoints[bin_indices]
else:
raise ValueError(f'Unsupported interval_handling: {handling}')

if attribute_domain.dtype == 'int' and attribute_domain.clip_to_range:
# If clip_to_range=False, then NaNs are possible so we don't cast to int.
result = np.ceil(result).astype(int)
return result

Expand Down
8 changes: 7 additions & 1 deletion dpsynth/pydantic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,10 @@ def dp_synthetic_data_generation(
discrete_config=mechanism_config,
)

return [cls(**user) for _, user in synthetic.iterrows()]
return [
cls(**{
k: None if isinstance(v, float) and math.isnan(v) else v
for k, v in user.items()
})
for _, user in synthetic.iterrows()
]
3 changes: 2 additions & 1 deletion dpsynth/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def create_discretize_transformation(
]
intervals = pd.IntervalIndex.from_breaks(bin_edges)
maybe_none = [] if attribute_domain.clip_to_range else [None]
sentinel = attribute_domain.resolved_sentinel
possible_values = maybe_none + list(intervals)

def transform(value: Any) -> pd.Interval | None:
Expand All @@ -190,7 +191,7 @@ def _resolve_finite(interval: pd.Interval) -> float:

def reverse(value: pd.Interval | None) -> float | pd.Interval | None:
if value is None:
return None
return sentinel
if attribute_domain.interval_handling == 'interval':
return value
if attribute_domain.interval_handling == 'sample':
Expand Down
50 changes: 50 additions & 0 deletions tests/domain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from absl.testing import absltest
from dpsynth import domain
import numpy as np


class TestDomain(absltest.TestCase):
Expand Down Expand Up @@ -100,6 +101,55 @@ def test_standardize_numerical(self):
for value in ood_values:
self.assertIsNone(attribute.standardize(value))

def test_numerical_attribute_default_sentinel(self):
attribute = domain.NumericalAttribute(0, 10)
self.assertIsNone(attribute.sentinel)
self.assertTrue(np.isnan(attribute.resolved_sentinel))

def test_numerical_attribute_custom_sentinel(self):
attribute = domain.NumericalAttribute(0, 10, sentinel=-1)
self.assertEqual(attribute.sentinel, -1)
self.assertEqual(attribute.resolved_sentinel, -1)

def test_sentinel_yaml_roundtrip(self):
original = {
'num': domain.NumericalAttribute(
min_value=0, max_value=10, clip_to_range=False, sentinel=-1
),
}
temp_file = self.create_tempfile('temp.yaml', mode='w+')
domain.to_yaml_file(original, temp_file.full_path)
loaded = domain.from_yaml_file(temp_file.full_path)
self.assertEqual(loaded['num'].sentinel, -1)

def test_string_sentinel_allowed_with_interval_handling(self):
attr = domain.NumericalAttribute(
0, 10, sentinel='MISSING', interval_handling='interval'
)
self.assertEqual(attr.sentinel, 'MISSING')
self.assertEqual(attr.resolved_sentinel, 'MISSING')

def test_interval_handling_default_sentinel(self):
attr = domain.NumericalAttribute(0, 10, interval_handling='interval')
self.assertIsNone(attr.sentinel)
self.assertEqual(attr.resolved_sentinel, '')

def test_string_sentinel_rejected_with_midpoint_handling(self):
with self.assertRaises(ValueError):
domain.NumericalAttribute(0, 10, sentinel='MISSING')

def test_numeric_sentinel_rejected_with_interval_handling(self):
with self.assertRaises(ValueError):
domain.NumericalAttribute(
0, 10, sentinel=42, interval_handling='interval'
)

def test_numpy_numeric_sentinel_accepted(self):
attr = domain.NumericalAttribute(0, 10, sentinel=np.int32(-1))
self.assertEqual(attr.sentinel, -1)
attr = domain.NumericalAttribute(0, 10, sentinel=np.float32(0.5))
self.assertAlmostEqual(attr.sentinel, 0.5, places=5)

def test_freeform_text_defaults(self):
attribute = domain.FreeFormTextAttribute()
self.assertEqual(attribute.max_tokens, 256)
Expand Down
30 changes: 28 additions & 2 deletions tests/local_mode/vectorized_transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,18 @@ def test_interval_mode(self):
self.assertIn('(', result[0])
self.assertIn(']', result[0])

def test_interval_mode_ood_empty_string(self):
def test_string_sentinel_interval(self):
attr = domain.NumericalAttribute(
min_value=0,
max_value=10,
clip_to_range=False,
sentinel='MISSING',
interval_handling='interval',
)
result = vectorized_transformations.undiscretize(
np.array([0, 1]), np.array([5.0]), attr
)
self.assertEqual(result[0], '')
self.assertEqual(result[0], 'MISSING')
self.assertIn('(', result[1])

def test_invalid_bin_edges_raises(self):
Expand All @@ -266,6 +267,31 @@ def test_invalid_bin_edges_raises(self):
np.array([1]), np.array([5.0, 3.0]), attr
)

def test_custom_sentinel_midpoint(self):
attr = domain.NumericalAttribute(
min_value=0, max_value=10, clip_to_range=False, sentinel=-1
)
result = vectorized_transformations.undiscretize(
np.array([0, 1, 2]), np.array([5.0]), attr
)
self.assertEqual(result[0], -1)
self.assertBetween(result[1], 0, 5)
self.assertBetween(result[2], 5, 10)

def test_custom_sentinel_sample(self):
rng = np.random.default_rng(0)
attr = domain.NumericalAttribute(
min_value=0,
max_value=10,
clip_to_range=False,
sentinel=-1,
interval_handling='sample',
)
result = vectorized_transformations.undiscretize(
np.array([0, 1]), np.array([5.0]), attr, rng=rng
)
self.assertEqual(result[0], -1)


class MergeRareValuesTest(absltest.TestCase):

Expand Down
24 changes: 21 additions & 3 deletions tests/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,25 @@ def test_valid_discretization_no_clip_to_range_inverse(self):

self.assertBetween(transform_fn.inverse(interval1), 0, 5)
self.assertBetween(transform_fn.inverse(interval2), 5, 10)
self.assertIsNone(transform_fn.inverse(None))
self.assertTrue(np.isnan(transform_fn.inverse(None)))

def test_discretize_inverse_sentinel_default_nan(self):
attr = domain.NumericalAttribute(
min_value=0, max_value=10, clip_to_range=False
)
_, transform_fn = transformations.create_discretize_transformation(
attr, [5]
)
self.assertTrue(np.isnan(transform_fn.inverse(None)))

def test_discretize_inverse_custom_sentinel(self):
attr = domain.NumericalAttribute(
min_value=0, max_value=10, clip_to_range=False, sentinel=-1
)
_, transform_fn = transformations.create_discretize_transformation(
attr, [5]
)
self.assertEqual(transform_fn.inverse(None), -1)

def test_valid_discretization_for_int_attribute(self):
attr = domain.NumericalAttribute(min_value=0, max_value=10, dtype='int')
Expand Down Expand Up @@ -180,7 +198,7 @@ def test_discretize_interval_handling_sample(self):
values.add(value)
# Sample mode should produce non-constant output (unlike midpoint).
self.assertGreater(len(values), 1)
self.assertIsNone(transform_fn.inverse(None))
self.assertTrue(np.isnan(transform_fn.inverse(None)))

def test_discretize_interval_handling_interval(self):
attr = domain.NumericalAttribute(
Expand All @@ -191,7 +209,7 @@ def test_discretize_interval_handling_interval(self):
)
interval = pd.Interval(5, 10)
self.assertEqual(transform_fn.inverse(interval), interval)
self.assertIsNone(transform_fn.inverse(None))
self.assertEqual(transform_fn.inverse(None), '')

def test_discretize_reverse_semi_infinite_intervals(self):
# Midpoint mode: semi-infinite intervals should return the finite endpoint.
Expand Down
Loading