Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions iceaxe/__tests__/schemas/test_db_memory_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,74 @@ class ExampleModel2(TableBase):
]


def test_simple_uuid_subclass_column_assignment(clear_all_database_objects):
class CustomUUID(UUID):
pass

class ExampleModel(TableBase):
id: CustomUUID = Field(primary_key=True)
values: list[CustomUUID]

migrator = DatabaseMemorySerializer()
db_objects = list(migrator.delegate([ExampleModel]))
assert db_objects == [
(
DBTable(table_name="examplemodel"),
[],
),
(
DBColumn(
table_name="examplemodel",
column_name="id",
column_type=ColumnType.UUID,
column_is_list=False,
nullable=False,
),
[
DBTable(table_name="examplemodel"),
],
),
(
DBColumn(
table_name="examplemodel",
column_name="values",
column_type=ColumnType.UUID,
column_is_list=True,
nullable=False,
),
[
DBTable(table_name="examplemodel"),
],
),
(
DBConstraint(
table_name="examplemodel",
constraint_name="examplemodel_pkey",
columns=frozenset({"id"}),
constraint_type=ConstraintType.PRIMARY_KEY,
foreign_key_constraint=None,
),
[
DBTable(table_name="examplemodel"),
DBColumn(
table_name="examplemodel",
column_name="id",
column_type=ColumnType.UUID,
column_is_list=False,
nullable=False,
),
DBColumn(
table_name="examplemodel",
column_name="values",
column_type=ColumnType.UUID,
column_is_list=True,
nullable=False,
),
],
),
]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"field_name, annotation, field_info, expected_db_objects",
Expand Down
26 changes: 26 additions & 0 deletions iceaxe/__tests__/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Annotated, Any, Generic, TypeVar, cast
from uuid import UUID

from iceaxe.base import (
DBModelMetaclass,
Expand Down Expand Up @@ -92,3 +93,28 @@ class Event(TableBase, autodetect=False):
assert field.annotation == dict[str, Any] | None
assert field.default is None
assert field.is_json is True


def test_model_fields_with_simple_uuid_subclass():
class CustomUUID(UUID):
pass

class Event(TableBase, autodetect=False):
id: CustomUUID
maybe_id: CustomUUID | None = None
ids: list[CustomUUID]

raw_uuid = UUID("12345678-1234-5678-1234-567812345678")
event = cast(
Any,
Event,
)(
id=raw_uuid,
maybe_id=str(raw_uuid),
ids=[raw_uuid, str(raw_uuid)],
)

assert event.model_fields["id"].annotation == CustomUUID
assert isinstance(event.id, CustomUUID)
assert isinstance(event.maybe_id, CustomUUID)
assert all(isinstance(value, CustomUUID) for value in event.ids)
39 changes: 39 additions & 0 deletions iceaxe/__tests__/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from json import dumps as json_dumps, loads as json_loads
from typing import Any, Type
from unittest.mock import AsyncMock, patch
from uuid import UUID

import asyncpg
import pytest
Expand Down Expand Up @@ -66,6 +67,44 @@ async def test_db_connection_update(db_connection: DBConnection):
assert user.get_modified_attributes() == {}


@pytest.mark.asyncio
async def test_db_connection_uuid_subclass_round_trip(
db_connection: DBConnection,
clear_all_database_objects,
):
class CustomUUID(UUID):
pass

class UUIDSubclassDemo(TableBase):
id: CustomUUID = Field(primary_key=True)

await db_connection.conn.execute("DROP TABLE IF EXISTS uuidsubclassdemo")
await create_all(db_connection, [UUIDSubclassDemo])

row_id = CustomUUID("12345678-1234-5678-1234-567812345678")
demo = UUIDSubclassDemo(id=row_id)
await db_connection.insert([demo])

raw_row = await db_connection.conn.fetchrow(
"SELECT id FROM uuidsubclassdemo WHERE id = $1",
UUID(str(row_id)),
)
assert raw_row is not None
assert isinstance(raw_row["id"], UUID)

result = await db_connection.exec(
QueryBuilder().select(UUIDSubclassDemo).where(UUIDSubclassDemo.id == row_id)
)
assert len(result) == 1
assert isinstance(result[0].id, CustomUUID)

selected_ids = await db_connection.exec(
QueryBuilder().select(UUIDSubclassDemo.id).where(UUIDSubclassDemo.id == row_id)
)
assert selected_ids == [row_id]
assert isinstance(selected_ids[0], CustomUUID)


@pytest.mark.asyncio
async def test_db_obj_mixin_track_modifications():
user = UserDemo(name="John Doe", email="john@example.com")
Expand Down
16 changes: 16 additions & 0 deletions iceaxe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from pydantic.main import _model_construction
from pydantic_core import PydanticUndefined

from iceaxe.custom_typehints import wrap_simple_subclass_annotation
from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field
from iceaxe.typing import transform_typehint


@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,))
Expand Down Expand Up @@ -51,6 +53,20 @@ def __new__(mcs, name, bases, namespace, **kwargs):
Create a new database model class with proper field tracking.
Handles registration of the model and processes any table-specific arguments.
"""
# Pydantic must see these normalized annotations up front; otherwise simple
# subclasses like `CustomUUID(UUID)` are treated as unknown types.
namespace = dict(namespace)
raw_annotations = namespace.get("__annotations__", {})
if raw_annotations:
namespace["__annotations__"] = {
key: (
transform_typehint(annotation, wrap_simple_subclass_annotation)
if not isinstance(annotation, str)
else annotation
)
for key, annotation in raw_annotations.items()
}

raw_kwargs = {**kwargs}

mcs.is_constructing = True
Expand Down
Loading
Loading