diff --git a/dbzero/dbzero/initialization.py b/dbzero/dbzero/initialization.py index b0fb443b..4ad908cf 100644 --- a/dbzero/dbzero/initialization.py +++ b/dbzero/dbzero/initialization.py @@ -2,8 +2,9 @@ # Copyright (c) 2025 DBZero Software sp. z o.o. """dbzero initialization functions""" +from collections.abc import Mapping from typing import Any -from .dbzero import _init, open as dbzero_open +from .dbzero import _init, _init_data_masking, open as dbzero_open def init(dbzero_root: str, **kwargs: Any) -> None: """Initialize the dbzero environment in a specified directory and apply global configurations. @@ -29,6 +30,7 @@ def init(dbzero_root: str, **kwargs: Any) -> None: * cache_size (int, default 2 GiB) for main object cache size in bytes * lang_cache_size (int, default 1024) for language model data cache size * lock_flags (dict) to configure locking behavior when opening the prefix in read-write mode + * data_masking (dict) to initialize data masking via _init_data_masking Lock flags (dict): * blocking (bool, default False) wait when trying to acquire the lock @@ -55,3 +57,9 @@ def init(dbzero_root: str, **kwargs: Any) -> None: if "prefix" in kwargs: open_mode = "rw" if kwargs.get("read_write", True) else "r" dbzero_open(kwargs["prefix"], open_mode=open_mode) + + if "data_masking" in kwargs: + data_masking = kwargs["data_masking"] + if not isinstance(data_masking, Mapping): + raise TypeError("data_masking must be a mapping") + _init_data_masking(**data_masking) diff --git a/python_tests/test_data_masking.py b/python_tests/test_data_masking.py index 67ed5db2..529cdd81 100644 --- a/python_tests/test_data_masking.py +++ b/python_tests/test_data_masking.py @@ -2,16 +2,24 @@ # Copyright (c) 2025 DBZero Software sp. z o.o. from contextvars import ContextVar +from dataclasses import dataclass import pytest import dbzero as db0 +from .conftest import DB0_DIR account_id = ContextVar("account_id") missing_value = object() +@db0.memo(protect_fields=True) +@dataclass +class InitDataMaskingProtectedClass: + value: str + + def test_init_data_masking_prefix_scoped_lifecycle(db0_fixture): current_prefix = db0.get_current_prefix() @@ -141,3 +149,42 @@ def test_init_data_masking_allows_different_bindings_for_different_prefixes(db0_ missing_value_placeholder=object(), mode="RELEASE", ) + + +def test_init_can_initialize_workspace_data_masking(db0_fixture): + db0.close() + init_account_id = ContextVar("init_workspace_data_masking_account_id") + + db0.init( + DB0_DIR, + data_masking={ + "context_var": init_account_id, + "mode": "DEBUG", + }, + ) + db0.open("init-workspace-data-masking") + init_account_id.set(-2) + + obj = InitDataMaskingProtectedClass("visible") + + assert obj.value == "visible" + + +def test_init_can_initialize_prefix_data_masking_after_opening_prefix(db0_fixture): + db0.close() + init_account_id = ContextVar("init_prefix_data_masking_account_id") + + db0.init( + DB0_DIR, + prefix="init-prefix-data-masking", + data_masking={ + "context_var": init_account_id, + "prefix": "init-prefix-data-masking", + "mode": "DEBUG", + }, + ) + init_account_id.set(-2) + + obj = InitDataMaskingProtectedClass("visible") + + assert obj.value == "visible" diff --git a/python_tests/test_load.py b/python_tests/test_load.py index 8dadc394..7c3d3359 100644 --- a/python_tests/test_load.py +++ b/python_tests/test_load.py @@ -3,10 +3,37 @@ import pytest import dbzero as db0 +from contextvars import ContextVar +from dataclasses import dataclass from .memo_test_types import (MemoTestClass, MemoTestThreeParamsClass, MemoTestCustomLoadClass, MemoTestCustomLoadClassWithParams, MemoTestSingleton) +@db0.enum(values=["CREATE", "READ", "UPDATE", "DELETE"]) +class LoadFieldAccess: + pass + + +@db0.memo(protect_fields=True) +@dataclass +class LoadProtectedClass: + name: str + value: int + note: str + + +@db0.memo(protect_fields=True) +@dataclass +class LoadProtectedBaseClass: + base_value: str + + +@db0.memo +@dataclass +class LoadProtectedDerivedClass(LoadProtectedBaseClass): + derived_value: str + + def test_load_py_string(): assert db0.load("abc") == "abc" @@ -48,6 +75,36 @@ def test_load_memo_types(db0_fixture): assert db0.load(memo) == {"value": "string"} +def test_load_protected_memo_only_loads_readable_fields(db0_fixture): + account_id = ContextVar("load_protected_account_id") + memo = LoadProtectedClass("alpha", 7, "private") + db0.set_field_access(LoadProtectedClass, 123, (LoadFieldAccess.READ,), "name", "value") + db0._init_data_masking(account_id) + account_id.set(123) + + assert db0.load(memo) == {"name": "alpha", "value": 7} + + +def test_load_protected_memo_uses_inherited_read_masks(db0_fixture): + account_id = ContextVar("load_protected_inherited_account_id") + memo = LoadProtectedDerivedClass("base", "derived") + db0.set_field_access(LoadProtectedBaseClass, 123, (LoadFieldAccess.READ,), "base_value") + db0._init_data_masking(account_id) + account_id.set(123) + + assert db0.load(memo) == {"base_value": "base"} + + +def test_load_protected_memo_applies_exclude_after_read_mask(db0_fixture): + account_id = ContextVar("load_protected_exclude_account_id") + memo = LoadProtectedClass("alpha", 7, "private") + db0.set_field_access(LoadProtectedClass, 123, (LoadFieldAccess.READ,), "name", "value") + db0._init_data_masking(account_id) + account_id.set(123) + + assert db0.load(memo, exclude=["value"]) == {"name": "alpha"} + + def test_load_db0_list(db0_fixture): Colors = db0.enum("Colors", ["RED", "GREEN", "BLUE"]) cut = ["1", 2 , Colors.GREEN] @@ -312,11 +369,10 @@ def test_load_all_shallow(db0_fixture): loaded = db0.load_all(obj) # NOTE: load_all ignores custom __load__ method at the topmost level assert loaded == {"first": "one", "second": "two", "third": "three"} - - + + def test_load_all_deep(db0_fixture): obj = ObjectWithCustomLoad("one", "two", ObjectWithCustomLoad("three", "four", "five")) loaded = db0.load_all(obj) # NOTE: load_all ignores custom __load__ method at the topmost level only assert loaded == {"first": "one", "second": "two", "third": {"first": "three"}} - \ No newline at end of file diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index 93501c62..4a920a4c 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -1389,20 +1389,48 @@ namespace db0::python (const std::string &key, PyTypes::ObjectSharedPtr) { auto key_obj = Py_OWN(PyUnicode_FromString(key.c_str())); - auto attr = Py_OWN(PyAPI_MemoObject_getattro(memo_obj, *key_obj)); - if (!attr) { + if (!key_obj) { has_error = true; return false; } - - if (py_exclude == nullptr || py_exclude == Py_None || PySequence_Contains(py_exclude, *key_obj) == 0) { - auto res = Py_OWN(tryLoad(*attr, kwargs, nullptr, load_stack_ptr, false)); - if (!res) { + + if (py_exclude != nullptr && py_exclude != Py_None) { + int contains = PySequence_Contains(py_exclude, *key_obj); + if (contains < 0) { has_error = true; - } else { - PySafeDict_SetItemString(*py_result, key.c_str(), res); + return false; + } + if (contains == 1) { + return true; + } + } + + auto &memo_type = memo_obj->ext().getType(); + if (memo_type.isProtectFields()) { + auto member_loc = memo_obj->ext().findField(key.c_str()); + if (!checkProtectedFieldAccess( + memo_obj, db0::object_model::FieldMaskOptions::READ, member_loc, key.c_str() + )) { + if (PyErr_Occurred()) { + has_error = true; + return false; + } + return true; } } + + auto attr = Py_OWN(PyAPI_MemoObject_getattro(memo_obj, *key_obj)); + if (!attr) { + has_error = true; + return false; + } + + auto res = Py_OWN(tryLoad(*attr, kwargs, nullptr, load_stack_ptr, false)); + if (!res) { + has_error = true; + } else { + PySafeDict_SetItemString(*py_result, key.c_str(), res); + } return !has_error; }); if (has_error) {