Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions src/xarray_validate/components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import Any, Dict, Hashable, Optional, Tuple, Union
from typing import Any, Dict, Hashable, Optional, Tuple, Type, Union

import attrs as _attrs
import numpy as np
from numpy.typing import DTypeLike

from . import converters
from . import _match, converters
from .base import BaseSchema, SchemaError, ValidationContext
from .types import ChunksT, DimsT, ShapeT

Expand Down Expand Up @@ -580,7 +580,7 @@ class AttrSchema(BaseSchema):
Attribute value definition. ``None`` may be used as a wildcard.
"""

type: Optional[str] = _attrs.field(
type: Optional[Type] = _attrs.field(
default=None,
validator=_attrs.validators.optional(_attrs.validators.instance_of(type)),
)
Expand Down Expand Up @@ -633,12 +633,28 @@ def validate(self, attr: Any, context: ValidationContext | None = None):
raise error

if self.value is not None:
if self.value is not None and self.value != attr:
error = SchemaError(f"name {attr} != {self.value}")
if context:
context.handle_error(error)
else:
raise error
# Check if schema value is a string pattern
if isinstance(self.value, str) and _match.is_pattern_key(self.value):
# Convert attribute to string for pattern matching
attr_str = str(attr)
pattern = _match.pattern_to_regex(self.value)
if not pattern.fullmatch(attr_str):
error = SchemaError(
f"attribute value {attr!r} does not match pattern "
f"{self.value!r}"
)
if context:
context.handle_error(error)
else:
raise error
else:
# Exact match for non-pattern values
if self.value != attr:
error = SchemaError(f"name {attr} != {self.value}")
if context:
context.handle_error(error)
else:
raise error


@_attrs.define(on_setattr=[_attrs.setters.convert, _attrs.setters.validate])
Expand Down Expand Up @@ -714,8 +730,12 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None
If validation fails.
"""

# Separate exact keys from pattern keys and compile patterns
exact_keys, pattern_keys, compiled_patterns = _match.separate_keys(self.attrs)

if self.require_all_keys:
missing_keys = set(self.attrs) - set(attrs)
# Only check exact keys for require_all_keys
missing_keys = set(exact_keys) - set(attrs)
if missing_keys:
error = SchemaError(f"attrs has missing keys: {missing_keys}")
if context:
Expand All @@ -724,15 +744,20 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None
raise error

if not self.allow_extra_keys:
extra_keys = set(attrs) - set(self.attrs)
# Check that all attributes match either exact or pattern keys
matched_attrs = _match.find_matched_keys(
attrs, exact_keys, compiled_patterns
)
extra_keys = set(attrs) - matched_attrs
if extra_keys:
error = SchemaError(f"attrs has extra keys: {extra_keys}")
if context:
context.handle_error(error)
else:
raise error

for key, attr_schema in self.attrs.items():
# Validate attributes matching exact keys
for key, attr_schema in exact_keys.items():
if key not in attrs:
error = SchemaError(f"key {key} not in attrs")
if context:
Expand All @@ -742,3 +767,13 @@ def validate(self, attrs: Any, context: ValidationContext | None = None) -> None
else:
child_context = context.push(f"attrs.{key}") if context else None
attr_schema.validate(attrs[key], child_context)

# Validate attributes matching pattern keys
for pattern_key, attr_schema in pattern_keys.items():
regex = compiled_patterns[pattern_key]
for attr_name in attrs:
if regex.fullmatch(attr_name) and attr_name not in exact_keys:
child_context = (
context.push(f"attrs.{attr_name}") if context else None
)
attr_schema.validate(attrs[attr_name], child_context)
257 changes: 236 additions & 21 deletions tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NameSchema,
SchemaError,
ShapeSchema,
ValidationContext,
testing,
)

Expand All @@ -28,11 +29,245 @@ class TestAttrSchema:
({"type": str, "value": "foo"}, "foo", {"type": str, "value": "foo"}),
],
)
def test_attr_schema(self, kwargs, validate, json):
def test_attr_schema_basic(self, kwargs, validate, json):
schema = AttrSchema(**kwargs)
schema.validate(validate)
assert schema.serialize() == json

