diff --git a/TODO.md b/TODO.md index 13a1840..5d20c10 100644 --- a/TODO.md +++ b/TODO.md @@ -5,4 +5,4 @@ - [x] DatasetSchema: Allow regex-based variable or coordinate name matching - [ ] AttrSchema: Support string input in the type field when deserializing - [ ] AttrSchema: Add regex-based string validation for attributes -- [ ] AttrSchema: Add pint-based unit validation system +- [x] AttrSchema: Add pint-based unit validation system diff --git a/pyproject.toml b/pyproject.toml index 0cdbf95..e8b25f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ Repository = "https://github.com/leroyvn/xarray-validate/" [project.optional-dependencies] dask = ["dask"] yaml = ["ruamel-yaml"] +units = ["pint"] [dependency-groups] lint = ["ruff>=0.14.0"] diff --git a/src/xarray_validate/components.py b/src/xarray_validate/components.py index e530de9..2a805af 100644 --- a/src/xarray_validate/components.py +++ b/src/xarray_validate/components.py @@ -578,6 +578,19 @@ class AttrSchema(BaseSchema): value : Any Attribute value definition. ``None`` may be used as a wildcard. + + units : str, optional + Exact unit validation (tolerates different spellings/abbreviations). + Uses pint to validate that the attribute value represents the same unit. + For example, ``units="metre"`` accepts "metre", "m", or "meter". + Requires pint to be installed. + + units_compatible : str, optional + Compatible units validation (allows unit conversions). + Uses pint to validate that the attribute value is compatible with the + specified unit. For example, ``units_compatible="metre"`` accepts + "meter", "kilometre", "millimetre", etc. + Requires pint to be installed. """ type: Optional[Type] = _attrs.field( @@ -585,10 +598,17 @@ class AttrSchema(BaseSchema): validator=_attrs.validators.optional(_attrs.validators.instance_of(type)), ) value: Optional[Any] = _attrs.field(default=None) + units: Optional[str] = _attrs.field(default=None) + units_compatible: Optional[str] = _attrs.field(default=None) def serialize(self) -> dict: # Inherit docstring - return {"type": self.type, "value": self.value} + return { + "type": self.type, + "value": self.value, + "units": self.units, + "units_compatible": self.units_compatible, + } @classmethod def deserialize(cls, obj): @@ -656,6 +676,105 @@ def validate(self, attr: Any, context: ValidationContext | None = None): else: raise error + # Unit validation + if self.units is not None or self.units_compatible is not None: + # Ensure attr is a string + if not isinstance(attr, str): + error = SchemaError( + "Unit validation requires attribute to be a string, got " + f"{type(attr).__name__}" + ) + if context: + context.handle_error(error) + else: + raise error + return + + # Try to import pint + try: + import pint + except ImportError as e: + error = SchemaError( + "Unit validation requires the pint library. " + "Install with: pip install pint" + ) + if context: + context.handle_error(error) + else: + raise error from e + return + + # Get the application registry + ureg = pint.get_application_registry() + + # Parse the attribute value as a unit + try: + attr_unit = ureg.Unit(attr) + except ( + pint.UndefinedUnitError, + pint.errors.DefinitionSyntaxError, + ) as e: + error = SchemaError(f"Invalid unit '{attr}': {e}") + if context: + context.handle_error(error) + else: + raise error from e + return + + # Validate exact unit match + if self.units is not None: + try: + expected_unit = ureg.Unit(self.units) + except ( + pint.UndefinedUnitError, + pint.errors.DefinitionSyntaxError, + ) as e: + error = SchemaError(f"Invalid expected unit '{self.units}': {e}") + if context: + context.handle_error(error) + else: + raise error from e + return + + if attr_unit != expected_unit: + error = SchemaError( + f"Unit mismatch: expected '{self.units}' " + f"(or equivalent like '{expected_unit:~}'), got '{attr}'" + ) + if context: + context.handle_error(error) + else: + raise error + + # Validate compatible units + if self.units_compatible is not None: + try: + expected_unit = ureg.Unit(self.units_compatible) + except ( + pint.UndefinedUnitError, + pint.errors.DefinitionSyntaxError, + ) as e: + error = SchemaError( + f"Invalid expected unit '{self.units_compatible}': {e}" + ) + if context: + context.handle_error(error) + else: + raise error from e + return + + if not attr_unit.is_compatible_with(expected_unit): + error = SchemaError( + f"Unit '{attr}' is not compatible with " + f"'{self.units_compatible}'. " + f"Expected dimensionality: {expected_unit.dimensionality}, " + f"got: {attr_unit.dimensionality}" + ) + if context: + context.handle_error(error) + else: + raise error + @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) class AttrsSchema(BaseSchema): diff --git a/tests/test_components.py b/tests/test_components.py index a9916a4..a791d6e 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -24,9 +24,21 @@ class TestAttrSchema: @pytest.mark.parametrize( "kwargs, validate, json", [ - ({"type": str, "value": None}, "foo", {"type": str, "value": None}), - ({"type": None, "value": "foo"}, "foo", {"type": None, "value": "foo"}), - ({"type": str, "value": "foo"}, "foo", {"type": str, "value": "foo"}), + ( + {"type": str, "value": None}, + "foo", + {"type": str, "value": None, "units": None, "units_compatible": None}, + ), + ( + {"type": None, "value": "foo"}, + "foo", + {"type": None, "value": "foo", "units": None, "units_compatible": None}, + ), + ( + {"type": str, "value": "foo"}, + "foo", + {"type": str, "value": "foo", "units": None, "units_compatible": None}, + ), ], ) def test_attr_schema_basic(self, kwargs, validate, json): @@ -43,7 +55,7 @@ def test_exact_value_match(self): # Should not match different value with pytest.raises(SchemaError, match="name .* != .*"): - schema.validate("kilometers") + schema.validate("kilometres") def test_glob_pattern_value_matching(self): """Test that glob patterns match attribute values.""" @@ -84,21 +96,21 @@ def test_pattern_value_in_attrs_schema(self): schema = AttrsSchema.deserialize( { "Conventions": "CF-*", # Glob pattern - "units": "{(meters|kilometers)}", # Regex pattern + "units": "{(metres|kilometres)}", # Regex pattern "comment": None, # Wildcard (only check if key exists) } ) # Validates matching patterns schema.validate( - {"Conventions": "CF-1.8", "units": "meters", "comment": "any value here"} + {"Conventions": "CF-1.8", "units": "metres", "comment": "any value here"} ) # Validates other matching values schema.validate( { "Conventions": "CF-2.0", - "units": "kilometers", + "units": "kilometres", "comment": "different value", } ) @@ -106,7 +118,7 @@ def test_pattern_value_in_attrs_schema(self): # Fails on non-matching patterns ctx = ValidationContext(mode="lazy") schema.validate( - {"Conventions": "ACDD-1.3", "units": "meters", "comment": "test"}, + {"Conventions": "ACDD-1.3", "units": "metres", "comment": "test"}, context=ctx, ) errors = ctx.result.errors @@ -114,6 +126,191 @@ def test_pattern_value_in_attrs_schema(self): assert errors[0][0] == "attrs.Conventions" assert "does not match pattern" in str(errors[0][1]) + def test_attr_schema_unit_validation_no_pint(self): + """Test that unit validation fails gracefully without pint.""" + # Create a schema with unit validation + schema = AttrSchema(units="metre") + + # Mock pint import failure + import sys + + pint_module = sys.modules.get("pint") + sys.modules["pint"] = None + + try: + with pytest.raises(SchemaError, match="requires the pint library"): + schema.validate("metre") + finally: + # Restore pint module + if pint_module: + sys.modules["pint"] = pint_module + else: + sys.modules.pop("pint", None) + + @pytest.mark.parametrize( + "schema_kwargs, valid_values, invalid_values", + [ + # Exact unit match - allows different spellings/abbreviations + ( + {"units": "metre"}, + ["metre", "m", "meter"], # All equivalent + ["kilometre", "cm", "foot", "not_a_unit"], + ), + ( + {"units": "nanometre"}, + ["nanometer", "nm", "nanometre"], + ["micrometer", "um", "angstrom", "metre"], + ), + ( + {"units": "kelvin"}, + ["kelvin", "K"], + ["celsius", "fahrenheit", "degC"], + ), + ( + {"units": "percent"}, + ["percent", "%"], + ["dimensionless", "1"], + ), + ], + ) + def test_attr_schema_exact_unit(self, schema_kwargs, valid_values, invalid_values): + """Test exact unit validation (tolerates different spellings).""" + pytest.importorskip("pint") + + schema = AttrSchema(**schema_kwargs) + + # Test valid values + for value in valid_values: + schema.validate(value) + + # Test invalid values + for value in invalid_values: + with pytest.raises(SchemaError, match="(Unit mismatch|Invalid unit)"): + schema.validate(value) + + @pytest.mark.parametrize( + "schema_kwargs, valid_values, invalid_values", + [ + # Compatible units - allows conversions + ( + {"units_compatible": "metre"}, + [ + "metre", + "m", + "kilometre", + "km", + "centimeter", + "cm", + "millimeter", + "mm", + "foot", + "mile", + ], + ["second", "kelvin", "pascal"], + ), + ( + {"units_compatible": "nanometer"}, + [ + "nanometer", + "nm", + "micrometer", + "um", + "angstrom", + "metre", + "kilometre", + ], + ["second", "kelvin"], + ), + ( + {"units_compatible": "kelvin"}, + ["kelvin", "K", "celsius", "degC", "fahrenheit"], + ["metre", "second"], + ), + ( + {"units_compatible": "pascal"}, + ["pascal", "Pa", "kPa", "hPa", "bar", "millibar", "atm", "psi"], + ["metre", "kelvin"], + ), + ( + {"units_compatible": "meter / second"}, + ["m/s", "meter/second", "km/h", "kilometre/hour", "mile/hour", "mph"], + ["metre", "second", "meter**2"], + ), + ( + {"units_compatible": "watt / meter**2"}, + ["W/m**2", "watt/meter**2", "W/m^2"], + ["watt", "metre"], + ), + ], + ) + def test_attr_schema_compatible_units( + self, schema_kwargs, valid_values, invalid_values + ): + """Test compatible unit validation (allows conversions).""" + pytest.importorskip("pint") + + schema = AttrSchema(**schema_kwargs) + + # Test valid values + for value in valid_values: + schema.validate(value) + + # Test invalid values + for value in invalid_values: + with pytest.raises(SchemaError, match="not compatible"): + schema.validate(value) + + def test_attr_schema_unit_validation_non_string(self): + """Test that unit validation requires string attributes.""" + pytest.importorskip("pint") + + schema = AttrSchema(units="metre") + + with pytest.raises(SchemaError, match="requires attribute to be a string"): + schema.validate(123) + + with pytest.raises(SchemaError, match="requires attribute to be a string"): + schema.validate(None) + + def test_attr_schema_invalid_unit_string(self): + """Test that invalid unit strings are caught.""" + pytest.importorskip("pint") + + schema = AttrSchema(units="metre") + + with pytest.raises(SchemaError, match="Invalid unit"): + schema.validate("not_a_real_unit") + + def test_attr_schema_serialize_with_units(self): + """Test serialization includes unit fields.""" + schema = AttrSchema(units="metre") + result = schema.serialize() + assert result == { + "type": None, + "value": None, + "units": "metre", + "units_compatible": None, + } + + schema = AttrSchema(units_compatible="kelvin") + result = schema.serialize() + assert result == { + "type": None, + "value": None, + "units": None, + "units_compatible": "kelvin", + } + + def test_attr_schema_deserialize_with_units(self): + """Test deserialization handles unit fields.""" + schema = AttrSchema.deserialize({"units": "metre"}) + assert schema.units == "metre" + assert schema.units_compatible is None + + schema = AttrSchema.deserialize({"units_compatible": "kelvin"}) + assert schema.units is None + assert schema.units_compatible == "kelvin" + class TestAttrsSchema: """Tests for AttrsSchema class.""" @@ -127,7 +324,14 @@ class TestAttrsSchema: { "allow_extra_keys": True, "require_all_keys": True, - "attrs": {"foo": {"type": None, "value": "bar"}}, + "attrs": { + "foo": { + "type": None, + "value": "bar", + "units": None, + "units_compatible": None, + } + }, }, ), ( @@ -136,7 +340,14 @@ class TestAttrsSchema: { "allow_extra_keys": True, "require_all_keys": True, - "attrs": {"foo": {"type": None, "value": 1}}, + "attrs": { + "foo": { + "type": None, + "value": 1, + "units": None, + "units_compatible": None, + } + }, }, ), ], diff --git a/uv.lock b/uv.lock index 2263228..5ad25b5 100644 --- a/uv.lock +++ b/uv.lock @@ -1025,6 +1025,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" }, ] +[[package]] +name = "flexcache" +version = "0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/b0/8a21e330561c65653d010ef112bf38f60890051d244ede197ddaa08e50c1/flexcache-0.3.tar.gz", hash = "sha256:18743bd5a0621bfe2cf8d519e4c3bfdf57a269c15d1ced3fb4b64e0ff4600656", size = 15816, upload-time = "2024-03-09T03:21:07.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/cd/c883e1a7c447479d6e13985565080e3fea88ab5a107c21684c813dba1875/flexcache-0.3-py3-none-any.whl", hash = "sha256:d43c9fea82336af6e0115e308d9d33a185390b8346a017564611f1466dcd2e32", size = 13263, upload-time = "2024-03-09T03:21:05.635Z" }, +] + +[[package]] +name = "flexparser" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/99/b4de7e39e8eaf8207ba1a8fa2241dd98b2ba72ae6e16960d8351736d8702/flexparser-0.4.tar.gz", hash = "sha256:266d98905595be2ccc5da964fe0a2c3526fbbffdc45b65b3146d75db992ef6b2", size = 31799, upload-time = "2024-11-07T02:00:56.249Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/5e/3be305568fe5f34448807976dc82fc151d76c3e0e03958f34770286278c1/flexparser-0.4-py3-none-any.whl", hash = "sha256:3738b456192dcb3e15620f324c447721023c0293f6af9955b481e91d00179846", size = 27625, upload-time = "2024-11-07T02:00:54.523Z" }, +] + [[package]] name = "fsspec" version = "2025.3.0" @@ -2355,6 +2379,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/41/220f49aaea88bc6fa6cba8d05ecf24676326156c23b991e80b3f2fc24c77/pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56", size = 6877, upload-time = "2018-09-25T19:17:35.817Z" }, ] +[[package]] +name = "pint" +version = "0.21.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +sdist = { url = "https://files.pythonhosted.org/packages/70/f4/9e7cb8e65e36c0a5e832bf04e57ca2cb1f96ea1ae289f10b82e2e98a49c7/Pint-0.21.1.tar.gz", hash = "sha256:5d5b6b518d0c5a7ab03a776175db500f1ed1523ee75fb7fafe38af8149431c8d", size = 336147, upload-time = "2023-05-25T16:29:12.845Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/bc/7ef2a654754cc3179af8df837485931f0874d96e111005a6246c1ed695f2/Pint-0.21.1-py3-none-any.whl", hash = "sha256:230ebccc312693117ee925c6492b3631c772ae9f7851a4e86080a15e7be692d8", size = 290846, upload-time = "2023-05-25T16:29:08.851Z" }, +] + +[[package]] +name = "pint" +version = "0.24.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "flexcache", marker = "python_full_version >= '3.9' and python_full_version < '3.11'" }, + { name = "flexparser", marker = "python_full_version >= '3.9' and python_full_version < '3.11'" }, + { name = "platformdirs", version = "4.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "platformdirs", version = "4.5.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/bb/52b15ddf7b7706ed591134a895dbf6e41c8348171fb635e655e0a4bbb0ea/pint-0.24.4.tar.gz", hash = "sha256:35275439b574837a6cd3020a5a4a73645eb125ce4152a73a2f126bf164b91b80", size = 342225, upload-time = "2024-11-07T16:29:46.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/16/bd2f5904557265882108dc2e04f18abc05ab0c2b7082ae9430091daf1d5c/Pint-0.24.4-py3-none-any.whl", hash = "sha256:aa54926c8772159fcf65f82cc0d34de6768c151b32ad1deb0331291c38fe7659", size = 302029, upload-time = "2024-11-07T16:29:43.976Z" }, +] + +[[package]] +name = "pint" +version = "0.25.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "flexcache", marker = "python_full_version >= '3.11'" }, + { name = "flexparser", marker = "python_full_version >= '3.11'" }, + { name = "platformdirs", version = "4.5.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/74/bc3f671997158aef171194c3c4041e549946f4784b8690baa0626a0a164b/pint-0.25.2.tar.gz", hash = "sha256:85a45d1da8fe9c9f7477fed8aef59ad2b939af3d6611507e1a9cbdacdcd3450a", size = 254467, upload-time = "2025-11-06T22:08:09.184Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/88/550d41e81e6d43335603a960cd9c75c1d88f9cf01bc9d4ee8e86290aba7d/pint-0.25.2-py3-none-any.whl", hash = "sha256:ca35ab1d8eeeb6f7d9942b3cb5f34ca42b61cdd5fb3eae79531553dcca04dda7", size = 306762, upload-time = "2025-11-06T22:08:07.745Z" }, +] + [[package]] name = "pkgutil-resolve-name" version = "1.3.10" @@ -4521,6 +4596,11 @@ dask = [ { name = "dask", version = "2024.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, { name = "dask", version = "2025.12.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, ] +units = [ + { name = "pint", version = "0.21.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "pint", version = "0.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and python_full_version < '3.11'" }, + { name = "pint", version = "0.25.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] yaml = [ { name = "ruamel-yaml" }, ] @@ -4585,10 +4665,11 @@ requires-dist = [ { name = "attrs" }, { name = "dask", marker = "extra == 'dask'" }, { name = "numpy" }, + { name = "pint", marker = "extra == 'units'" }, { name = "ruamel-yaml", marker = "extra == 'yaml'" }, { name = "xarray" }, ] -provides-extras = ["dask", "yaml"] +provides-extras = ["dask", "yaml", "units"] [package.metadata.requires-dev] dev = [