From f04199c6aa90b965f7e9348a4579aa1182d386de Mon Sep 17 00:00:00 2001 From: Vincent Leroy Date: Thu, 1 Jan 2026 14:33:17 +0200 Subject: [PATCH] Add pattern match support for attribute values and keys This commit adds pattern match support for attribute values (in `AttrSchema`) and keys (in `AttrsSchema`). The test suite refactoring continues as well: `AttrsSchema` tests are now scoped in a dedicated `TestAttrsSchema` class to improve comprehensiveness. --- src/xarray_validate/components.py | 59 +++++-- tests/test_components.py | 257 +++++++++++++++++++++++++++--- 2 files changed, 283 insertions(+), 33 deletions(-) diff --git a/src/xarray_validate/components.py b/src/xarray_validate/components.py index 60b1d86..e530de9 100644 --- a/src/xarray_validate/components.py +++ b/src/xarray_validate/components.py @@ -1,13 +1,13 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Any, Dict, Hashable, Optional, Tuple, Union +from typing import Any, Dict, Hashable, Optional, Tuple, Type, Union import attrs as _attrs import numpy as np from numpy.typing import DTypeLike -from . import converters +from . import _match, converters from .base import BaseSchema, SchemaError, ValidationContext from .types import ChunksT, DimsT, ShapeT @@ -580,7 +580,7 @@ class AttrSchema(BaseSchema): Attribute value definition. ``None`` may be used as a wildcard. """ - type: Optional[str] = _attrs.field( + type: Optional[Type] = _attrs.field( default=None, validator=_attrs.validators.optional(_attrs.validators.instance_of(type)), ) @@ -633,12 +633,28 @@ def validate(self, attr: Any, context: ValidationContext | None = None): raise error if self.value is not None: - if self.value is not None and self.value != attr: - error = SchemaError(f"name {attr} != {self.value}") - if context: - context.handle_error(error) - else: - raise error + # Check if schema value is a string pattern + if isinstance(self.value, str) and _match.is_pattern_key(self.value): + # Convert attribute to string for pattern matching + attr_str = str(attr) + pattern = _match.pattern_to_regex(self.value) + if not pattern.fullmatch(attr_str): + error = SchemaError( + f"attribute value {attr!r} does not match pattern " + f"{self.value!r}" + ) + if context: + context.handle_error(error) + else: + raise error + else: + # Exact match for non-pattern values + if self.value != attr: + error = SchemaError(f"name {attr} != {self.value}") + if context: + context.handle_error(error) + else: + raise error @_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate]) @@ -714,8 +730,12 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None If validation fails. """ + # Separate exact keys from pattern keys and compile patterns + exact_keys, pattern_keys, compiled_patterns = _match.separate_keys(self.attrs) + if self.require_all_keys: - missing_keys = set(self.attrs) - set(attrs) + # Only check exact keys for require_all_keys + missing_keys = set(exact_keys) - set(attrs) if missing_keys: error = SchemaError(f"attrs has missing keys: {missing_keys}") if context: @@ -724,7 +744,11 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None raise error if not self.allow_extra_keys: - extra_keys = set(attrs) - set(self.attrs) + # Check that all attributes match either exact or pattern keys + matched_attrs = _match.find_matched_keys( + attrs, exact_keys, compiled_patterns + ) + extra_keys = set(attrs) - matched_attrs if extra_keys: error = SchemaError(f"attrs has extra keys: {extra_keys}") if context: @@ -732,7 +756,8 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None else: raise error - for key, attr_schema in self.attrs.items(): + # Validate attributes matching exact keys + for key, attr_schema in exact_keys.items(): if key not in attrs: error = SchemaError(f"key {key} not in attrs") if context: @@ -742,3 +767,13 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None else: child_context = context.push(f"attrs.{key}") if context else None attr_schema.validate(attrs[key], child_context) + + # Validate attributes matching pattern keys + for pattern_key, attr_schema in pattern_keys.items(): + regex = compiled_patterns[pattern_key] + for attr_name in attrs: + if regex.fullmatch(attr_name) and attr_name not in exact_keys: + child_context = ( + context.push(f"attrs.{attr_name}") if context else None + ) + attr_schema.validate(attrs[attr_name], child_context) diff --git a/tests/test_components.py b/tests/test_components.py index e868477..a9916a4 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -15,6 +15,7 @@ NameSchema, SchemaError, ShapeSchema, + ValidationContext, testing, ) @@ -28,11 +29,245 @@ class TestAttrSchema: ({"type": str, "value": "foo"}, "foo", {"type": str, "value": "foo"}), ], ) - def test_attr_schema(self, kwargs, validate, json): + def test_attr_schema_basic(self, kwargs, validate, json): schema = AttrSchema(**kwargs) schema.validate(validate) assert schema.serialize() == json + def test_exact_value_match(self): + """Test that exact matching works for non-pattern values.""" + schema = AttrSchema(value="meters") + + # Exact value matches + schema.validate("meters") + + # Should not match different value + with pytest.raises(SchemaError, match="name .* != .*"): + schema.validate("kilometers") + + def test_glob_pattern_value_matching(self): + """Test that glob patterns match attribute values.""" + schema = AttrSchema(value="CF-*") + + # Values starting with CF- match + schema.validate("CF-1.8") + schema.validate("CF-1.9") + + # Values not starting with CF- do not match + with pytest.raises(SchemaError, match="does not match pattern"): + schema.validate("ACDD-1.3") + + def test_regex_pattern_value_matching(self): + """Test that regex patterns match attribute values.""" + schema = AttrSchema(value=r"{CF-\d+\.\d+}") + + # CF version strings match + schema.validate("CF-1.8") + schema.validate("CF-1.10") + + # Invalid formats do not match + with pytest.raises(SchemaError, match="does not match pattern"): + schema.validate("CF-1") + with pytest.raises(SchemaError, match="does not match pattern"): + schema.validate("ACDD-1.3") + + def test_pattern_value_with_numeric_conversion(self): + """Test that numeric values are converted to strings for pattern matching.""" + schema = AttrSchema(value=r"{\d+\.\d+}") + + # Numeric values are converted to strings for matching + schema.validate("1.8") + schema.validate("2.0") + + def test_pattern_value_in_attrs_schema(self): + """Test pattern matching for values in AttrsSchema.""" + schema = AttrsSchema.deserialize( + { + "Conventions": "CF-*", # Glob pattern + "units": "{(meters|kilometers)}", # Regex pattern + "comment": None, # Wildcard (only check if key exists) + } + ) + + # Validates matching patterns + schema.validate( + {"Conventions": "CF-1.8", "units": "meters", "comment": "any value here"} + ) + + # Validates other matching values + schema.validate( + { + "Conventions": "CF-2.0", + "units": "kilometers", + "comment": "different value", + } + ) + + # Fails on non-matching patterns + ctx = ValidationContext(mode="lazy") + schema.validate( + {"Conventions": "ACDD-1.3", "units": "meters", "comment": "test"}, + context=ctx, + ) + errors = ctx.result.errors + assert len(errors) == 1 + assert errors[0][0] == "attrs.Conventions" + assert "does not match pattern" in str(errors[0][1]) + + +class TestAttrsSchema: + """Tests for AttrsSchema class.""" + + @pytest.mark.parametrize( + "schema_args, validate, json", + [ + ( + {"foo": AttrSchema(value="bar")}, + [{"foo": "bar"}], + { + "allow_extra_keys": True, + "require_all_keys": True, + "attrs": {"foo": {"type": None, "value": "bar"}}, + }, + ), + ( + {"foo": AttrSchema(value=1)}, + [{"foo": 1}], + { + "allow_extra_keys": True, + "require_all_keys": True, + "attrs": {"foo": {"type": None, "value": 1}}, + }, + ), + ], + ) + def test_attrs_schema_basic(self, schema_args, validate, json): + schema = testing.assert_construct(AttrsSchema, schema_args) + + for v in validate: + schema.validate(v) + + testing.assert_json(schema, json) + + def test_glob_pattern_matching_keys(self): + """Test that glob patterns match attribute keys.""" + schema = AttrsSchema.deserialize({"valid_*": "pass"}) + + # Validates attributes matching the pattern + schema.validate({"valid_min": "pass", "valid_max": "pass", "other": "ignored"}) + + # Fails to validate attributes that do not match the pattern + with pytest.raises(SchemaError, match="fail"): + schema.validate({"valid_min": "pass", "valid_max": "fail"}) + + def test_regex_pattern_matching_keys(self): + """Test that regex patterns match attribute keys.""" + schema = AttrsSchema.deserialize({"{valid_(min|max)}": "pass"}) + + # Validates attributes matching the regex + schema.validate({"valid_min": "pass", "valid_max": "pass", "other": "ignore"}) + + # Fails to validate attributes that do not match the pattern + with pytest.raises(SchemaError, match="fail"): + schema.validate({"valid_min": "pass", "valid_max": "fail"}) + + def test_mixed_exact_and_pattern_keys(self): + """Test mixing exact and pattern keys.""" + schema = AttrsSchema.deserialize( + {"units": "meters", "valid_*": 0.0, "long_name": "Distance"} + ) + + # Validates with exact and pattern matches + # Note: All attributes matching valid_* must have value 0.0 + schema.validate( + { + "units": "meters", + "valid_min": 0.0, + "valid_max": 0.0, + "long_name": "Distance", + } + ) + + def test_exact_key_takes_precedence(self): + """Test that exact keys take precedence over pattern keys.""" + schema = AttrsSchema.deserialize({"valid_min": -10.0, "valid_*": 0.0}) + + # valid_min matches exact schema (-10.0), not pattern schema (0.0) + schema.validate({"valid_min": -10.0, "valid_max": 0.0}) + + def test_pattern_with_require_all_keys_false(self): + """Test pattern matching with optional keys.""" + schema = AttrsSchema.deserialize( + { + "attrs": {"valid_*": 0.0}, + "require_all_keys": False, + "allow_extra_keys": True, + } + ) + + # Should validate even without pattern matches + schema.validate({"other_attr": "ignored"}) + + def test_pattern_with_allow_extra_keys_false(self): + """Test pattern matching with strict key checking.""" + schema = AttrsSchema.deserialize( + { + "attrs": {"valid_*": 0.0, "units": "meters"}, + "require_all_keys": False, + "allow_extra_keys": False, + } + ) + + # Validates when all keys match schema + # Note: All attributes matching valid_* must have value 0.0 + schema.validate({"valid_min": 0.0, "valid_max": 0.0, "units": "meters"}) + + # Raises when there are extra keys + with pytest.raises(SchemaError, match="attrs has extra keys"): + schema.validate( + {"valid_min": 0.0, "units": "meters", "unexpected": "value"} + ) + + def test_multiple_patterns(self): + """Test multiple pattern keys.""" + schema = AttrsSchema.deserialize({"valid_*": 0.0, r"{flag_\d+}": True}) + + # Validates attributes matching different patterns + # Note: All attributes matching valid_* must have value 0.0 + schema.validate( + {"valid_min": 0.0, "valid_max": 0.0, "flag_0": True, "flag_1": True} + ) + + def test_pattern_validation_failure(self): + """Test that pattern validation catches value mismatches.""" + schema = AttrsSchema.deserialize({"valid_*": 0.0}) + + # Raises when pattern-matched values don't validate + with pytest.raises(SchemaError, match="name .* != .*"): + schema.validate({"valid_min": "wrong_type"}) + + def test_empty_pattern_matches_nothing(self): + """Test that schema with only patterns doesn't require any keys.""" + schema = AttrsSchema.deserialize( + {"attrs": {"valid_*": 0.0}, "require_all_keys": False} + ) + + # Validates empty attrs when no exact keys are required + schema.validate({}) + + def test_complex_regex_pattern(self): + """Test complex regex pattern with character classes.""" + schema = AttrsSchema.deserialize({"{[a-z]+_[0-9]{2}}": 100}) + + # Should match attributes following the pattern + schema.validate({"foo_12": 100, "bar_99": 100}) + + # Should not match attributes not following the pattern + with pytest.raises(SchemaError, match="attrs has extra keys"): + AttrsSchema.deserialize( + {"attrs": {"{[a-z]+_[0-9]{2}}": 100}, "allow_extra_keys": False} + ).validate({"foo_1": 100}) + class TestDTypeSchema: VALIDATION_VALUES = { @@ -114,26 +349,6 @@ def test_dtype_schema_generic(self): [(((2, 2), (10,)), ("x", "y"), (4, 10))], {"x": 2, "y": -1}, ), - ( - AttrsSchema, - {"foo": AttrSchema(value="bar")}, - [{"foo": "bar"}], - { - "allow_extra_keys": True, - "require_all_keys": True, - "attrs": {"foo": {"type": None, "value": "bar"}}, - }, - ), - ( - AttrsSchema, - {"foo": AttrSchema(value=1)}, - [{"foo": 1}], - { - "allow_extra_keys": True, - "require_all_keys": True, - "attrs": {"foo": {"type": None, "value": 1}}, - }, - ), ( CoordsSchema, {"x": DataArraySchema(name="x")},