def test_exact_value_match(self):
"""Test that exact matching works for non-pattern values."""
schema = AttrSchema(value="meters")

# Exact value matches
schema.validate("meters")

# Should not match different value
with pytest.raises(SchemaError, match="name .* != .*"):
schema.validate("kilometers")

def test_glob_pattern_value_matching(self):
"""Test that glob patterns match attribute values."""
schema = AttrSchema(value="CF-*")

# Values starting with CF- match
schema.validate("CF-1.8")
schema.validate("CF-1.9")

# Values not starting with CF- do not match
with pytest.raises(SchemaError, match="does not match pattern"):
schema.validate("ACDD-1.3")

def test_regex_pattern_value_matching(self):
"""Test that regex patterns match attribute values."""
schema = AttrSchema(value=r"{CF-\d+\.\d+}")

# CF version strings match
schema.validate("CF-1.8")
schema.validate("CF-1.10")

# Invalid formats do not match
with pytest.raises(SchemaError, match="does not match pattern"):
schema.validate("CF-1")
with pytest.raises(SchemaError, match="does not match pattern"):
schema.validate("ACDD-1.3")

def test_pattern_value_with_numeric_conversion(self):
"""Test that numeric values are converted to strings for pattern matching."""
schema = AttrSchema(value=r"{\d+\.\d+}")

# Numeric values are converted to strings for matching
schema.validate("1.8")
schema.validate("2.0")

def test_pattern_value_in_attrs_schema(self):
"""Test pattern matching for values in AttrsSchema."""
schema = AttrsSchema.deserialize(
{
"Conventions": "CF-*", # Glob pattern
"units": "{(meters|kilometers)}", # Regex pattern
"comment": None, # Wildcard (only check if key exists)
}
)

# Validates matching patterns
schema.validate(
{"Conventions": "CF-1.8", "units": "meters", "comment": "any value here"}
)

# Validates other matching values
schema.validate(
{
"Conventions": "CF-2.0",
"units": "kilometers",
"comment": "different value",
}
)

# Fails on non-matching patterns
ctx = ValidationContext(mode="lazy")
schema.validate(
{"Conventions": "ACDD-1.3", "units": "meters", "comment": "test"},
context=ctx,
)
errors = ctx.result.errors
assert len(errors) == 1
assert errors[0][0] == "attrs.Conventions"
assert "does not match pattern" in str(errors[0][1])


class TestAttrsSchema:
"""Tests for AttrsSchema class."""

@pytest.mark.parametrize(
"schema_args, validate, json",
[
(
{"foo": AttrSchema(value="bar")},
[{"foo": "bar"}],
{
"allow_extra_keys": True,
"require_all_keys": True,
"attrs": {"foo": {"type": None, "value": "bar"}},
},
),
(
{"foo": AttrSchema(value=1)},
[{"foo": 1}],
{
"allow_extra_keys": True,
"require_all_keys": True,
"attrs": {"foo": {"type": None, "value": 1}},
},
),
],
)
def test_attrs_schema_basic(self, schema_args, validate, json):
schema = testing.assert_construct(AttrsSchema, schema_args)

for v in validate:
schema.validate(v)

testing.assert_json(schema, json)

def test_glob_pattern_matching_keys(self):
"""Test that glob patterns match attribute keys."""
schema = AttrsSchema.deserialize({"valid_*": "pass"})

# Validates attributes matching the pattern
schema.validate({"valid_min": "pass", "valid_max": "pass", "other": "ignored"})

# Fails to validate attributes that do not match the pattern
with pytest.raises(SchemaError, match="fail"):
schema.validate({"valid_min": "pass", "valid_max": "fail"})

def test_regex_pattern_matching_keys(self):
"""Test that regex patterns match attribute keys."""
schema = AttrsSchema.deserialize({"{valid_(min|max)}": "pass"})

# Validates attributes matching the regex
schema.validate({"valid_min": "pass", "valid_max": "pass", "other": "ignore"})

# Fails to validate attributes that do not match the pattern
with pytest.raises(SchemaError, match="fail"):
schema.validate({"valid_min": "pass", "valid_max": "fail"})

