Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/xarray_validate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,35 @@ def has_errors(self) -> bool:
return self.result.has_errors


def raise_or_handle(
error: SchemaError,
context: ValidationContext | None = None,
from_exc: Exception | None = None,
) -> None:
"""
Raise error or handle it via context if available.

Parameters
----------
error : SchemaError
The error to raise or handle.

context : ValidationContext or None
Validation context. If provided, error is handled via context.
Otherwise, error is raised.

from_exc : Exception or None
Optional exception to chain from when raising.
"""
if context:
context.handle_error(error)
else:
if from_exc is not None:
raise error from from_exc
else:
raise error


class SchemaError(Exception):
"""Custom schema error."""

Expand Down
179 changes: 38 additions & 141 deletions src/xarray_validate/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import DTypeLike

from . import _match, converters
from .base import BaseSchema, SchemaError, ValidationContext
from .base import BaseSchema, SchemaError, ValidationContext, raise_or_handle
from .types import ChunksT, DimsT, ShapeT


Expand Down Expand Up @@ -130,11 +130,7 @@ def validate(
else f"one of {repr(self_dtypes)}"
)
error = SchemaError(msg)

if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -218,10 +214,7 @@ def validate(self, dims: DimsT, context: ValidationContext | None = None) -> Non
error = SchemaError(
f"dimension number mismatch: got {len(dims)}, expected {len(self.dims)}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

if self.ordered:
for i, (actual, expected) in enumerate(zip(dims, self.dims)):
Expand All @@ -230,21 +223,15 @@ def validate(self, dims: DimsT, context: ValidationContext | None = None) -> Non
f"dimension mismatch in axis {i}: got {actual}, "
f"expected {expected}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
else:
for i, expected in enumerate(self.dims):
if expected is not None and expected not in dims:
error = SchemaError(
f"dimension mismatch: expected {expected} is missing "
f"from actual dimension list {dims}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -306,20 +293,14 @@ def validate(self, shape: tuple, context: ValidationContext | None = None) -> No
"dimension count mismatch: "
f"got {len(shape)}, expected {len(self.shape)}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

for i, (actual, expected) in enumerate(zip(shape, self.shape)):
if expected is not None and actual != expected:
error = SchemaError(
f"shape mismatch in axis {i}: got {actual}, expected {expected}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -374,10 +355,7 @@ def validate(self, name: str, context: ValidationContext | None = None) -> None:
# - https://docs.python.org/3.9/library/re.html
if self.name != name:
error = SchemaError(f"name mismatch: got {name}, expected {self.name}")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -455,23 +433,14 @@ def validate(
if isinstance(self.chunks, bool):
if self.chunks and not chunks:
error = SchemaError("expected array to be chunked but it is not")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
elif not self.chunks and chunks:
error = SchemaError("expected unchunked array but it is chunked")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
elif isinstance(self.chunks, dict):
if chunks is None:
error = SchemaError("expected array to be chunked but it is not")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
dim_chunks = dict(zip(dims, chunks))
dim_sizes = dict(zip(dims, shape))
# Check whether chunk sizes are regular because we assume the first
Expand All @@ -487,21 +456,15 @@ def validate(
error = SchemaError(
f"chunk mismatch for {key}: got {ac}, expected {ec}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

else: # assumes ec is an iterable
ac = dim_chunks[key]
if ec is not None and tuple(ac) != tuple(ec):
error = SchemaError(
f"chunk mismatch for {key}: got {ac}, expected {ec}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
else:
raise ValueError(f"got unknown chunks type: {type(self.chunks)}")

Expand Down Expand Up @@ -560,10 +523,7 @@ def validate(self, array: Any, context: ValidationContext | None = None) -> None
error = SchemaError(
f"array type mismatch: got {type(array)}, expected {self.array_type}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -647,10 +607,7 @@ def validate(self, attr: Any, context: ValidationContext | None = None):
error = SchemaError(
f"attribute type mismatch {attr} is not of type {self.type}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

if self.value is not None:
# Check if schema value is a string pattern
Expand All @@ -663,18 +620,12 @@ def validate(self, attr: Any, context: ValidationContext | None = None):
f"attribute value {attr!r} does not match pattern "
f"{self.value!r}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
else:
# Exact match for non-pattern values
if self.value != attr:
error = SchemaError(f"name {attr} != {self.value}")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

# Unit validation
if self.units is not None or self.units_compatible is not None:
Expand All @@ -684,83 +635,41 @@ def validate(self, attr: Any, context: ValidationContext | None = None):
"Unit validation requires attribute to be a string, got "
f"{type(attr).__name__}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
return

# Try to import pint
try:
import pint
except ImportError as e:
error = SchemaError(
"Unit validation requires the pint library. "
"Install with: pip install pint"
)
if context:
context.handle_error(error)
else:
raise error from e
return

# Get the application registry
ureg = pint.get_application_registry()
# Local import of units submodule will trigger a pint import and
# raise if the dependency is missing
from . import units

# Parse the attribute value as a unit
try:
attr_unit = ureg.Unit(attr)
except (
pint.UndefinedUnitError,
pint.errors.DefinitionSyntaxError,
) as e:
error = SchemaError(f"Invalid unit '{attr}': {e}")
if context:
context.handle_error(error)
else:
raise error from e
attr_unit = units.parse(attr, context=context)
if attr_unit is None:
return

# Validate exact unit match
if self.units is not None:
try:
expected_unit = ureg.Unit(self.units)
except (
pint.UndefinedUnitError,
pint.errors.DefinitionSyntaxError,
) as e:
error = SchemaError(f"Invalid expected unit '{self.units}': {e}")
if context:
context.handle_error(error)
else:
raise error from e
expected_unit = units.parse(
self.units, context=context, error_prefix="Invalid expected unit"
)
if expected_unit is None:
return

if attr_unit != expected_unit:
error = SchemaError(
f"Unit mismatch: expected '{self.units}' "
f"(or equivalent like '{expected_unit:~}'), got '{attr}'"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

# Validate compatible units
if self.units_compatible is not None:
try:
expected_unit = ureg.Unit(self.units_compatible)
except (
pint.UndefinedUnitError,
pint.errors.DefinitionSyntaxError,
) as e:
error = SchemaError(
f"Invalid expected unit '{self.units_compatible}': {e}"
)
if context:
context.handle_error(error)
else:
raise error from e
expected_unit = units.parse(
self.units_compatible,
context=context,
error_prefix="Invalid expected unit",
)
if expected_unit is None:
return

if not attr_unit.is_compatible_with(expected_unit):
Expand All @@ -770,10 +679,7 @@ def validate(self, attr: Any, context: ValidationContext | None = None):
f"Expected dimensionality: {expected_unit.dimensionality}, "
f"got: {attr_unit.dimensionality}"
)
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -857,10 +763,7 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None
missing_keys = set(exact_keys) - set(attrs)
if missing_keys:
error = SchemaError(f"attrs has missing keys: {missing_keys}")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

if not self.allow_extra_keys:
# Check that all attributes match either exact or pattern keys
Expand All @@ -870,19 +773,13 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None
extra_keys = set(attrs) - matched_attrs
if extra_keys:
error = SchemaError(f"attrs has extra keys: {extra_keys}")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)

# Validate attributes matching exact keys
for key, attr_schema in exact_keys.items():
if key not in attrs:
error = SchemaError(f"key {key} not in attrs")
if context:
context.handle_error(error)
else:
raise error
raise_or_handle(error, context)
else:
child_context = context.push(f"attrs.{key}") if context else None
attr_schema.validate(attrs[key], child_context)
Expand Down
Loading