Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions scim2_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,42 @@ def handle_extension(resource: Resource, scim_name: str) -> tuple[BaseModel, str
return resource, scim_name


def model_validate_from_dict(field_root_type: type[BaseModel], value: dict) -> Any:
"""Workaround for some of the "special" requirements for MS Entra, mixing display and displayName in some cases."""
if (
"display" not in value
and "display" in field_root_type.model_fields
and "displayName" in value
):
value["display"] = value["displayName"]
del value["displayName"]
return field_root_type.model_validate(value)
def parse_value(field_root_type: type, value: Any) -> Any:
"""Parse a PATCH value according to the target field root type."""
if isinstance(value, dict):
if not hasattr(field_root_type, "model_fields"):
raise TypeError

# Work around mixed display/displayName payloads emitted by MS Entra.
if (
"display" not in value
and "display" in field_root_type.model_fields
and "displayName" in value
):
value = value.copy()
value["display"] = value["displayName"]
del value["displayName"]
return field_root_type.model_validate(value)

if field_root_type is bool and isinstance(value, str):
return not value.lower() == "false"

if field_root_type is datetime.datetime and isinstance(value, str):
# ISO 8601 datetime format (notably with the Z suffix) are only supported from Python 3.11
if sys.version_info < (3, 11): # pragma: no cover
return datetime.datetime.fromisoformat(re.sub(r"Z$", "+00:00", value))
return datetime.datetime.fromisoformat(value)

if field_root_type is EmailStr and isinstance(value, str):
return value

if hasattr(field_root_type, "model_fields"):
primary_value = get_by_alias(field_root_type, "value", True)
if primary_value is not None:
return field_root_type(value=value)
raise TypeError

return field_root_type(value)


def parse_new_value(model: BaseModel, attribute_name: str, value: Any) -> Any:
Expand All @@ -164,31 +190,10 @@ def parse_new_value(model: BaseModel, attribute_name: str, value: Any) -> Any:
"""
field_root_type = model.get_field_root_type(attribute_name)
try:
if isinstance(value, dict):
new_value = model_validate_from_dict(field_root_type, value)
elif isinstance(value, list):
new_value = [model_validate_from_dict(field_root_type, v) for v in value]
if isinstance(value, list):
new_value = [parse_value(field_root_type, v) for v in value]
else:
if field_root_type is bool and isinstance(value, str):
new_value = not value.lower() == "false"
elif field_root_type is datetime.datetime and isinstance(value, str):
# ISO 8601 datetime format (notably with the Z suffix) are only supported from Python 3.11
if sys.version_info < (3, 11): # pragma: no cover
new_value = datetime.datetime.fromisoformat(
re.sub(r"Z$", "+00:00", value)
)
else:
new_value = datetime.datetime.fromisoformat(value)
elif field_root_type is EmailStr and isinstance(value, str):
new_value = value
elif hasattr(field_root_type, "model_fields"):
primary_value = get_by_alias(field_root_type, "value", True)
if primary_value is not None:
new_value = field_root_type(value=value)
else:
raise TypeError
else:
new_value = field_root_type(value)
new_value = parse_value(field_root_type, value)
except (AttributeError, TypeError, ValueError, ValidationError) as e:
raise InvalidValueException() from e
return new_value
23 changes: 23 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from scim2_models import URN
from scim2_models import MutabilityException
from scim2_models import PatchOperation
from scim2_models.resources.resource import Resource

from scim2_server.operators import patch_resource

Expand Down Expand Up @@ -200,3 +202,24 @@ def test_patch_operation_add_multi_valued(self, provider):
},
],
}

def test_patch_replace_multivalued_primitive_attribute(self):
"""Replace a multi-valued primitive attribute."""

class MyResource(Resource):
__schema__ = URN("urn:example:schemas:MyResource")

tags: list[str] | None = None

resource = MyResource(id="123")

patch_resource(
resource,
PatchOperation(
op=PatchOperation.Op.replace_,
path="tags",
value=["tag1", "tag2"],
),
)

assert resource.tags == ["tag1", "tag2"]
Loading