Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbzero/dbzero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2025 DBZero Software sp. z o.o.

from .dbzero import *
from .dbzero import _check_interned, _init_data_filter, _init_data_masking
from .dbzero import _check_interned, _in_read_only, _init_data_filter, _init_data_masking
from .memo import *
from .enum import *
from .fast_query import *
Expand Down
19 changes: 17 additions & 2 deletions dbzero/dbzero/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .interfaces import Memo
from .dbzero import begin_atomic, begin_async_atomic, assign
from .dbzero import _in_read_only, begin_atomic, begin_async_atomic, assign


_async_atomic_locks: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = weakref.WeakKeyDictionary()
Expand All @@ -30,6 +30,17 @@ class _AsyncAtomicState:
)


class _NoOpAtomicContext:
def close(self):
pass

def cancel(self):
pass


_NO_OP_ATOMIC_CONTEXT = _NoOpAtomicContext()


def _current_async_task() -> Optional[asyncio.Task]:
try:
return asyncio.current_task()
Expand Down Expand Up @@ -68,7 +79,7 @@ def __exit__(self, exc_type, exc_value, traceback):
def begin(self):
"""Begin the atomic context"""
if self.__ctx is None:
self.__ctx = begin_atomic()
self.__ctx = _NO_OP_ATOMIC_CONTEXT if _in_read_only() else begin_atomic()

def close(self):
"""Close the atomic context, staging the changes for commit"""
Expand Down Expand Up @@ -155,6 +166,10 @@ def __init__(self):
self.__release_lock = False

async def __aenter__(self) -> AsyncAtomicManager:
if _in_read_only():
self.__ctx = _NO_OP_ATOMIC_CONTEXT
return self

task = _current_async_task()
if task is None:
raise RuntimeError("db0.async_atomic requires a running asyncio task")
Expand Down
19 changes: 16 additions & 3 deletions dbzero/dbzero/dbzero.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def read_only() -> ContextManager[Any]:
"""Open a context manager that rejects dbzero mutations in its block."""
...

def _in_read_only() -> bool:
"""Return whether the current execution is inside a dbzero read-only block."""
...

def in_read_only() -> bool:
"""Return whether the current execution is inside a dbzero read-only block."""
...

def open(prefix_name: str, open_mode: str = "rw", **kwargs: Any) -> None:
"""Open a data prefix and set it as the current working context.

Expand Down Expand Up @@ -970,13 +978,13 @@ def bytearray(source: Union[bytes, Iterable[int]] = b'', /) -> ByteArrayObject:

# Tag and query functions

def tags(*objects: Memo) -> ObjectTagManager:
def tags(*objects: Union[Memo, QueryObject]) -> ObjectTagManager:
"""Get a tag manager interface for given Memo objects.

Parameters
----------
*objects : Memo
One or more Memo objects to manage tags for.
*objects : Memo or QueryObject
One or more Memo objects or query result sets to manage tags for.

Returns
-------
Expand All @@ -1001,6 +1009,11 @@ def tags(*objects: Memo) -> ObjectTagManager:
>>> dbzero.tags(product1, product2).add("sale")
>>> dbzero.tags(product1, product2).remove("sale")

Batch operations on query results:

>>> dbzero.tags(dbzero.find("token-a", "token-b")).add("token-c")
>>> dbzero.tags(dbzero.find("token-a")).remove("token-a")

Chain operations:

>>> dbzero.tags(obj).add("tag1").remove("old-tag").add("tag2")
Expand Down
4 changes: 2 additions & 2 deletions dbzero/dbzero/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ByteArrayObject(bytearray):
...

class ObjectTagManager:
"""Manages tags of one or more Memo instances."""
"""Manages tags of one or more Memo instances or query result sets."""

def add(self, *tag: Union[Tag, Iterable[Tag]]) -> None:
"""Add one or more tags to the managed objects.
Expand Down Expand Up @@ -286,4 +286,4 @@ def get_state_num(self) -> int:
int
State number of a snapshot.
"""
...
...
7 changes: 6 additions & 1 deletion dbzero/dbzero/read_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from .dbzero import begin_read_only
from .dbzero import _in_read_only, begin_read_only


