Skip to content
This repository was archived by the owner on Mar 13, 2026. It is now read-only.

Commit 6950785

Browse files
committed
feat: support json_serializer parameter in create_engine()
SpannerDialect now accepts `json_serializer` and `json_deserializer` kwargs, matching the standard SQLAlchemy convention used by PostgreSQL and other dialects. Previously, passing `json_serializer` to `create_engine()` raised `TypeError` because the dialect's `__init__` did not declare these parameters. The implementation uses a serialize-then-wrap strategy: the user's `json_serializer` function pre-serializes values (handling custom types like `datetime`), then the result is parsed back into a `JsonObject` via `from_str()`. This preserves the existing Spanner client pipeline (`_helpers.py` expects `JsonObject` instances) while allowing custom type handling — without subclassing or modifying `JsonObject` itself. Example usage: engine = create_engine( "spanner:///...", json_serializer=lambda obj: json.dumps(obj, cls=MyEncoder), ) Made-with: Cursor
1 parent 9602646 commit 6950785

2 files changed

Lines changed: 279 additions & 0 deletions

File tree

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,40 @@ def visit_JSON(self, type_, **kw):
820820
return "JSON"
821821

822822

823+
def _make_json_serializer(json_serializer):
824+
"""Build a ``_json_serializer`` callable from a user-supplied function.
825+
826+
SQLAlchemy's ``create_engine(json_serializer=fn)`` convention expects a
827+
callable that replaces ``json.dumps`` entirely — it takes a Python object
828+
and returns a JSON string. The Spanner pipeline is different: it wraps
829+
values in a :class:`JsonObject` first, and serialization happens later in
830+
``_helpers._make_param_value_pb`` via ``obj.serialize()``.
831+
832+
To bridge this gap we use a **serialize-then-wrap** strategy:
833+
834+
1. Call the user's ``json_serializer(value)`` to produce a JSON string
835+
with all custom types (``datetime``, etc.) already handled.
836+
2. Feed that string into ``JsonObject.from_str()`` which parses it back
837+
into a ``JsonObject`` containing only native Python types.
838+
3. When ``_helpers.py`` later calls ``obj.serialize()``, the standard
839+
``json.dumps`` works because no custom types remain.
840+
841+
This avoids subclassing or monkey-patching ``JsonObject`` and requires
842+
no changes to the core ``google-cloud-spanner`` library.
843+
844+
If *json_serializer* is already a ``JsonObject`` subclass (e.g. the
845+
default class-level value), it is returned directly.
846+
"""
847+
if isinstance(json_serializer, type) and issubclass(json_serializer, JsonObject):
848+
return json_serializer
849+
850+
def _factory(value):
851+
json_str = json_serializer(value)
852+
return JsonObject.from_str(json_str)
853+
854+
return _factory
855+
856+
823857
class SpannerDialect(DefaultDialect):
824858
"""Cloud Spanner dialect.
825859
@@ -869,6 +903,13 @@ class SpannerDialect(DefaultDialect):
869903
_json_serializer = JsonObject
870904
_json_deserializer = JsonObject
871905

906+
def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
907+
super().__init__(**kwargs)
908+
if json_serializer is not None:
909+
self._json_serializer = _make_json_serializer(json_serializer)
910+
if json_deserializer is not None:
911+
self._json_deserializer = json_deserializer
912+
872913
@classmethod
873914
def dbapi(cls):
874915
"""A pointer to the Cloud Spanner DB API package.

