Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

<!-- ## Unreleased -->
## Unreleased

### Deprecated

- `APIMixin.patch()` parameter `validate`. Has no effect and will be removed in a future release.

## [0.2.0](https://github.com/unioslo/mreg-api/releases/tag/0.2.0) - 2026-04-16

Expand Down
118 changes: 24 additions & 94 deletions mreg_api/models/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,108 +6,23 @@
from abc import abstractmethod
from datetime import datetime
from typing import Any
from typing import Callable
from typing import Self
from typing import cast
from typing import overload

from pydantic import AliasChoices
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic.fields import FieldInfo
from typing_extensions import deprecated

from mreg_api.endpoints import Endpoint
from mreg_api.exceptions import EntityAlreadyExists
from mreg_api.exceptions import EntityNotFound
from mreg_api.exceptions import GetError
from mreg_api.exceptions import InternalError
from mreg_api.exceptions import PatchError
from mreg_api.exceptions import PostError
from mreg_api.types import Json
from mreg_api.types import JsonMapping
from mreg_api.types import QueryParams


def get_field_aliases(field_info: FieldInfo) -> set[str]:
"""Get all aliases for a Pydantic field."""
aliases: set[str] = set()

if field_info.alias:
aliases.add(field_info.alias)

if field_info.validation_alias:
if isinstance(field_info.validation_alias, str):
aliases.add(field_info.validation_alias)
elif isinstance(field_info.validation_alias, AliasChoices):
for choice in field_info.validation_alias.choices:
if isinstance(choice, str):
aliases.add(choice)
return aliases


def get_model_aliases(model: BaseModel) -> dict[str, str]:
"""Get a mapping of aliases to field names for a Pydantic model.

Includes field names, alias, and validation alias(es).
"""
fields: dict[str, str] = {}
for field_name, field_info in model.__class__.model_fields.items():
aliases = get_field_aliases(field_info)
if model.model_config.get("populate_by_name"):
aliases.add(field_name)
# Assign aliases to field name in mapping
for alias in aliases:
fields[alias] = field_name
return fields


def validate_patched_model(model: BaseModel, fields: JsonMapping) -> None:
"""Validate that model fields were patched correctly."""
aliases = get_model_aliases(model)

validators: dict[type, Callable[[Any, Any], bool]] = {
list: _validate_lists,
dict: _validate_dicts,
}
for key, value in fields.items():
field_name = key
if key in aliases:
field_name = aliases[key]

try:
nval = getattr(model, field_name)
except AttributeError as e:
raise PatchError(f"Could not get value for {field_name} in patched object.") from e

# Ensure patched value is the one we tried to set
validator = validators.get(
type(nval), # pyright:ignore[reportUnknownArgumentType, reportAny] # dict.get call with unknown type (Any) is fine
_validate_default,
)
if not validator(nval, value):
raise PatchError(
f"Patch failure! Tried to set {key} to {value!r}, but server returned {nval!r}."
)


def _validate_lists(new: list[Json], old: list[Json]) -> bool:
"""Validate that two lists are equal."""
if len(new) != len(old):
return False
return all(x in old for x in new)


def _validate_dicts(new: JsonMapping, old: JsonMapping) -> bool:
"""Validate that two dictionaries are equal."""
if len(new) != len(old):
return False
return all(old.get(k) == v for k, v in new.items())


def _validate_default(new: Json, old: Json) -> bool:
"""Validate that two values are equal."""
return str(new) == str(old)


class FrozenModel(BaseModel):
"""Model for an immutable object."""

Expand Down Expand Up @@ -526,8 +441,28 @@ def refetch(self) -> Self:
raise GetError(f"Could not refresh {self.__class__.__name__} with ID {identifier}.")
return obj

@overload
@deprecated(
"APIMixin.patch() parameter 'validate' is deprecated and will be removed in a future version."
)
def patch(
self, data: JsonMapping, *, params: QueryParams | None = None, validate: bool = ...
) -> Self: ...

@overload
def patch(
self,
data: JsonMapping,
*,
params: QueryParams | None = ...,
) -> Self: ...