class ReadOnlyManager:
Expand All @@ -25,3 +25,8 @@ def __exit__(self, _exc_type, _exc_value, _traceback):
def read_only() -> ReadOnlyManager:
"""Open a context manager that rejects dbzero mutations in its block."""
return ReadOnlyManager()


def in_read_only() -> bool:
"""Return whether the current execution is inside a dbzero read-only block."""
return _in_read_only()
63 changes: 63 additions & 0 deletions python_tests/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,69 @@ def test_remove_tags_then_find_typed(db0_fixture):
db0.tags(objects[4], objects[2]).remove("one")
assert len(list(db0.find(MemoTestClass, "one"))) == 0


def test_add_tags_to_query_result(db0_fixture):
objects = [MemoTestClass(i) for i in range(6)]
db0.tags(objects[0]).add(["token-a", "token-b"])
db0.tags(objects[1]).add("token-a")
db0.tags(objects[2]).add(["token-a", "token-b"])
db0.tags(objects[3]).add("token-b")

db0.tags(db0.find("token-a", "token-b")).add("token-c")

assert {item.value for item in db0.find("token-c")} == {0, 2}


def test_remove_tags_from_query_result(db0_fixture):
objects = [MemoTestClass(i) for i in range(5)]
db0.tags(objects[0]).add(["token-a", "keep"])
db0.tags(objects[1]).add("token-a")
db0.tags(objects[2]).add(["token-a", "other"])
db0.tags(objects[3]).add("other")

db0.tags(db0.find("token-a")).remove("token-a")

assert list(db0.find("token-a")) == []
assert {item.value for item in db0.find("keep")} == {0}
assert {item.value for item in db0.find("other")} == {2, 3}


def test_query_tag_target_can_be_empty_or_mixed_with_memo_target(db0_fixture):
objects = [MemoTestClass(i) for i in range(4)]
db0.tags(objects[0]).add("source")
db0.tags(objects[1]).add("source")

db0.tags(db0.find("missing")).add("batch")
assert list(db0.find("batch")) == []

db0.tags(objects[2], db0.find("source")).add(["batch", "extra"])

assert {item.value for item in db0.find("batch")} == {0, 1, 2}
assert {item.value for item in db0.find("extra")} == {0, 1, 2}


def test_batched_query_tags_rejected_in_read_only_context(db0_fixture):
obj = MemoTestClass(1)
db0.tags(obj).add("source")

with db0.read_only():
with pytest.raises(RuntimeError, match="read_only|read-only"):
db0.tags(db0.find("source")).add("blocked")

assert list(db0.find("blocked")) == []


def test_batched_query_tags_reject_snapshot_query_target(db0_fixture):
obj = MemoTestClass(1)
db0.tags(obj).add("source")
db0.commit()

with db0.snapshot() as snap:
with pytest.raises(RuntimeError, match="read-only"):
db0.tags(snap.find("source")).remove("source")

assert list(db0.find("source")) == [obj]


def test_query_by_non_existing_tag(db0_fixture):
assert len(list(db0.find("tag1"))) == 0
Expand Down
1 change: 1 addition & 0 deletions python_tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def guarded_import(name, *args, **kwargs):
monkeypatch.setattr(builtins, "__import__", guarded_import)

db0.init(DB0_DIR)
assert db0.in_read_only() is False


def test_init_propagates_rpc_init_error_and_allows_recovery(db0_fixture, monkeypatch):
Expand Down
41 changes: 40 additions & 1 deletion python_tests/test_read_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2025 DBZero Software sp. z o.o.

import asyncio
import importlib
import threading

import pytest
Expand All @@ -23,6 +24,18 @@ def test_read_only_allows_reads(db0_fixture):
assert db0.fetch(db0.uuid(obj)) == obj