def test_mixed_exact_and_pattern_keys(self):
"""Test mixing exact and pattern keys."""
schema = AttrsSchema.deserialize(
{"units": "meters", "valid_*": 0.0, "long_name": "Distance"}
)

# Validates with exact and pattern matches
# Note: All attributes matching valid_* must have value 0.0
schema.validate(
{
"units": "meters",
"valid_min": 0.0,
"valid_max": 0.0,
"long_name": "Distance",
}
)

def test_exact_key_takes_precedence(self):
"""Test that exact keys take precedence over pattern keys."""
schema = AttrsSchema.deserialize({"valid_min": -10.0, "valid_*": 0.0})

# valid_min matches exact schema (-10.0), not pattern schema (0.0)
schema.validate({"valid_min": -10.0, "valid_max": 0.0})

def test_pattern_with_require_all_keys_false(self):
"""Test pattern matching with optional keys."""
schema = AttrsSchema.deserialize(
{
"attrs": {"valid_*": 0.0},
"require_all_keys": False,
"allow_extra_keys": True,
}
)

# Should validate even without pattern matches
schema.validate({"other_attr": "ignored"})

def test_pattern_with_allow_extra_keys_false(self):
"""Test pattern matching with strict key checking."""
schema = AttrsSchema.deserialize(
{
"attrs": {"valid_*": 0.0, "units": "meters"},
"require_all_keys": False,
"allow_extra_keys": False,
}
)

# Validates when all keys match schema
# Note: All attributes matching valid_* must have value 0.0
schema.validate({"valid_min": 0.0, "valid_max": 0.0, "units": "meters"})

# Raises when there are extra keys
with pytest.raises(SchemaError, match="attrs has extra keys"):
schema.validate(
{"valid_min": 0.0, "units": "meters", "unexpected": "value"}
)

def test_multiple_patterns(self):
"""Test multiple pattern keys."""
schema = AttrsSchema.deserialize({"valid_*": 0.0, r"{flag_\d+}": True})

# Validates attributes matching different patterns
# Note: All attributes matching valid_* must have value 0.0
schema.validate(
{"valid_min": 0.0, "valid_max": 0.0, "flag_0": True, "flag_1": True}
)

def test_pattern_validation_failure(self):
"""Test that pattern validation catches value mismatches."""
schema = AttrsSchema.deserialize({"valid_*": 0.0})

# Raises when pattern-matched values don't validate
with pytest.raises(SchemaError, match="name .* != .*"):
schema.validate({"valid_min": "wrong_type"})

def test_empty_pattern_matches_nothing(self):
"""Test that schema with only patterns doesn't require any keys."""
schema = AttrsSchema.deserialize(
{"attrs": {"valid_*": 0.0}, "require_all_keys": False}
)

# Validates empty attrs when no exact keys are required
schema.validate({})

def test_complex_regex_pattern(self):
"""Test complex regex pattern with character classes."""
schema = AttrsSchema.deserialize({"{[a-z]+_[0-9]{2}}": 100})

# Should match attributes following the pattern
schema.validate({"foo_12": 100, "bar_99": 100})

# Should not match attributes not following the pattern
with pytest.raises(SchemaError, match="attrs has extra keys"):
AttrsSchema.deserialize(
{"attrs": {"{[a-z]+_[0-9]{2}}": 100}, "allow_extra_keys": False}
).validate({"foo_1": 100})


class TestDTypeSchema:
VALIDATION_VALUES = {
Expand Down Expand Up @@ -114,26 +349,6 @@ def test_dtype_schema_generic(self):
[(((2, 2), (10,)), ("x", "y"), (4, 10))],
{"x": 2, "y": -1},
),
(
AttrsSchema,
{"foo": AttrSchema(value="bar")},
[{"foo": "bar"}],
{
"allow_extra_keys": True,
"require_all_keys": True,
"attrs": {"foo": {"type": None, "value": "bar"}},
},
),
(
AttrsSchema,
{"foo": AttrSchema(value=1)},
[{"foo": 1}],
{
"allow_extra_keys": True,
"require_all_keys": True,
"attrs": {"foo": {"type": None, "value": 1}},
},
),
(
CoordsSchema,
{"x": DataArraySchema(name="x")},
Expand Down