def patch(
self, data: JsonMapping, *, params: QueryParams | None = None, validate: bool = False
self,
data: JsonMapping,
*,
params: QueryParams | None = None,
validate: bool | None = None, # noqa: ARG002 # pyright: ignore[reportUnusedParameter]
) -> Self:
"""Patch the object with the given values.

Expand All @@ -539,7 +474,7 @@ def patch(
Args:
data: The values to patch.
params: Optional query parameters.
validate: Whether to validate the patched object.
validate: Whether to validate the response. (Deprecated and ignored)

Returns:
The object refetched from the server.
Expand All @@ -549,11 +484,6 @@ def patch(
MregClient().patch(self.endpoint().with_id(self.id_for_endpoint()), json=data, params=params)
new_object = self.refetch()

if validate:
# __init_subclass__ guarantees we inherit from BaseModel
# but we can't signal this to the type checker, so we cast here.
validate_patched_model(cast(BaseModel, new_object), data) # pyright: ignore[reportInvalidCast] # we know what we are doing here (...?!)

return new_object

def delete(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion mreg_api/models/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom field types for Pydantic models.

The types validate to basic types like str, int, etc., but with additional
The types resolve to basic types like str, int, etc., but with additional
validation added to them. The types are used in Pydantic models for consistent
validation of common fields such as hostnames, MAC addresses, etc.

Expand Down
17 changes: 8 additions & 9 deletions mreg_api/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,7 +2000,7 @@ def set_policy(self, policy: NetworkPolicy) -> Self:
Returns:
The updated Network object.
"""
return self.patch({"policy": policy.id}, validate=False)
return self.patch({"policy": policy.id})

def set_max_communities(self, max_communities: int) -> Self:
"""Set the maximum number of communities for the network.
Expand All @@ -2011,23 +2011,23 @@ def set_max_communities(self, max_communities: int) -> Self:
Returns:
The updated Network object.
"""
return self.patch({"max_communities": max_communities}, validate=False)
return self.patch({"max_communities": max_communities})

def unset_policy(self) -> Self:
"""Unset the network policy of the network.

Returns:
The updated Network object.
"""
return self.patch({"policy": None}, validate=False)
return self.patch({"policy": None})

def unset_max_communities(self) -> Self:
"""Unset the maximum number of communities for the network.

Returns:
The updated Network object.
"""
return self.patch({"max_communities": None}, validate=False)
return self.patch({"max_communities": None})


class NetworkPolicyAttribute(FrozenModelWithTimestamps, WithName):
Expand Down Expand Up @@ -2094,14 +2094,14 @@ def patch(
data: JsonMapping,
*,
params: QueryParams | None = None,
validate: bool = False, # noqa: ARG002, E501
validate: bool | None = None, # noqa: ARG002
) -> Self:
"""Patch the community.

Args:
data: The data to patch.
params: Optional query parameters.
validate: Whether to validate the response. (Not implemented)
validate: Whether to validate the response. (Deprecated and ignored)

Returns:
The updated Community object.
Expand Down Expand Up @@ -2286,7 +2286,6 @@ def _patch_attrs(self, attrs: list[NetworkPolicyAttributeValue]) -> None:
"""
self.patch(
{"attributes": [{"name": a.name, "value": a.value} for a in attrs]},
validate=False,
)
# NOTE: can return self.refetch() here if we need to refresh the object

Expand Down Expand Up @@ -2468,7 +2467,7 @@ def disassociate_mac(self) -> IPAddress:
A new IPAddress object fetched from the API with the MAC address removed.
"""
# Model converts empty string to None so we must validate this ourselves.
patched = self.patch(data={"macaddress": ""}, validate=False)
patched = self.patch(data={"macaddress": ""})
if patched.macaddress:
raise PatchError(f"Failed to disassociate MAC address from {self.ipaddress}")
return patched
Expand Down Expand Up @@ -3282,7 +3281,7 @@ def set_contacts(self, contacts: list[str]) -> Host:
Host: Updated Host object.
"""
# Uses non-atomic host update via PATCH to set the contacts list.
return self.patch(data={"contacts": contacts}, validate=False)
return self.patch(data={"contacts": contacts})

def add_contacts(self, contacts: list[str]) -> HostContactModification:
"""Add contacts to the host.
Expand Down
Loading