test/unit/test_json_serializer.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2026 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import json
17+
import unittest
18+
19+
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import (
20+
SpannerDialect,
21+
_make_json_serializer,
22+
)
23+
from google.cloud.spanner_v1.data_types import JsonObject
24+
25+
26+
def _custom_serializer(obj):
27+
"""Sample json_serializer that handles datetime objects."""
28+
return json.dumps(obj, default=_datetime_default)
29+
30+
31+
def _datetime_default(obj):
32+
"""Sample default handler for json.dumps."""
33+
if hasattr(obj, "isoformat"):
34+
return obj.isoformat()
35+
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
36+
37+
38+
class TestMakeJsonSerializer(unittest.TestCase):
39+
"""Tests for _make_json_serializer factory."""
40+
41+
def test_json_object_subclass_returned_directly(self):
42+
result = _make_json_serializer(JsonObject)
43+
assert result is JsonObject
44+
45+
def test_custom_subclass_returned_directly(self):
46+
class MyJsonObject(JsonObject):
47+
pass
48+
49+
result = _make_json_serializer(MyJsonObject)
50+
assert result is MyJsonObject
51+
52+
def test_callable_produces_json_object(self):
53+
factory = _make_json_serializer(_custom_serializer)
54+
obj = factory({"key": "value", "num": 42})
55+
assert isinstance(obj, JsonObject)
56+
parsed = json.loads(obj.serialize())
57+
assert parsed == {"key": "value", "num": 42}
58+
59+
def test_callable_handles_datetime(self):
60+
factory = _make_json_serializer(_custom_serializer)
61+
dt = datetime.datetime(2023, 6, 15)
62+
obj = factory({"ts": dt})
63+
assert isinstance(obj, JsonObject)
64+
parsed = json.loads(obj.serialize())
65+
assert parsed["ts"] == "2023-06-15T00:00:00"
66+
67+
def test_callable_handles_nested_datetimes(self):
68+
factory = _make_json_serializer(_custom_serializer)
69+
obj = factory({
70+
"events": [
71+
{"ts": datetime.datetime(2023, 1, 1), "action": "created"},
72+
{"ts": datetime.datetime(2023, 6, 15), "action": "updated"},
73+
]
74+
})
75+
parsed = json.loads(obj.serialize())
76+
assert parsed["events"][0]["ts"] == "2023-01-01T00:00:00"
77+
assert parsed["events"][1]["ts"] == "2023-06-15T00:00:00"
78+
79+
def test_callable_handles_arrays(self):
80+
factory = _make_json_serializer(_custom_serializer)
81+
obj = factory([1, 2, 3])
82+
assert isinstance(obj, JsonObject)
83+
assert json.loads(obj.serialize()) == [1, 2, 3]
84+
85+
def test_callable_handles_null(self):
86+
factory = _make_json_serializer(lambda v: json.dumps(v))
87+
obj = factory(None)
88+
assert isinstance(obj, JsonObject)
89+
assert obj.serialize() is None
90+
91+
def test_no_custom_types_remain_in_json_object(self):
92+
"""After serialize-then-wrap, the JsonObject contains only native types."""
93+
factory = _make_json_serializer(_custom_serializer)
94+
dt = datetime.datetime(2023, 6, 15, 9, 30, 0)
95+
obj = factory({"ts": dt, "name": "test"})
96+
assert isinstance(obj["ts"], str)
97+
assert obj["ts"] == "2023-06-15T09:30:00"
98+
99+
100+
class TestSpannerDialectJsonSerializer(unittest.TestCase):
101+
"""Tests for json_serializer/json_deserializer support in SpannerDialect."""
102+
103+
def test_default_json_serializer_is_json_object(self):
104+
dialect = SpannerDialect()
105+
assert dialect._json_serializer is JsonObject
106+
107+
def test_default_json_deserializer_is_json_object(self):
108+
dialect = SpannerDialect()
109+
assert dialect._json_deserializer is JsonObject
110+
111+
def test_custom_json_serializer_produces_factory(self):
112+
dialect = SpannerDialect(json_serializer=_custom_serializer)
113+
assert dialect._json_serializer is not JsonObject
114+
obj = dialect._json_serializer({"ts": datetime.datetime(2023, 1, 1)})
115+
assert isinstance(obj, JsonObject)
116+
parsed = json.loads(obj.serialize())
117+
assert parsed["ts"] == "2023-01-01T00:00:00"
118+
119+
def test_json_object_subclass_used_directly(self):
120+
dialect = SpannerDialect(json_serializer=JsonObject)
121+
assert dialect._json_serializer is JsonObject
122+
123+
def test_custom_json_deserializer(self):
124+
custom = lambda x: json.loads(x)
125+
dialect = SpannerDialect(json_deserializer=custom)
126+
assert dialect._json_deserializer is custom
127+
128+
def test_class_attribute_unchanged_after_instance_override(self):
129+
_ = SpannerDialect(json_serializer=_custom_serializer)
130+
assert SpannerDialect._json_serializer is JsonObject
131+
132+
def test_json_serializer_accepted_by_get_cls_kwargs(self):
133+
from sqlalchemy.util import get_cls_kwargs
134+
135+
kwargs = get_cls_kwargs(SpannerDialect)
136+
assert "json_serializer" in kwargs
137+
assert "json_deserializer" in kwargs
138+
139+
140+
class TestEndToEndJsonSerialization(unittest.TestCase):
141+
"""End-to-end: SQLAlchemy JSON bind_processor -> serialize-then-wrap -> JsonObject.
142+
143+
Simulates the full pipeline that occurs during a DML INSERT/UPDATE
144+
with a JSON column containing non-natively-serializable types.
145+
"""
146+
147+
def test_bind_processor_with_custom_serializer(self):
148+
"""Simulate SQLAlchemy's JSON.bind_processor using the dialect."""
149+
from sqlalchemy import types as sa_types
150+
151+
dialect = SpannerDialect(json_serializer=_custom_serializer)
152+
processor = sa_types.JSON().bind_processor(dialect)
153+
154+
dt = datetime.datetime(2023, 6, 15, 9, 30, 0)
155+
value = {"event": "deploy", "timestamp": dt, "count": 42}
156+
157+
result = processor(value)
158+
159+
assert isinstance(result, JsonObject)
160+
serialized = result.serialize()
161+
parsed = json.loads(serialized)
162+
assert parsed["event"] == "deploy"
163+
assert parsed["timestamp"] == "2023-06-15T09:30:00"
164+
assert parsed["count"] == 42
165+
166+
def test_bind_processor_with_nested_datetimes(self):
167+
from sqlalchemy import types as sa_types
168+
169+
dialect = SpannerDialect(json_serializer=_custom_serializer)
170+
processor = sa_types.JSON().bind_processor(dialect)
171+
172+
value = {
173+
"history": [
174+
{"ts": datetime.datetime(2023, 1, 1), "action": "created"},
175+
{"ts": datetime.datetime(2023, 6, 15), "action": "updated"},
176+
]
177+
}
178+
result = processor(value)
179+
parsed = json.loads(result.serialize())
180+
assert parsed["history"][0]["ts"] == "2023-01-01T00:00:00"
181+
assert parsed["history"][1]["ts"] == "2023-06-15T00:00:00"
182+
183+
def test_bind_processor_with_null_default(self):
184+
"""With none_as_null=False (default), None becomes a null JsonObject."""
185+
from sqlalchemy import types as sa_types
186+
187+
dialect = SpannerDialect(json_serializer=_custom_serializer)
188+
processor = sa_types.JSON().bind_processor(dialect)
189+
190+
result = processor(None)
191+
assert isinstance(result, JsonObject)
192+
assert result.serialize() is None
193+
194+
def test_bind_processor_with_null_as_sql_null(self):
195+
"""With none_as_null=True, None becomes Python None (SQL NULL)."""
196+
from sqlalchemy import types as sa_types
197+
198+
dialect = SpannerDialect(json_serializer=_custom_serializer)
199+
processor = sa_types.JSON(none_as_null=True).bind_processor(dialect)
200+
201+
result = processor(None)
202+
assert result is None
203+
204+
def test_spanner_helpers_pipeline(self):
205+
"""Simulate _helpers._make_param_value_pb: isinstance check + bare serialize().
206+
207+
_helpers.py checks isinstance(value, JsonObject) then calls
208+
value.serialize() with no arguments. Verify this works after
209+
the serialize-then-wrap round-trip.
210+
"""
211+
dialect = SpannerDialect(json_serializer=_custom_serializer)
212+
factory = dialect._json_serializer
213+
214+
dt = datetime.datetime(2023, 12, 25, 0, 0, 0)
215+
obj = factory({"holiday": "christmas", "date": dt})
216+
217+
assert isinstance(obj, JsonObject)
218+
serialized = obj.serialize()
219+
assert serialized is not None
220+
parsed = json.loads(serialized)
221+
assert parsed["date"] == "2023-12-25T00:00:00"
222+
223+
def test_default_dialect_unchanged(self):
224+
"""Without json_serializer, the dialect uses plain JsonObject (no round-trip)."""
225+
from sqlalchemy import types as sa_types
226+
227+
dialect = SpannerDialect()
228+
processor = sa_types.JSON().bind_processor(dialect)
229+
230+
value = {"name": "test", "count": 42}
231+
result = processor(value)
232+
assert type(result) is JsonObject
233+
parsed = json.loads(result.serialize())
234+
assert parsed == {"count": 42, "name": "test"}
235+
236+
237+
if __name__ == "__main__":
238+
unittest.main()

0 commit comments

Comments
 (0)