def test_in_read_only_public_predicate(db0_fixture):
assert not db0.in_read_only()

with db0.read_only():
assert db0.in_read_only()
with db0.read_only():
assert db0.in_read_only()
assert db0.in_read_only()

assert not db0.in_read_only()


def test_read_only_rejects_memo_field_assignment(db0_fixture):
obj = MemoTestClass(123)

Expand Down Expand Up @@ -125,14 +138,22 @@ def test_read_only_inside_atomic_rejects_mutation(db0_fixture):
assert obj.value == 123


def test_atomic_inside_read_only_starts_normally_but_mutation_is_rejected(db0_fixture):
def test_atomic_inside_read_only_is_optimized_out(db0_fixture, monkeypatch):
obj = MemoTestClass(123)
atomic_module = importlib.import_module("dbzero.atomic")

with db0.read_only():
assert db0._in_read_only()
monkeypatch.setattr(
atomic_module,
"begin_atomic",
lambda: pytest.fail("begin_atomic should not run inside read_only"),
)
with db0.atomic():
with pytest.raises(RuntimeError, match="read_only.*mutation|mutation.*read_only"):
obj.value = 456

assert not db0._in_read_only()
obj.value = 789
assert obj.value == 789

Expand Down Expand Up @@ -292,6 +313,24 @@ async def mutate_in_child_task():
assert obj.value == 789


async def test_read_only_child_task_stale_context_is_not_reactivated_by_unrelated_context(db0_fixture):
obj = MemoTestClass(123)
child_can_run = asyncio.Event()

async def mutate_in_child_task():
await child_can_run.wait()
obj.value = 789

with db0.read_only():
child_task = asyncio.create_task(mutate_in_child_task())

with db0.read_only():
child_can_run.set()
await asyncio.wait_for(child_task, timeout=5)

assert obj.value == 789


def test_read_only_fast_overhead_paths(db0_fixture):
obj = MemoTestClass(0)
iterations = 100
Expand Down
24 changes: 21 additions & 3 deletions src/dbzero/bindings/python/PyObjectTagManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "Memo.hpp"
#include "PyInternalAPI.hpp"
#include "PyToolkit.hpp"
#include "iter/PyObjectIterable.hpp"
#include <memory>
#include <vector>

namespace db0::python

Expand Down Expand Up @@ -102,10 +105,19 @@ namespace db0::python

PyObjectTagManager *tryMakeObjectTagManager(PyObject *, PyObject *const *args, Py_ssize_t nargs)
{
// all arguments must be Memo objects
std::vector<PyObject*> memo_args;
std::vector<std::shared_ptr<ObjectIterable> > query_targets;
memo_args.reserve(nargs);
query_targets.reserve(nargs);

for (Py_ssize_t i = 0; i < nargs; ++i) {
if (PyObjectIterable_Check(args[i])) {
auto *query = reinterpret_cast<PyObjectIterable*>(args[i]);
query_targets.push_back(query->getSharedPtr());
continue;
}
if (!PyToolkit::isAnyMemoObject(args[i])) {
THROWF(db0::InputException) << "All arguments must be dbzero memo objects";
THROWF(db0::InputException) << "All arguments must be dbzero memo objects or object queries";
}
if (PyMemo_Check<MemoObject>(args[i])) {
auto *memoObject = reinterpret_cast<MemoObject *>(args[i]);
Expand All @@ -118,10 +130,16 @@ namespace db0::python
auto materialized = Py_OWN(getMaterializedMemoObject(memoObject));
}
}
memo_args.push_back(args[i]);
}

auto tags_obj = Py_OWN(PyObjectTagManager_new(&PyObjectTagManagerType, NULL, NULL));
ObjectTagManager::makeNew(&tags_obj->modifyExt(), args, nargs);
ObjectTagManager::makeNew(
&tags_obj->modifyExt(),
memo_args.data(),
memo_args.size(),
std::move(query_targets)
);
return tags_obj.steal();
}

Expand Down
Loading
Loading