Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions iceaxe/__tests__/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ async def db_connection(docker_postgres):
"demomodela",
"demomodelb",
"jsondemo",
"pydanticjsondemo",
"complextypedemo",
]
known_types = ["statusenum", "employeestatus"]
Expand Down
21 changes: 20 additions & 1 deletion iceaxe/__tests__/schemas/test_db_memory_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from uuid import UUID

import pytest
from pydantic import create_model
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo

from iceaxe import Field, TableBase
Expand Down Expand Up @@ -1569,3 +1569,22 @@ class TestModel(TableBase):
id_column = next(c for c in columns if c.column_name == "id")
assert id_column.column_type == ColumnType.INTEGER
assert not id_column.nullable


def test_pydantic_model_json_field(clear_all_database_objects):
class SettingsModel(BaseModel):
theme: str
notifications: bool

class TestModel(TableBase):
id: int = Field(primary_key=True)
settings: SettingsModel = Field(is_json=True)

migrator = DatabaseMemorySerializer()
db_objects = list(migrator.delegate([TestModel]))

columns = [obj for obj, _ in db_objects if isinstance(obj, DBColumn)]
settings_column = next(c for c in columns if c.column_name == "settings")

assert settings_column.column_type == ColumnType.JSON
assert not settings_column.nullable
118 changes: 118 additions & 0 deletions iceaxe/__tests__/test_session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from contextlib import asynccontextmanager
from enum import StrEnum
from json import dumps as json_dumps, loads as json_loads
from typing import Any, Type
from unittest.mock import AsyncMock, patch

import asyncpg
import pytest
from asyncpg.connection import Connection
from pydantic import BaseModel

