diff --git a/dbzero/dbzero/__init__.py b/dbzero/dbzero/__init__.py index 31ec18fa..3a391c0b 100644 --- a/dbzero/dbzero/__init__.py +++ b/dbzero/dbzero/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) 2025 DBZero Software sp. z o.o. from .dbzero import * -from .dbzero import _check_interned, _init_data_filter, _init_data_masking +from .dbzero import _check_interned, _in_read_only, _init_data_filter, _init_data_masking from .memo import * from .enum import * from .fast_query import * diff --git a/dbzero/dbzero/atomic.py b/dbzero/dbzero/atomic.py index 0b1994bc..9f5c914b 100644 --- a/dbzero/dbzero/atomic.py +++ b/dbzero/dbzero/atomic.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional from .interfaces import Memo -from .dbzero import begin_atomic, begin_async_atomic, assign +from .dbzero import _in_read_only, begin_atomic, begin_async_atomic, assign _async_atomic_locks: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = weakref.WeakKeyDictionary() @@ -30,6 +30,17 @@ class _AsyncAtomicState: ) +class _NoOpAtomicContext: + def close(self): + pass + + def cancel(self): + pass + + +_NO_OP_ATOMIC_CONTEXT = _NoOpAtomicContext() + + def _current_async_task() -> Optional[asyncio.Task]: try: return asyncio.current_task() @@ -68,7 +79,7 @@ def __exit__(self, exc_type, exc_value, traceback): def begin(self): """Begin the atomic context""" if self.__ctx is None: - self.__ctx = begin_atomic() + self.__ctx = _NO_OP_ATOMIC_CONTEXT if _in_read_only() else begin_atomic() def close(self): """Close the atomic context, staging the changes for commit""" @@ -155,6 +166,10 @@ def __init__(self): self.__release_lock = False async def __aenter__(self) -> AsyncAtomicManager: + if _in_read_only(): + self.__ctx = _NO_OP_ATOMIC_CONTEXT + return self + task = _current_async_task() if task is None: raise RuntimeError("db0.async_atomic requires a running asyncio task") diff --git a/dbzero/dbzero/dbzero.pyi b/dbzero/dbzero/dbzero.pyi index 9e7b4747..600cfd74 100644 --- a/dbzero/dbzero/dbzero.pyi +++ b/dbzero/dbzero/dbzero.pyi @@ -15,6 +15,14 @@ def read_only() -> ContextManager[Any]: """Open a context manager that rejects dbzero mutations in its block.""" ... +def _in_read_only() -> bool: + """Return whether the current execution is inside a dbzero read-only block.""" + ... + +def in_read_only() -> bool: + """Return whether the current execution is inside a dbzero read-only block.""" + ... + def open(prefix_name: str, open_mode: str = "rw", **kwargs: Any) -> None: """Open a data prefix and set it as the current working context. @@ -970,13 +978,13 @@ def bytearray(source: Union[bytes, Iterable[int]] = b'', /) -> ByteArrayObject: # Tag and query functions -def tags(*objects: Memo) -> ObjectTagManager: +def tags(*objects: Union[Memo, QueryObject]) -> ObjectTagManager: """Get a tag manager interface for given Memo objects. Parameters ---------- - *objects : Memo - One or more Memo objects to manage tags for. + *objects : Memo or QueryObject + One or more Memo objects or query result sets to manage tags for. Returns ------- @@ -1001,6 +1009,11 @@ def tags(*objects: Memo) -> ObjectTagManager: >>> dbzero.tags(product1, product2).add("sale") >>> dbzero.tags(product1, product2).remove("sale") + Batch operations on query results: + + >>> dbzero.tags(dbzero.find("token-a", "token-b")).add("token-c") + >>> dbzero.tags(dbzero.find("token-a")).remove("token-a") + Chain operations: >>> dbzero.tags(obj).add("tag1").remove("old-tag").add("tag2") diff --git a/dbzero/dbzero/interfaces.py b/dbzero/dbzero/interfaces.py index dccff08a..007cbb2e 100644 --- a/dbzero/dbzero/interfaces.py +++ b/dbzero/dbzero/interfaces.py @@ -143,7 +143,7 @@ class ByteArrayObject(bytearray): ... class ObjectTagManager: - """Manages tags of one or more Memo instances.""" + """Manages tags of one or more Memo instances or query result sets.""" def add(self, *tag: Union[Tag, Iterable[Tag]]) -> None: """Add one or more tags to the managed objects. @@ -286,4 +286,4 @@ def get_state_num(self) -> int: int State number of a snapshot. """ - ... \ No newline at end of file + ... diff --git a/dbzero/dbzero/read_only.py b/dbzero/dbzero/read_only.py index a9947137..9b039acc 100644 --- a/dbzero/dbzero/read_only.py +++ b/dbzero/dbzero/read_only.py @@ -3,7 +3,7 @@ from __future__ import annotations -from .dbzero import begin_read_only +from .dbzero import _in_read_only, begin_read_only class ReadOnlyManager: @@ -25,3 +25,8 @@ def __exit__(self, _exc_type, _exc_value, _traceback): def read_only() -> ReadOnlyManager: """Open a context manager that rejects dbzero mutations in its block.""" return ReadOnlyManager() + + +def in_read_only() -> bool: + """Return whether the current execution is inside a dbzero read-only block.""" + return _in_read_only() diff --git a/python_tests/test_find.py b/python_tests/test_find.py index a8acff1b..eb98dea6 100644 --- a/python_tests/test_find.py +++ b/python_tests/test_find.py @@ -232,6 +232,69 @@ def test_remove_tags_then_find_typed(db0_fixture): db0.tags(objects[4], objects[2]).remove("one") assert len(list(db0.find(MemoTestClass, "one"))) == 0 + +def test_add_tags_to_query_result(db0_fixture): + objects = [MemoTestClass(i) for i in range(6)] + db0.tags(objects[0]).add(["token-a", "token-b"]) + db0.tags(objects[1]).add("token-a") + db0.tags(objects[2]).add(["token-a", "token-b"]) + db0.tags(objects[3]).add("token-b") + + db0.tags(db0.find("token-a", "token-b")).add("token-c") + + assert {item.value for item in db0.find("token-c")} == {0, 2} + + +def test_remove_tags_from_query_result(db0_fixture): + objects = [MemoTestClass(i) for i in range(5)] + db0.tags(objects[0]).add(["token-a", "keep"]) + db0.tags(objects[1]).add("token-a") + db0.tags(objects[2]).add(["token-a", "other"]) + db0.tags(objects[3]).add("other") + + db0.tags(db0.find("token-a")).remove("token-a") + + assert list(db0.find("token-a")) == [] + assert {item.value for item in db0.find("keep")} == {0} + assert {item.value for item in db0.find("other")} == {2, 3} + + +def test_query_tag_target_can_be_empty_or_mixed_with_memo_target(db0_fixture): + objects = [MemoTestClass(i) for i in range(4)] + db0.tags(objects[0]).add("source") + db0.tags(objects[1]).add("source") + + db0.tags(db0.find("missing")).add("batch") + assert list(db0.find("batch")) == [] + + db0.tags(objects[2], db0.find("source")).add(["batch", "extra"]) + + assert {item.value for item in db0.find("batch")} == {0, 1, 2} + assert {item.value for item in db0.find("extra")} == {0, 1, 2} + + +def test_batched_query_tags_rejected_in_read_only_context(db0_fixture): + obj = MemoTestClass(1) + db0.tags(obj).add("source") + + with db0.read_only(): + with pytest.raises(RuntimeError, match="read_only|read-only"): + db0.tags(db0.find("source")).add("blocked") + + assert list(db0.find("blocked")) == [] + + +def test_batched_query_tags_reject_snapshot_query_target(db0_fixture): + obj = MemoTestClass(1) + db0.tags(obj).add("source") + db0.commit() + + with db0.snapshot() as snap: + with pytest.raises(RuntimeError, match="read-only"): + db0.tags(snap.find("source")).remove("source") + + assert list(db0.find("source")) == [obj] + def test_query_by_non_existing_tag(db0_fixture): assert len(list(db0.find("tag1"))) == 0 diff --git a/python_tests/test_init.py b/python_tests/test_init.py index eb3acb50..3965c205 100644 --- a/python_tests/test_init.py +++ b/python_tests/test_init.py @@ -78,6 +78,7 @@ def guarded_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", guarded_import) db0.init(DB0_DIR) + assert db0.in_read_only() is False def test_init_propagates_rpc_init_error_and_allows_recovery(db0_fixture, monkeypatch): diff --git a/python_tests/test_read_only.py b/python_tests/test_read_only.py index 7f8ece1a..dbbd7ee1 100644 --- a/python_tests/test_read_only.py +++ b/python_tests/test_read_only.py @@ -2,6 +2,7 @@ # Copyright (c) 2025 DBZero Software sp. z o.o. import asyncio +import importlib import threading import pytest @@ -23,6 +24,18 @@ def test_read_only_allows_reads(db0_fixture): assert db0.fetch(db0.uuid(obj)) == obj +def test_in_read_only_public_predicate(db0_fixture): + assert not db0.in_read_only() + + with db0.read_only(): + assert db0.in_read_only() + with db0.read_only(): + assert db0.in_read_only() + assert db0.in_read_only() + + assert not db0.in_read_only() + + def test_read_only_rejects_memo_field_assignment(db0_fixture): obj = MemoTestClass(123) @@ -125,14 +138,22 @@ def test_read_only_inside_atomic_rejects_mutation(db0_fixture): assert obj.value == 123 -def test_atomic_inside_read_only_starts_normally_but_mutation_is_rejected(db0_fixture): +def test_atomic_inside_read_only_is_optimized_out(db0_fixture, monkeypatch): obj = MemoTestClass(123) + atomic_module = importlib.import_module("dbzero.atomic") with db0.read_only(): + assert db0._in_read_only() + monkeypatch.setattr( + atomic_module, + "begin_atomic", + lambda: pytest.fail("begin_atomic should not run inside read_only"), + ) with db0.atomic(): with pytest.raises(RuntimeError, match="read_only.*mutation|mutation.*read_only"): obj.value = 456 + assert not db0._in_read_only() obj.value = 789 assert obj.value == 789 @@ -292,6 +313,24 @@ async def mutate_in_child_task(): assert obj.value == 789 +async def test_read_only_child_task_stale_context_is_not_reactivated_by_unrelated_context(db0_fixture): + obj = MemoTestClass(123) + child_can_run = asyncio.Event() + + async def mutate_in_child_task(): + await child_can_run.wait() + obj.value = 789 + + with db0.read_only(): + child_task = asyncio.create_task(mutate_in_child_task()) + + with db0.read_only(): + child_can_run.set() + await asyncio.wait_for(child_task, timeout=5) + + assert obj.value == 789 + + def test_read_only_fast_overhead_paths(db0_fixture): obj = MemoTestClass(0) iterations = 100 diff --git a/src/dbzero/bindings/python/PyObjectTagManager.cpp b/src/dbzero/bindings/python/PyObjectTagManager.cpp index 9fd6d8f8..6cdf5ac8 100644 --- a/src/dbzero/bindings/python/PyObjectTagManager.cpp +++ b/src/dbzero/bindings/python/PyObjectTagManager.cpp @@ -5,6 +5,9 @@ #include "Memo.hpp" #include "PyInternalAPI.hpp" #include "PyToolkit.hpp" +#include "iter/PyObjectIterable.hpp" +#include +#include namespace db0::python @@ -102,10 +105,19 @@ namespace db0::python PyObjectTagManager *tryMakeObjectTagManager(PyObject *, PyObject *const *args, Py_ssize_t nargs) { - // all arguments must be Memo objects + std::vector memo_args; + std::vector > query_targets; + memo_args.reserve(nargs); + query_targets.reserve(nargs); + for (Py_ssize_t i = 0; i < nargs; ++i) { + if (PyObjectIterable_Check(args[i])) { + auto *query = reinterpret_cast(args[i]); + query_targets.push_back(query->getSharedPtr()); + continue; + } if (!PyToolkit::isAnyMemoObject(args[i])) { - THROWF(db0::InputException) << "All arguments must be dbzero memo objects"; + THROWF(db0::InputException) << "All arguments must be dbzero memo objects or object queries"; } if (PyMemo_Check(args[i])) { auto *memoObject = reinterpret_cast(args[i]); @@ -118,10 +130,16 @@ namespace db0::python auto materialized = Py_OWN(getMaterializedMemoObject(memoObject)); } } + memo_args.push_back(args[i]); } auto tags_obj = Py_OWN(PyObjectTagManager_new(&PyObjectTagManagerType, NULL, NULL)); - ObjectTagManager::makeNew(&tags_obj->modifyExt(), args, nargs); + ObjectTagManager::makeNew( + &tags_obj->modifyExt(), + memo_args.data(), + memo_args.size(), + std::move(query_targets) + ); return tags_obj.steal(); } diff --git a/src/dbzero/bindings/python/PyReadOnly.cpp b/src/dbzero/bindings/python/PyReadOnly.cpp index 01d63e75..e9a9fe4c 100644 --- a/src/dbzero/bindings/python/PyReadOnly.cpp +++ b/src/dbzero/bindings/python/PyReadOnly.cpp @@ -4,6 +4,7 @@ #include "PyReadOnly.hpp" #include "PyInternalAPI.hpp" +#include #include namespace db0::python @@ -13,54 +14,170 @@ namespace db0::python namespace { PyObject *s_read_only_depth_var = nullptr; + std::atomic_uint64_t s_active_read_only_generation = 0; thread_local std::uint64_t s_read_only_generation = 0; + constexpr const char *READ_ONLY_STATE_CAPSULE_NAME = "dbzero.ReadOnlyState"; + + struct ReadOnlyState + { + std::atomic_bool active = true; + }; struct ReadOnlyDepthCache { std::uint64_t thread_state_id = 0; std::uint64_t context_version = 0; std::uint64_t generation = 0; + std::uint64_t active_generation = 0; unsigned int depth = 0; bool valid = false; }; thread_local ReadOnlyDepthCache s_depth_cache; - unsigned int readOnlyDepthFromPythonContext() + void incrementActiveReadOnlyGeneration() + { + s_active_read_only_generation.fetch_add(1, std::memory_order_release); + } + + void destroyReadOnlyStateCapsule(PyObject *capsule) + { + if (!PyCapsule_IsValid(capsule, READ_ONLY_STATE_CAPSULE_NAME)) { + return; + } + auto *state = static_cast( + PyCapsule_GetPointer(capsule, READ_ONLY_STATE_CAPSULE_NAME) + ); + delete state; + } + + PyObject *makeReadOnlyStateCapsule() + { + auto *state = new ReadOnlyState(); + auto *capsule = PyCapsule_New( + state, + READ_ONLY_STATE_CAPSULE_NAME, + destroyReadOnlyStateCapsule + ); + if (!capsule) { + delete state; + return nullptr; + } + return capsule; + } + + ReadOnlyState *getReadOnlyState(PyObject *object) + { + if (!PyCapsule_IsValid(object, READ_ONLY_STATE_CAPSULE_NAME)) { + return nullptr; + } + return static_cast( + PyCapsule_GetPointer(object, READ_ONLY_STATE_CAPSULE_NAME) + ); + } + + bool readOnlyStateIsActive(PyObject *object) + { + auto *state = getReadOnlyState(object); + return state && state->active.load(std::memory_order_acquire); + } + + unsigned int activeReadOnlyDepthFromPythonContext() { if (!s_read_only_depth_var) { return 0; } - auto *thread_state = PyThreadState_Get(); - if (s_depth_cache.valid - && s_depth_cache.thread_state_id == thread_state->id - && s_depth_cache.context_version == thread_state->context_ver - && s_depth_cache.generation == s_read_only_generation) { - return s_depth_cache.depth; + PyObject *py_states = nullptr; + if (PyContextVar_Get(s_read_only_depth_var, NULL, &py_states) < 0) { + PyErr_Clear(); + return 0; } - PyObject *py_depth = nullptr; - if (PyContextVar_Get(s_read_only_depth_var, NULL, &py_depth) < 0) { - PyErr_Clear(); + if (!py_states) { return 0; } unsigned int depth = 0; - if (py_depth) { - auto long_depth = PyLong_AsUnsignedLong(py_depth); - Py_DECREF(py_depth); - if (PyErr_Occurred()) { - PyErr_Clear(); - long_depth = 0; + if (PyTuple_Check(py_states)) { + auto tuple_size = PyTuple_GET_SIZE(py_states); + for (Py_ssize_t i = 0; i < tuple_size; ++i) { + if (readOnlyStateIsActive(PyTuple_GET_ITEM(py_states, i))) { + ++depth; + } + } + } + + Py_DECREF(py_states); + return depth; + } + + PyObject *makeContextStatesTuple(PyObject *new_state_capsule) + { + PyObject *current_states = nullptr; + if (PyContextVar_Get(s_read_only_depth_var, NULL, ¤t_states) < 0) { + PyErr_Clear(); + return nullptr; + } + + Py_ssize_t active_count = 0; + if (current_states && PyTuple_Check(current_states)) { + auto tuple_size = PyTuple_GET_SIZE(current_states); + for (Py_ssize_t i = 0; i < tuple_size; ++i) { + if (readOnlyStateIsActive(PyTuple_GET_ITEM(current_states, i))) { + ++active_count; + } + } + } + + auto *tuple = PyTuple_New(active_count + 1); + if (!tuple) { + Py_XDECREF(current_states); + return nullptr; + } + + Py_ssize_t tuple_index = 0; + if (current_states && PyTuple_Check(current_states)) { + auto tuple_size = PyTuple_GET_SIZE(current_states); + for (Py_ssize_t i = 0; i < tuple_size; ++i) { + auto *state_capsule = PyTuple_GET_ITEM(current_states, i); + if (!readOnlyStateIsActive(state_capsule)) { + continue; + } + Py_INCREF(state_capsule); + PyTuple_SET_ITEM(tuple, tuple_index++, state_capsule); } - depth = static_cast(long_depth); } + Py_INCREF(new_state_capsule); + PyTuple_SET_ITEM(tuple, tuple_index, new_state_capsule); + Py_XDECREF(current_states); + return tuple; + } + + unsigned int readOnlyDepthFromPythonContext() + { + if (!s_read_only_depth_var) { + return 0; + } + + auto *thread_state = PyThreadState_Get(); + auto active_generation = s_active_read_only_generation.load(std::memory_order_acquire); + if (s_depth_cache.valid + && s_depth_cache.thread_state_id == thread_state->id + && s_depth_cache.context_version == thread_state->context_ver + && s_depth_cache.generation == s_read_only_generation + && s_depth_cache.active_generation == active_generation) { + return s_depth_cache.depth; + } + + auto depth = activeReadOnlyDepthFromPythonContext(); + s_depth_cache = { .thread_state_id = thread_state->id, .context_version = thread_state->context_ver, .generation = s_read_only_generation, + .active_generation = active_generation, .depth = depth, .valid = true, }; @@ -73,9 +190,9 @@ namespace db0::python s_depth_cache.valid = false; } - PyObject *makeDepthObject(unsigned int depth) + PyObject *makeEmptyContextIdsTuple() { - return PyLong_FromUnsignedLong(depth); + return PyTuple_New(0); } } @@ -85,17 +202,26 @@ namespace db0::python THROWF(db0::InternalException) << "read_only context support is not initialized"; } - auto current_depth = readOnlyDepthFromPythonContext(); - auto next_depth = Py_OWN(makeDepthObject(current_depth + 1)); - if (!next_depth) { + m_state_capsule = makeReadOnlyStateCapsule(); + if (!m_state_capsule) { + THROWF(db0::InputException) << "unable to enter read_only context"; + } + + auto next_states = Py_OWN(makeContextStatesTuple(m_state_capsule)); + if (!next_states) { + Py_DECREF(m_state_capsule); + m_state_capsule = nullptr; THROWF(db0::InputException) << "unable to enter read_only context"; } - m_token = PyContextVar_Set(s_read_only_depth_var, next_depth.get()); + m_token = PyContextVar_Set(s_read_only_depth_var, next_states.get()); if (!m_token) { + Py_DECREF(m_state_capsule); + m_state_capsule = nullptr; THROWF(db0::InputException) << "unable to enter read_only context"; } db0::ReadOnlyContext::enterExternal(); + incrementActiveReadOnlyGeneration(); invalidateReadOnlyDepthCache(); } @@ -121,7 +247,14 @@ namespace db0::python } Py_DECREF(m_token); m_token = nullptr; + auto *state = getReadOnlyState(m_state_capsule); + if (state) { + state->active.store(false, std::memory_order_release); + } + Py_DECREF(m_state_capsule); + m_state_capsule = nullptr; db0::ReadOnlyContext::exitExternal(); + incrementActiveReadOnlyGeneration(); invalidateReadOnlyDepthCache(); } @@ -180,6 +313,18 @@ namespace db0::python return runSafe(PyAPI_tryBeginReadOnly, self); } + PyObject *PyAPI_inReadOnly(PyObject *, PyObject *const *, Py_ssize_t nargs) + { + if (nargs != 0) { + PyErr_SetString(PyExc_TypeError, "_in_read_only requires no arguments"); + return NULL; + } + if (db0::ReadOnlyContext::isActive()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + } + bool PyReadOnly_Check(PyObject *object) { return Py_TYPE(object) == &PyReadOnlyType; } @@ -199,11 +344,11 @@ namespace db0::python int initReadOnlyContextSupport() { if (!s_read_only_depth_var) { - auto default_depth = Py_OWN(makeDepthObject(0)); - if (!default_depth) { + auto default_ids = Py_OWN(makeEmptyContextIdsTuple()); + if (!default_ids) { return -1; } - s_read_only_depth_var = PyContextVar_New("dbzero_read_only_depth", default_depth.get()); + s_read_only_depth_var = PyContextVar_New("dbzero_read_only_context_ids", default_ids.get()); if (!s_read_only_depth_var) { return -1; } diff --git a/src/dbzero/bindings/python/PyReadOnly.hpp b/src/dbzero/bindings/python/PyReadOnly.hpp index 7e0e27f8..8999e2e2 100644 --- a/src/dbzero/bindings/python/PyReadOnly.hpp +++ b/src/dbzero/bindings/python/PyReadOnly.hpp @@ -22,6 +22,7 @@ namespace db0::python private: bool m_active = true; PyObject *m_token = nullptr; + PyObject *m_state_capsule = nullptr; }; using PyReadOnly = PyWrapper; @@ -36,6 +37,7 @@ namespace db0::python PyObject *PyAPI_PyReadOnly_close(PyObject *, PyObject *); PyObject *PyAPI_beginReadOnly(PyObject *self, PyObject *const *, Py_ssize_t nargs); + PyObject *PyAPI_inReadOnly(PyObject *self, PyObject *const *, Py_ssize_t nargs); int initReadOnlyContextSupport(); } diff --git a/src/dbzero/bindings/python/dbzero.cpp b/src/dbzero/bindings/python/dbzero.cpp index d61a6b3e..eead8d91 100644 --- a/src/dbzero/bindings/python/dbzero.cpp +++ b/src/dbzero/bindings/python/dbzero.cpp @@ -75,6 +75,7 @@ static PyMethodDef dbzero_methods[] = {"_in_async_task", (PyCFunction)&py::PyAPI_inAsyncTask, METH_FASTCALL, "Returns whether the current execution is an asyncio task"}, {"begin_locked", (PyCFunction)&py::PyAPI_beginLocked, METH_FASTCALL, "Enter a new locked section"}, {"begin_read_only", (PyCFunction)&py::PyAPI_beginReadOnly, METH_FASTCALL, "Enter a new read-only section"}, + {"_in_read_only", (PyCFunction)&py::PyAPI_inReadOnly, METH_FASTCALL, "Returns whether the current execution is in a read-only section"}, {"describe", &py::describeObject, METH_VARARGS, "Get dbzero object's description"}, {"rename_field", (PyCFunction)&py::renameField, METH_VARARGS | METH_KEYWORDS, "Get snapshot of dbzero state"}, {"_init_data_masking", (PyCFunction)&py::initDataMasking, METH_VARARGS | METH_KEYWORDS, "Initialize data masking for specific prefixes"}, diff --git a/src/dbzero/object_model/tags/ObjectTagManager.cpp b/src/dbzero/object_model/tags/ObjectTagManager.cpp index 18c6545f..a5af0723 100644 --- a/src/dbzero/object_model/tags/ObjectTagManager.cpp +++ b/src/dbzero/object_model/tags/ObjectTagManager.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025 DBZero Software sp. z o.o. #include "ObjectTagManager.hpp" +#include "ObjectIterator.hpp" #include #include #include @@ -49,14 +50,21 @@ namespace db0::object_model } - ObjectTagManager::ObjectTagManager(ObjectPtr const *memo_ptr, std::size_t nargs) - : m_info(memo_ptr[0]) + ObjectTagManager::ObjectTagManager(ObjectPtr const *memo_ptr, std::size_t nargs, + std::vector > &&query_targets) + : m_empty(nargs == 0 && query_targets.empty()) , m_info_vec_ptr((nargs > 1) ? (new ObjectInfo[nargs - 1]) : nullptr) - , m_info_vec_size(nargs - 1) - , m_access_mode(m_info.m_access_mode) - , m_fixtures(m_info.getFixture()) + , m_info_vec_size(nargs > 0 ? nargs - 1 : 0) + , m_query_targets(std::move(query_targets)) { - assert(nargs > 0); + if (m_empty) { + return; + } + if (nargs > 0) { + m_info = ObjectInfo(memo_ptr[0]); + m_access_mode = m_info.m_access_mode; + m_fixtures.add(m_info.getFixture()); + } for (std::size_t i = 1; i < nargs; ++i) { m_info_vec_ptr[i - 1] = ObjectInfo(memo_ptr[i]); m_fixtures.add(m_info_vec_ptr[i - 1].getFixture()); @@ -64,6 +72,13 @@ namespace db0::object_model m_access_mode = AccessType::READ_ONLY; } } + for (const auto &query_target: m_query_targets) { + auto fixture = query_target->getFixture(); + m_fixtures.add(fixture); + if (fixture->getAccessType() != AccessType::READ_WRITE) { + m_access_mode = AccessType::READ_ONLY; + } + } } ObjectTagManager::ObjectTagManager() @@ -80,13 +95,14 @@ namespace db0::object_model } } - ObjectTagManager *ObjectTagManager::makeNew(void *at_ptr, ObjectPtr const *memo_ptr, std::size_t nargs) + ObjectTagManager *ObjectTagManager::makeNew(void *at_ptr, ObjectPtr const *memo_ptr, std::size_t nargs, + std::vector > &&query_targets) { - if (nargs == 0) { + if (nargs == 0 && query_targets.empty()) { // construct as empty return new (at_ptr) ObjectTagManager(); } - return new (at_ptr) ObjectTagManager(memo_ptr, nargs); + return new (at_ptr) ObjectTagManager(memo_ptr, nargs, std::move(query_targets)); } ObjectTagManager::ObjectInfo::ObjectInfo(ObjectPtr memo_ptr) @@ -227,10 +243,16 @@ namespace db0::object_model if (db0::ReadOnlyContext::isActive()) { THROWF(db0::InputException) << "dbzero read_only context forbids mutation"; } - m_info.add(args, nargs); + validateQueryTargets(); + if (!!m_info.m_lang_ptr) { + m_info.add(args, nargs); + } for (std::size_t i = 0; i < m_info_vec_size; ++i) { m_info_vec_ptr[i].add(args, nargs); } + forEachQueryTarget([&](ObjectInfo &object_info) { + object_info.add(args, nargs); + }); onUpdated(); } @@ -246,10 +268,16 @@ namespace db0::object_model if (db0::ReadOnlyContext::isActive()) { THROWF(db0::InputException) << "dbzero read_only context forbids mutation"; } - m_info.remove(args, nargs); + validateQueryTargets(); + if (!!m_info.m_lang_ptr) { + m_info.remove(args, nargs); + } for (std::size_t i = 0; i < m_info_vec_size; ++i) { m_info_vec_ptr[i].remove(args, nargs); - } + } + forEachQueryTarget([&](ObjectInfo &object_info) { + object_info.remove(args, nargs); + }); onUpdated(); } @@ -271,4 +299,34 @@ namespace db0::object_model } } + void ObjectTagManager::validateQueryTargets() const + { + for (const auto &query_target: m_query_targets) { + if (!query_target) { + THROWF(db0::InputException) << "ObjectTagManager: invalid query target"; + } + if (query_target->isPredicateOnly()) { + THROWF(db0::InputException) << "ObjectTagManager: predicate queries cannot be used as tag targets"; + } + if (query_target->getFixture()->getAccessType() != AccessType::READ_WRITE) { + THROWF(db0::InputException) << "ObjectTagManager: cannot update tags through read-only query target"; + } + } + } + + void ObjectTagManager::forEachQueryTarget(std::function callback) + { + for (const auto &query_target: m_query_targets) { + auto iter = query_target->iter(); + while (true) { + auto object = iter->next(); + if (!object) { + break; + } + ObjectInfo object_info(object.get()); + callback(object_info); + } + } + } + } diff --git a/src/dbzero/object_model/tags/ObjectTagManager.hpp b/src/dbzero/object_model/tags/ObjectTagManager.hpp index 7a92b321..348dafb5 100644 --- a/src/dbzero/object_model/tags/ObjectTagManager.hpp +++ b/src/dbzero/object_model/tags/ObjectTagManager.hpp @@ -3,7 +3,11 @@ #pragma once +#include +#include +#include #include +#include #include #include #include @@ -29,7 +33,8 @@ namespace db0::object_model // construct as empty ObjectTagManager(); - ObjectTagManager(ObjectPtr const *memo_ptr, std::size_t nargs); + ObjectTagManager(ObjectPtr const *memo_ptr, std::size_t nargs, + std::vector > &&query_targets = {}); ~ObjectTagManager(); /** @@ -39,7 +44,8 @@ namespace db0::object_model void remove(ObjectPtr const *args, Py_ssize_t nargs); - static ObjectTagManager *makeNew(void *at_ptr, ObjectPtr const *memo_ptr, std::size_t nargs); + static ObjectTagManager *makeNew(void *at_ptr, ObjectPtr const *memo_ptr, std::size_t nargs, + std::vector > &&query_targets = {}); private: // Memo object to be assigned tags to (language specific) @@ -70,14 +76,17 @@ namespace db0::object_model // first object's info ObjectInfo m_info; // optional additional objects' info - ObjectInfo *m_info_vec_ptr; + ObjectInfo *m_info_vec_ptr = nullptr; std::size_t m_info_vec_size = 0; - AccessType m_access_mode; + AccessType m_access_mode = AccessType::READ_WRITE; + std::vector > m_query_targets; // fixtures of the tagged objects (to mark as updated) db0::WeakFixtureVector m_fixtures; bool m_on_updated = false; void onUpdated(); + void validateQueryTargets() const; + void forEachQueryTarget(std::function); }; }