diff --git a/src/xarray_validate/base.py b/src/xarray_validate/base.py index 300f382..bf2d7e3 100644 --- a/src/xarray_validate/base.py +++ b/src/xarray_validate/base.py @@ -123,6 +123,35 @@ def has_errors(self) -> bool: return self.result.has_errors +def raise_or_handle( + error: SchemaError, + context: ValidationContext | None = None, + from_exc: Exception | None = None, +) -> None: + """ + Raise error or handle it via context if available. + + Parameters + ---------- + error : SchemaError + The error to raise or handle. + + context : ValidationContext or None + Validation context. If provided, error is handled via context. + Otherwise, error is raised. + + from_exc : Exception or None + Optional exception to chain from when raising. + """ + if context: + context.handle_error(error) + else: + if from_exc is not None: + raise error from from_exc + else: + raise error + + class SchemaError(Exception): """Custom schema error.""" diff --git a/src/xarray_validate/components.py b/src/xarray_validate/components.py index 2a805af..efc1865 100644 --- a/src/xarray_validate/components.py +++ b/src/xarray_validate/components.py @@ -8,7 +8,7 @@ from numpy.typing import DTypeLike from . import _match, converters -from .base import BaseSchema, SchemaError, ValidationContext +from .base import BaseSchema, SchemaError, ValidationContext, raise_or_handle from .types import ChunksT, DimsT, ShapeT @@ -130,11 +130,7 @@ def validate( else f"one of {repr(self_dtypes)}" ) error = SchemaError(msg) - - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -218,10 +214,7 @@ def validate(self, dims: DimsT, context: ValidationContext | None = None) -> Non error = SchemaError( f"dimension number mismatch: got {len(dims)}, expected {len(self.dims)}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) if self.ordered: for i, (actual, expected) in enumerate(zip(dims, self.dims)): @@ -230,10 +223,7 @@ def validate(self, dims: DimsT, context: ValidationContext | None = None) -> Non f"dimension mismatch in axis {i}: got {actual}, " f"expected {expected}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) else: for i, expected in enumerate(self.dims): if expected is not None and expected not in dims: @@ -241,10 +231,7 @@ def validate(self, dims: DimsT, context: ValidationContext | None = None) -> Non f"dimension mismatch: expected {expected} is missing " f"from actual dimension list {dims}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -306,20 +293,14 @@ def validate(self, shape: tuple, context: ValidationContext | None = None) -> No "dimension count mismatch: " f"got {len(shape)}, expected {len(self.shape)}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) for i, (actual, expected) in enumerate(zip(shape, self.shape)): if expected is not None and actual != expected: error = SchemaError( f"shape mismatch in axis {i}: got {actual}, expected {expected}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -374,10 +355,7 @@ def validate(self, name: str, context: ValidationContext | None = None) -> None: # - https://docs.python.org/3.9/library/re.html if self.name != name: error = SchemaError(f"name mismatch: got {name}, expected {self.name}") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -455,23 +433,14 @@ def validate( if isinstance(self.chunks, bool): if self.chunks and not chunks: error = SchemaError("expected array to be chunked but it is not") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) elif not self.chunks and chunks: error = SchemaError("expected unchunked array but it is chunked") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) elif isinstance(self.chunks, dict): if chunks is None: error = SchemaError("expected array to be chunked but it is not") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) dim_chunks = dict(zip(dims, chunks)) dim_sizes = dict(zip(dims, shape)) # Check whether chunk sizes are regular because we assume the first @@ -487,10 +456,7 @@ def validate( error = SchemaError( f"chunk mismatch for {key}: got {ac}, expected {ec}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) else: # assumes ec is an iterable ac = dim_chunks[key] @@ -498,10 +464,7 @@ def validate( error = SchemaError( f"chunk mismatch for {key}: got {ac}, expected {ec}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) else: raise ValueError(f"got unknown chunks type: {type(self.chunks)}") @@ -560,10 +523,7 @@ def validate(self, array: Any, context: ValidationContext | None = None) -> None error = SchemaError( f"array type mismatch: got {type(array)}, expected {self.array_type}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -647,10 +607,7 @@ def validate(self, attr: Any, context: ValidationContext | None = None): error = SchemaError( f"attribute type mismatch {attr} is not of type {self.type}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) if self.value is not None: # Check if schema value is a string pattern @@ -663,18 +620,12 @@ def validate(self, attr: Any, context: ValidationContext | None = None): f"attribute value {attr!r} does not match pattern " f"{self.value!r}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) else: # Exact match for non-pattern values if self.value != attr: error = SchemaError(f"name {attr} != {self.value}") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) # Unit validation if self.units is not None or self.units_compatible is not None: @@ -684,56 +635,24 @@ def validate(self, attr: Any, context: ValidationContext | None = None): "Unit validation requires attribute to be a string, got " f"{type(attr).__name__}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) return - # Try to import pint - try: - import pint - except ImportError as e: - error = SchemaError( - "Unit validation requires the pint library. " - "Install with: pip install pint" - ) - if context: - context.handle_error(error) - else: - raise error from e - return - - # Get the application registry - ureg = pint.get_application_registry() + # Local import of units submodule will trigger a pint import and + # raise if the dependency is missing + from . import units # Parse the attribute value as a unit - try: - attr_unit = ureg.Unit(attr) - except ( - pint.UndefinedUnitError, - pint.errors.DefinitionSyntaxError, - ) as e: - error = SchemaError(f"Invalid unit '{attr}': {e}") - if context: - context.handle_error(error) - else: - raise error from e + attr_unit = units.parse(attr, context=context) + if attr_unit is None: return # Validate exact unit match if self.units is not None: - try: - expected_unit = ureg.Unit(self.units) - except ( - pint.UndefinedUnitError, - pint.errors.DefinitionSyntaxError, - ) as e: - error = SchemaError(f"Invalid expected unit '{self.units}': {e}") - if context: - context.handle_error(error) - else: - raise error from e + expected_unit = units.parse( + self.units, context=context, error_prefix="Invalid expected unit" + ) + if expected_unit is None: return if attr_unit != expected_unit: @@ -741,26 +660,16 @@ def validate(self, attr: Any, context: ValidationContext | None = None): f"Unit mismatch: expected '{self.units}' " f"(or equivalent like '{expected_unit:~}'), got '{attr}'" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) # Validate compatible units if self.units_compatible is not None: - try: - expected_unit = ureg.Unit(self.units_compatible) - except ( - pint.UndefinedUnitError, - pint.errors.DefinitionSyntaxError, - ) as e: - error = SchemaError( - f"Invalid expected unit '{self.units_compatible}': {e}" - ) - if context: - context.handle_error(error) - else: - raise error from e + expected_unit = units.parse( + self.units_compatible, + context=context, + error_prefix="Invalid expected unit", + ) + if expected_unit is None: return if not attr_unit.is_compatible_with(expected_unit): @@ -770,10 +679,7 @@ def validate(self, attr: Any, context: ValidationContext | None = None): f"Expected dimensionality: {expected_unit.dimensionality}, " f"got: {attr_unit.dimensionality}" ) - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -857,10 +763,7 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None missing_keys = set(exact_keys) - set(attrs) if missing_keys: error = SchemaError(f"attrs has missing keys: {missing_keys}") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) if not self.allow_extra_keys: # Check that all attributes match either exact or pattern keys @@ -870,19 +773,13 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None extra_keys = set(attrs) - matched_attrs if extra_keys: error = SchemaError(f"attrs has extra keys: {extra_keys}") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) # Validate attributes matching exact keys for key, attr_schema in exact_keys.items(): if key not in attrs: error = SchemaError(f"key {key} not in attrs") - if context: - context.handle_error(error) - else: - raise error + raise_or_handle(error, context) else: child_context = context.push(f"attrs.{key}") if context else None attr_schema.validate(attrs[key], child_context) diff --git a/src/xarray_validate/units.py b/src/xarray_validate/units.py new file mode 100644 index 0000000..35b02d5 --- /dev/null +++ b/src/xarray_validate/units.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import SchemaError, ValidationContext, raise_or_handle + +if TYPE_CHECKING: + import pint + +_REGISTRY: pint.UnitRegistry | None = None + + +try: + import pint +except ImportError as e: + raise ImportError( + "Unit validation requires the pint library. Install with pip install pint" + ) from e + + +def set_registry(ureg: pint.UnitRegistry | None = None) -> None: + global _REGISTRY + _REGISTRY = ureg if ureg is not None else pint.get_application_registry() + + +def get_registry() -> pint.UnitRegistry: + """ + Get the default unit registry. + + If not set by the user using :func:`.set_registry`,""" + if _REGISTRY is None: + set_registry() + return _REGISTRY + + +def parse( + unit_string: str, + ureg: pint.UnitRegistry | None = None, + context: ValidationContext | None = None, + error_prefix: str = "Invalid units", +): + """ + Parse a unit string with pint, handling errors appropriately. + + Parameters + ---------- + unit_string : str + The unit string to parse. + + ureg : pint.UnitRegistry, optional + The pint unit registry to use for parsing. If not passed, the default + registry is used. + + context : ValidationContext, optional + Validation context for error handling. + + error_prefix : str, default: "Invalid units" + Prefix for error messages. + + Returns + ------- + pint.Unit or None + The parsed unit, or None if parsing failed. + """ + if ureg is None: + ureg = get_registry() + + try: + return ureg.Unit(unit_string) + except (pint.UndefinedUnitError, pint.errors.DefinitionSyntaxError) as e: + error = SchemaError(f"{error_prefix} '{unit_string}': {e}") + raise_or_handle(error, context, from_exc=e) + return None diff --git a/tests/test_components.py b/tests/test_components.py index a791d6e..c86e10b 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -138,7 +138,7 @@ def test_attr_schema_unit_validation_no_pint(self): sys.modules["pint"] = None try: - with pytest.raises(SchemaError, match="requires the pint library"): + with pytest.raises(ImportError, match="requires the pint library"): schema.validate("metre") finally: # Restore pint module