Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dbzero/dbzero/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) 2025 DBZero Software sp. z o.o.

"""dbzero initialization functions"""
from collections.abc import Mapping
from typing import Any
from .dbzero import _init, open as dbzero_open
from .dbzero import _init, _init_data_masking, open as dbzero_open

def init(dbzero_root: str, **kwargs: Any) -> None:
"""Initialize the dbzero environment in a specified directory and apply global configurations.
Expand All @@ -29,6 +30,7 @@ def init(dbzero_root: str, **kwargs: Any) -> None:
* cache_size (int, default 2 GiB) for main object cache size in bytes
* lang_cache_size (int, default 1024) for language model data cache size
* lock_flags (dict) to configure locking behavior when opening the prefix in read-write mode
* data_masking (dict) to initialize data masking via _init_data_masking

Lock flags (dict):
* blocking (bool, default False) wait when trying to acquire the lock
Expand All @@ -55,3 +57,9 @@ def init(dbzero_root: str, **kwargs: Any) -> None:
if "prefix" in kwargs:
open_mode = "rw" if kwargs.get("read_write", True) else "r"
dbzero_open(kwargs["prefix"], open_mode=open_mode)

if "data_masking" in kwargs:
data_masking = kwargs["data_masking"]
if not isinstance(data_masking, Mapping):
raise TypeError("data_masking must be a mapping")
_init_data_masking(**data_masking)
47 changes: 47 additions & 0 deletions python_tests/test_data_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
# Copyright (c) 2025 DBZero Software sp. z o.o.

from contextvars import ContextVar
from dataclasses import dataclass

import pytest

import dbzero as db0
from .conftest import DB0_DIR


account_id = ContextVar("account_id")
missing_value = object()


@db0.memo(protect_fields=True)
@dataclass
class InitDataMaskingProtectedClass:
value: str


def test_init_data_masking_prefix_scoped_lifecycle(db0_fixture):
current_prefix = db0.get_current_prefix()

Expand Down Expand Up @@ -141,3 +149,42 @@ def test_init_data_masking_allows_different_bindings_for_different_prefixes(db0_
missing_value_placeholder=object(),
mode="RELEASE",
)


def test_init_can_initialize_workspace_data_masking(db0_fixture):
db0.close()
init_account_id = ContextVar("init_workspace_data_masking_account_id")

db0.init(
DB0_DIR,
data_masking={
"context_var": init_account_id,
"mode": "DEBUG",
},
)
db0.open("init-workspace-data-masking")
init_account_id.set(-2)

obj = InitDataMaskingProtectedClass("visible")

assert obj.value == "visible"


def test_init_can_initialize_prefix_data_masking_after_opening_prefix(db0_fixture):
db0.close()
init_account_id = ContextVar("init_prefix_data_masking_account_id")

db0.init(
DB0_DIR,
prefix="init-prefix-data-masking",
data_masking={
"context_var": init_account_id,
"prefix": "init-prefix-data-masking",
"mode": "DEBUG",
},
)
init_account_id.set(-2)

obj = InitDataMaskingProtectedClass("visible")

assert obj.value == "visible"
62 changes: 59 additions & 3 deletions python_tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,37 @@

import pytest
import dbzero as db0
from contextvars import ContextVar
from dataclasses import dataclass
from .memo_test_types import (MemoTestClass, MemoTestThreeParamsClass, MemoTestCustomLoadClass,
MemoTestCustomLoadClassWithParams, MemoTestSingleton)


@db0.enum(values=["CREATE", "READ", "UPDATE", "DELETE"])
class LoadFieldAccess:
pass


@db0.memo(protect_fields=True)
@dataclass
class LoadProtectedClass:
name: str
value: int
note: str


@db0.memo(protect_fields=True)
@dataclass
class LoadProtectedBaseClass:
base_value: str


@db0.memo
@dataclass
class LoadProtectedDerivedClass(LoadProtectedBaseClass):
derived_value: str


def test_load_py_string():
assert db0.load("abc") == "abc"

Expand Down Expand Up @@ -48,6 +75,36 @@ def test_load_memo_types(db0_fixture):
assert db0.load(memo) == {"value": "string"}


def test_load_protected_memo_only_loads_readable_fields(db0_fixture):
account_id = ContextVar("load_protected_account_id")
memo = LoadProtectedClass("alpha", 7, "private")
db0.set_field_access(LoadProtectedClass, 123, (LoadFieldAccess.READ,), "name", "value")
db0._init_data_masking(account_id)
account_id.set(123)

assert db0.load(memo) == {"name": "alpha", "value": 7}


def test_load_protected_memo_uses_inherited_read_masks(db0_fixture):
account_id = ContextVar("load_protected_inherited_account_id")
memo = LoadProtectedDerivedClass("base", "derived")
db0.set_field_access(LoadProtectedBaseClass, 123, (LoadFieldAccess.READ,), "base_value")
db0._init_data_masking(account_id)
account_id.set(123)

assert db0.load(memo) == {"base_value": "base"}


def test_load_protected_memo_applies_exclude_after_read_mask(db0_fixture):
account_id = ContextVar("load_protected_exclude_account_id")
memo = LoadProtectedClass("alpha", 7, "private")
db0.set_field_access(LoadProtectedClass, 123, (LoadFieldAccess.READ,), "name", "value")
db0._init_data_masking(account_id)
account_id.set(123)

assert db0.load(memo, exclude=["value"]) == {"name": "alpha"}


def test_load_db0_list(db0_fixture):
Colors = db0.enum("Colors", ["RED", "GREEN", "BLUE"])
cut = ["1", 2 , Colors.GREEN]
Expand Down Expand Up @@ -312,11 +369,10 @@ def test_load_all_shallow(db0_fixture):
loaded = db0.load_all(obj)
# NOTE: load_all ignores custom __load__ method at the topmost level
assert loaded == {"first": "one", "second": "two", "third": "three"}


def test_load_all_deep(db0_fixture):
obj = ObjectWithCustomLoad("one", "two", ObjectWithCustomLoad("three", "four", "five"))
loaded = db0.load_all(obj)
# NOTE: load_all ignores custom __load__ method at the topmost level only
assert loaded == {"first": "one", "second": "two", "third": {"first": "three"}}

44 changes: 36 additions & 8 deletions src/dbzero/bindings/python/Memo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1389,20 +1389,48 @@ namespace db0::python
(const std::string &key, PyTypes::ObjectSharedPtr)
{
auto key_obj = Py_OWN(PyUnicode_FromString(key.c_str()));
auto attr = Py_OWN(PyAPI_MemoObject_getattro(memo_obj, *key_obj));
if (!attr) {
if (!key_obj) {
has_error = true;
return false;
}
if (py_exclude == nullptr || py_exclude == Py_None || PySequence_Contains(py_exclude, *key_obj) == 0) {
auto res = Py_OWN(tryLoad(*attr, kwargs, nullptr, load_stack_ptr, false));
if (!res) {

if (py_exclude != nullptr && py_exclude != Py_None) {
int contains = PySequence_Contains(py_exclude, *key_obj);
if (contains < 0) {
has_error = true;
} else {
PySafeDict_SetItemString(*py_result, key.c_str(), res);
return false;
}
if (contains == 1) {
return true;
}
}

auto &memo_type = memo_obj->ext().getType();
if (memo_type.isProtectFields()) {
auto member_loc = memo_obj->ext().findField(key.c_str());
if (!checkProtectedFieldAccess(
memo_obj, db0::object_model::FieldMaskOptions::READ, member_loc, key.c_str()
)) {
if (PyErr_Occurred()) {
has_error = true;
return false;
}
return true;
}
}

auto attr = Py_OWN(PyAPI_MemoObject_getattro(memo_obj, *key_obj));
if (!attr) {
has_error = true;
return false;
}

auto res = Py_OWN(tryLoad(*attr, kwargs, nullptr, load_stack_ptr, false));
if (!res) {
has_error = true;
} else {
PySafeDict_SetItemString(*py_result, key.c_str(), res);
}
return !has_error;
});
if (has_error) {
Expand Down
Loading