from iceaxe.__tests__.conf_models import (
ArtifactDemo,
Expand Down Expand Up @@ -544,6 +546,122 @@ async def test_refresh(db_connection: DBConnection):
assert user.name == "Jane Doe"


@pytest.mark.asyncio
async def test_pydantic_json_round_trip(db_connection: DBConnection):
class Preferences(BaseModel):
theme: str
notifications: bool

class PydanticJsonDemo(TableBase):
id: int | None = Field(primary_key=True, default=None)
payload: Preferences = Field(is_json=True)

await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo")
await create_all(db_connection, [PydanticJsonDemo])

demo = PydanticJsonDemo(
payload=Preferences(theme="dark", notifications=True),
)
await db_connection.insert([demo])

full_result = await db_connection.exec(
QueryBuilder().select(PydanticJsonDemo).where(PydanticJsonDemo.id == demo.id)
)
assert len(full_result) == 1
assert full_result[0].payload == Preferences(theme="dark", notifications=True)

column_result = await db_connection.exec(
QueryBuilder()
.select(PydanticJsonDemo.payload)
.where(PydanticJsonDemo.id == demo.id)
)
assert column_result == [Preferences(theme="dark", notifications=True)]

demo.payload = Preferences(theme="light", notifications=False)
await db_connection.update([demo])

updated_result = await db_connection.exec(
QueryBuilder().select(PydanticJsonDemo).where(PydanticJsonDemo.id == demo.id)
)
assert updated_result == [
PydanticJsonDemo(
id=demo.id,
payload=Preferences(theme="light", notifications=False),
)
]

await db_connection.conn.execute(
"UPDATE pydanticjsondemo SET payload = $1::json WHERE id = $2",
json_dumps({"theme": "system", "notifications": True}),
demo.id,
)
await db_connection.refresh([demo])
assert demo.payload == Preferences(theme="system", notifications=True)


@pytest.mark.asyncio
async def test_pydantic_json_serialization_to_database(db_connection: DBConnection):
class Preferences(BaseModel):
theme: str
notifications: bool

class PydanticJsonDemo(TableBase):
id: int | None = Field(primary_key=True, default=None)
payload: Preferences = Field(is_json=True)

await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo")
await create_all(db_connection, [PydanticJsonDemo])

demo = PydanticJsonDemo(
payload=Preferences(theme="dark", notifications=True),
)
await db_connection.insert([demo])

raw_row = await db_connection.conn.fetchrow(
"SELECT payload::text AS payload FROM pydanticjsondemo WHERE id = $1",
demo.id,
)
assert raw_row is not None
assert json_loads(raw_row["payload"]) == {
"theme": "dark",
"notifications": True,
}


@pytest.mark.asyncio
async def test_pydantic_json_deserialization_from_database(
db_connection: DBConnection,
):
class Preferences(BaseModel):
theme: str
notifications: bool

class PydanticJsonDemo(TableBase):
id: int | None = Field(primary_key=True, default=None)
payload: Preferences = Field(is_json=True)

await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo")
await create_all(db_connection, [PydanticJsonDemo])

inserted_row = await db_connection.conn.fetchrow(
"INSERT INTO pydanticjsondemo (payload) VALUES ($1::json) RETURNING id",
json_dumps({"theme": "system", "notifications": False}),
)
assert inserted_row is not None

result = await db_connection.exec(
QueryBuilder()
.select(PydanticJsonDemo)
.where(PydanticJsonDemo.id == inserted_row["id"])
)
assert result == [
PydanticJsonDemo(
id=inserted_row["id"],
payload=Preferences(theme="system", notifications=False),
)
]


@pytest.mark.asyncio
async def test_get(db_connection: DBConnection):
"""
Expand Down
19 changes: 17 additions & 2 deletions iceaxe/field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from json import dumps as json_dumps
from json import dumps as json_dumps, loads as json_loads
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -12,7 +12,7 @@
cast,
)

from pydantic import Field as PydanticField
from pydantic import Field as PydanticField, TypeAdapter
from pydantic.fields import FieldInfo, _FieldInfoInputs
from pydantic_core import PydanticUndefined

Expand Down Expand Up @@ -105,6 +105,8 @@ class DBFieldInfo(FieldInfo):
When set, this type takes precedence over automatic type inference.
"""

_json_type_adapter: TypeAdapter[Any] | None = None

def __init__(self, **kwargs: Unpack[DBFieldInputs]):
"""
Initialize a new DBFieldInfo instance with the given field configuration.
Expand Down Expand Up @@ -163,6 +165,19 @@ def to_db_value(self, value: Any):
return json_dumps(value)
return value

def from_db_value(self, value: Any):
if not self.is_json or value is None:
return value

parsed_value = json_loads(value) if isinstance(value, str) else value
if self.annotation is None:
return parsed_value

if self._json_type_adapter is None:
self._json_type_adapter = TypeAdapter(self.annotation)

return self._json_type_adapter.validate_python(parsed_value)


def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # type: ignore
"""
Expand Down
3 changes: 3 additions & 0 deletions iceaxe/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import wraps
from typing import Any, Generic, Literal, Type, TypeVar, TypeVarTuple, cast, overload

from pydantic import BaseModel

from iceaxe.alias_values import Alias
from iceaxe.base import (
DBFieldClassDefinition,
Expand Down Expand Up @@ -46,6 +48,7 @@
| PRIMITIVE_WRAPPER_TYPES
| DATE_TYPES
| JSON_WRAPPER_FALLBACK
| BaseModel
| None
)

Expand Down
11 changes: 11 additions & 0 deletions iceaxe/schemas/db_memory_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Generator, Sequence, Type, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel
from pydantic_core import PydanticUndefined

from iceaxe.base import (
Expand Down Expand Up @@ -489,6 +490,16 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase]
)
else:
raise ValueError(f"Unsupported date type: {annotation}")
elif is_type_compatible(annotation, BaseModel):
if info.is_json:
return TypeDeclarationResponse(
primitive_type=ColumnType.JSON,
)
else:
raise ValueError(
f"Pydantic model fields must have Field(is_json=True) specified: {annotation}\n"
f"Column: {table.__name__}.{key}"
)
elif is_type_compatible(annotation, JSON_WRAPPER_FALLBACK):
if info.is_json:
return TypeDeclarationResponse(
Expand Down
19 changes: 10 additions & 9 deletions iceaxe/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
from contextlib import asynccontextmanager
from inspect import isclass
from json import loads as json_loads
from math import ceil
from typing import (
Any,
Expand Down Expand Up @@ -558,13 +557,9 @@ async def upsert(
processed_values = []
for field in returning_fields_cols:
value = row[field.key]
if (
value is not None
and field.root_model.model_fields[
field.key
].is_json
):
value = json_loads(value)
value = field.root_model.model_fields[
field.key
].from_db_value(value)
processed_values.append(value)
results.append(tuple(processed_values))
else:
Expand Down Expand Up @@ -747,7 +742,13 @@ async def refresh(self, objects: Sequence[TableBase]):
if obj_id in results:
# Update field-by-field
for field in fields:
setattr(obj, field, results[obj_id][field])
setattr(
obj,
field,
model.model_fields[field].from_db_value(
results[obj_id][field]
),
)
else:
LOGGER.error(
f"Object {obj} with primary key {obj_id} not found in database"
Expand Down
5 changes: 3 additions & 2 deletions iceaxe/session_optimized.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ from typing import Any, List, Tuple
from iceaxe.base import TableBase
from iceaxe.queries import FunctionMetadata
from iceaxe.alias_values import Alias
from json import loads as json_loads
from cpython.ref cimport PyObject
from cpython.object cimport PyObject_GetItem
from libc.stdlib cimport malloc, free
Expand Down Expand Up @@ -132,7 +131,7 @@ cdef list process_values(
if field_value is not None:
all_none = False
if fields[j][k].is_json:
field_value = json_loads(field_value)
field_value = select_raw.model_fields[field_name].from_db_value(field_value)

obj_dict[field_name] = field_value

Expand All @@ -153,6 +152,8 @@ cdef list process_values(
item = value[f"{table_name}_{column_name}"]
except KeyError:
raise KeyError(f"Key '{table_name}_{column_name}' not found in value.")
if item is not None and select_raw.field_definition.is_json:
item = select_raw.field_definition.from_db_value(item)
result_value[j] = <PyObject*>item
Py_INCREF(item)

Expand Down
Loading