From ed636cecb8ed84dcbf8c99cb325fa743065207fc Mon Sep 17 00:00:00 2001 From: Wojtek Date: Thu, 14 May 2026 15:19:55 +0200 Subject: [PATCH 1/2] protecting masked fields + reset_protect_fields --- dbzero/dbzero/dbzero.pyi | 8 + python_tests/test_memo_protect_fields.py | 224 +++++++++++++++++- src/dbzero/bindings/python/DataMasking.hpp | 48 ++++ src/dbzero/bindings/python/Memo.cpp | 78 +++++- src/dbzero/bindings/python/PyAPI.cpp | 80 +++---- src/dbzero/bindings/python/PyAPI.hpp | 2 + src/dbzero/bindings/python/dbzero.cpp | 1 + src/dbzero/core/memory/config.cpp | 3 + src/dbzero/core/memory/config.hpp | 4 + src/dbzero/object_model/class/Class.cpp | 33 ++- src/dbzero/object_model/class/Class.hpp | 6 +- src/dbzero/object_model/class/MemberID.hpp | 5 +- .../object_model/object/ObjectImplBase.cpp | 47 ++-- .../object_model/object/ObjectImplBase.hpp | 9 +- src/dbzero/workspace/Workspace.cpp | 17 ++ src/dbzero/workspace/Workspace.hpp | 3 +- tests/unit_tests/WorkspaceTest.cpp | 29 +++ 17 files changed, 528 insertions(+), 69 deletions(-) create mode 100644 src/dbzero/bindings/python/DataMasking.hpp diff --git a/dbzero/dbzero/dbzero.pyi b/dbzero/dbzero/dbzero.pyi index 999324f7..dc2df94c 100644 --- a/dbzero/dbzero/dbzero.pyi +++ b/dbzero/dbzero/dbzero.pyi @@ -619,6 +619,14 @@ def get_field_access(class_obj: type, account_id: int) -> Iterable[Tuple[str, Tu """Return protected-field access flags for a memo class and account.""" ... +def reset_protect_fields(class_obj: type) -> None: + """Clear the persisted protect_fields flag for a memo class. + + The memo type must no longer be decorated with ``protect_fields=True``. + Remove the argument or set it to ``False`` before calling this function. + """ + ... + def _init_data_masking( context_var: Any, prefix: Union[str, Any, Sequence[Any], None] = None, diff --git a/python_tests/test_memo_protect_fields.py b/python_tests/test_memo_protect_fields.py index abdf5f30..2e49c4cb 100644 --- a/python_tests/test_memo_protect_fields.py +++ b/python_tests/test_memo_protect_fields.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: LGPL-2.1-or-later # Copyright (c) 2025 DBZero Software sp. z o.o. +import asyncio +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from contextvars import ContextVar import dbzero as db0 +import pytest from .conftest import DB0_DIR @@ -20,6 +24,17 @@ class MemoProtectedFieldsClass: value: int +@db0.memo(protect_fields=True) +@dataclass +class MemoProtectedComplexFieldsClass: + none_value: object + flag: bool + count: int + name: str + payload: bytes + pair: tuple + + @db0.memo @dataclass class MemoUnprotectedFieldsClass: @@ -145,12 +160,38 @@ class ProtectedAfter: assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True -def test_reset_protect_fields_clears_persisted_flag(db0_fixture): - obj = MemoProtectedFieldsClass("alpha", 1) +def test_reset_protect_fields_rejects_type_still_decorated_as_protected(db0_fixture): + MemoProtectedFieldsClass("alpha", 1) + + with pytest.raises(RuntimeError, match="decorated.*protect_fields"): + db0.reset_protect_fields(MemoProtectedFieldsClass) + + +def test_reset_protect_fields_clears_persisted_flag_after_decorator_removed(db0_fixture): + @db0.memo(id="dbzero-software/dbzero/tests/protected-reset", protect_fields=True) + @dataclass + class ProtectedBefore: + name: str + + obj = ProtectedBefore("alpha") + obj_id = db0.uuid(obj) + assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True + db0.commit() + + db0.close() + db0.init(DB0_DIR) + db0.open("my-test-prefix") + + @db0.memo(id="dbzero-software/dbzero/tests/protected-reset", protect_fields=False) + @dataclass + class ProtectedAfter: + name: str + + obj = db0.fetch(ProtectedAfter, obj_id) memo_class = get_memo_class_object(obj) assert memo_class.get_type_flags()["protect_fields"] is True - memo_class.reset_protect_fields() + assert db0.reset_protect_fields(ProtectedAfter) is None assert memo_class.get_type_flags()["protect_fields"] is False @@ -296,3 +337,180 @@ def test_describe_field_offset_range_counts_predeclared_access_fields(db0_fixtur materialized_range = db0.describe(obj)["field_offset_range"] assert materialized_range >= predeclared_range assert materialized_range < 120 * 2 + 128 + + +def test_protected_field_getter_requires_initialized_data_masking(db0_fixture): + obj = MemoProtectedFieldsClass("alpha", 1) + + with pytest.raises(RuntimeError, match="data masking"): + _ = obj.name + + +def test_protected_field_getter_requires_read_access(db0_fixture): + account_id = ContextVar("protected_read_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedFieldsClass, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.name == "alpha" + + with pytest.raises(PermissionError, match="read"): + _ = obj.value + + +def test_protected_field_getter_returns_missing_value_placeholder(db0_fixture): + account_id = ContextVar("protected_placeholder_account_id") + missing_value = object() + obj = MemoProtectedFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedFieldsClass, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id, missing_value_placeholder=missing_value) + account_id.set(123) + + assert obj.value is missing_value + + +def test_protected_field_getter_allows_debug_super_account_read(db0_fixture): + account_id = ContextVar("protected_super_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0._init_data_masking(account_id, mode="DEBUG") + account_id.set(-1) + + assert obj.name == "alpha" + assert obj.value == 1 + + +def test_protected_field_getter_rejects_invalid_negative_account_id(db0_fixture): + account_id = ContextVar("protected_invalid_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0._init_data_masking(account_id, mode="DEBUG") + account_id.set(-3) + + with pytest.raises(RuntimeError, match="invalid.*account"): + _ = obj.name + + +def test_protected_field_getter_applies_complex_masks_for_mixed_field_types(db0_fixture): + account_id = ContextVar("protected_complex_account_id") + obj = MemoProtectedComplexFieldsClass(None, False, 7, "alpha", b"payload", ("x", 11)) + + db0.set_field_access( + MemoProtectedComplexFieldsClass, + 101, + (FieldAccess.READ,), + "none_value", + "flag", + "name", + ) + db0.set_field_access( + MemoProtectedComplexFieldsClass, + 202, + (FieldAccess.READ,), + "count", + "payload", + ) + db0.set_field_access( + MemoProtectedComplexFieldsClass, + 303, + (FieldAccess.READ,), + "none_value", + "flag", + "count", + "name", + "payload", + "pair", + ) + db0._init_data_masking(account_id) + + account_id.set(101) + assert obj.none_value is None + assert obj.flag is False + assert obj.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = obj.count + with pytest.raises(PermissionError, match="read"): + _ = obj.payload + + account_id.set(202) + assert obj.count == 7 + assert obj.payload == b"payload" + with pytest.raises(PermissionError, match="read"): + _ = obj.none_value + with pytest.raises(PermissionError, match="read"): + _ = obj.flag + + account_id.set(303) + assert obj.none_value is None + assert obj.flag is False + assert obj.count == 7 + assert obj.name == "alpha" + assert obj.payload == b"payload" + assert obj.pair == ("x", 11) + + account_id.set(404) + for field_name in ("none_value", "flag", "count", "name", "payload", "pair"): + with pytest.raises(PermissionError, match="read"): + getattr(obj, field_name) + + +@pytest.mark.asyncio +async def test_protected_field_getter_uses_contextvar_per_async_task(db0_fixture): + account_id = ContextVar("protected_async_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.READ,), "name") + db0.set_field_access(MemoProtectedFieldsClass, 202, (FieldAccess.READ,), "value") + db0._init_data_masking(account_id) + + async def read_as(account, allowed_field, denied_field): + account_id.set(account) + await asyncio.sleep(0) + allowed_value = getattr(obj, allowed_field) + await asyncio.sleep(0) + with pytest.raises(PermissionError, match="read"): + getattr(obj, denied_field) + return allowed_value + + name_value, int_value = await asyncio.gather( + read_as(101, "name", "value"), + read_as(202, "value", "name"), + ) + + assert name_value == "alpha" + assert int_value == 1 + + +def test_protected_field_getter_uses_contextvar_per_thread(db0_fixture): + account_id = ContextVar("protected_thread_account_id") + obj = MemoProtectedComplexFieldsClass(None, True, 7, "alpha", b"payload", ("x", 11)) + db0.set_field_access( + MemoProtectedComplexFieldsClass, + 101, + (FieldAccess.READ,), + "none_value", + "flag", + ) + db0.set_field_access( + MemoProtectedComplexFieldsClass, + 202, + (FieldAccess.READ,), + "count", + "name", + ) + db0._init_data_masking(account_id) + + def read_as(account): + account_id.set(account) + if account == 101: + with pytest.raises(PermissionError, match="read"): + _ = obj.count + return obj.none_value, obj.flag + with pytest.raises(PermissionError, match="read"): + _ = obj.flag + return obj.count, obj.name + + with ThreadPoolExecutor(max_workers=2) as pool: + account_101 = pool.submit(read_as, 101) + account_202 = pool.submit(read_as, 202) + + assert account_101.result() == (None, True) + assert account_202.result() == (7, "alpha") diff --git a/src/dbzero/bindings/python/DataMasking.hpp b/src/dbzero/bindings/python/DataMasking.hpp new file mode 100644 index 00000000..486c82f8 --- /dev/null +++ b/src/dbzero/bindings/python/DataMasking.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2025 DBZero Software sp. z o.o. + +#pragma once + +#include + +namespace db0 + +{ + + enum class DataMaskingMode + { + RELEASE, + DEBUG + }; + + struct DataMaskingState + { + PyObject *contextVar = nullptr; + PyObject *missingValuePlaceholder = nullptr; + bool hasMissingValuePlaceholder = false; + DataMaskingMode mode = DataMaskingMode::RELEASE; + + DataMaskingState(PyObject *contextVar, PyObject *missingValuePlaceholder, + bool hasMissingValuePlaceholder, DataMaskingMode mode) + : contextVar(contextVar) + , missingValuePlaceholder(missingValuePlaceholder) + , hasMissingValuePlaceholder(hasMissingValuePlaceholder) + , mode(mode) + { + Py_INCREF(contextVar); + if (missingValuePlaceholder) { + Py_INCREF(missingValuePlaceholder); + } + } + + bool matches(PyObject *otherContextVar, PyObject *otherMissingValuePlaceholder, + bool otherHasMissingValuePlaceholder, DataMaskingMode otherMode) const + { + return contextVar == otherContextVar + && missingValuePlaceholder == otherMissingValuePlaceholder + && hasMissingValuePlaceholder == otherHasMissingValuePlaceholder + && mode == otherMode; + } + }; + +} diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index d334a729..6369df01 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -11,12 +11,15 @@ #include "Types.hpp" #include "Migration.hpp" #include "PyHash.hpp" +#include "DataMasking.hpp" #include #include +#include #include #include #include #include +#include #include #include #include @@ -356,6 +359,66 @@ namespace db0::python return !(attr_name[0] == '_' && attr_name[1] == 'X' && attr_name[2] == '_' && attr_name[3] == '_'); } + template + PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc) + { + auto &type = memo_obj->ext().getType(); + if (!type.isProtectFields()) { + return nullptr; + } + + if (!Settings::m_data_masking_enabled) { + PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); + return nullptr; + } + + auto masking_state = memo_obj->ext().getFixture()->getMaskingState(); + if (!masking_state) { + PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); + return nullptr; + } + + PyObject *pyAccountId = nullptr; + if (PyContextVar_Get(masking_state->contextVar, NULL, &pyAccountId) < 0) { + PyErr_SetString(PyExc_RuntimeError, "unable to read data masking account context"); + return nullptr; + } + if (!pyAccountId) { + PyErr_SetString(PyExc_RuntimeError, "data masking account context is not set"); + return nullptr; + } + + auto accountId = PyLong_AsLongLong(pyAccountId); + Py_DECREF(pyAccountId); + if (PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "data masking account context must be an int"); + return nullptr; + } + + bool canRead = false; + if (accountId < -2) { + PyErr_SetString(PyExc_RuntimeError, "invalid data masking account id"); + return nullptr; + } else if (accountId == -1 || accountId == -2) { + canRead = masking_state->mode == DataMaskingMode::DEBUG; + } else if (accountId >= 0) { + auto mask = type.tryGetFieldAccess(static_cast(accountId), member_loc); + canRead = mask && (*mask)[db0::object_model::FieldMaskOptions::READ]; + } + + if (canRead) { + return nullptr; + } + + if (masking_state->hasMissingValuePlaceholder) { + Py_INCREF(masking_state->missingValuePlaceholder); + return masking_state->missingValuePlaceholder; + } + + PyErr_SetString(PyExc_PermissionError, "data masking denies read access to protected field"); + return nullptr; + } + template PyObject *tryMemoObject_getattro(MemoImplT *memo_obj, PyObject *attr) { @@ -375,9 +438,14 @@ namespace db0::python ObjectSharedPtr member; if (isPersistentAttrName(attr_name)) { memo_obj->ext().getFixture()->refreshIfUpdated(); - member = memo_obj->ext().tryGet(PyUnicode_AsUTF8(attr), &is_auto_generated); + auto member_loc = memo_obj->ext().findField(attr_name); + member = memo_obj->ext().tryGet(member_loc, &is_auto_generated); if (member.get() && !is_auto_generated) { + auto masked = checkProtectedFieldReadAccess(memo_obj, member_loc); + if (masked || PyErr_Occurred()) { + return masked; + } return member.steal(); } } @@ -1108,8 +1176,14 @@ namespace db0::python template PyObject *tryGetAttrAs(MemoImplT *memo_obj, PyObject *attr, PyTypeObject *py_type) { + const char *attr_name = PyUnicode_AsUTF8(attr); + if (!attr_name) { + PyErr_SetString(PyExc_AttributeError, "Invalid attribute name"); + return nullptr; + } + memo_obj->ext().getFixture()->refreshIfUpdated(); - auto member = memo_obj->ext().tryGetAs(PyUnicode_AsUTF8(attr), py_type); + auto member = memo_obj->ext().tryGetAs(memo_obj->ext().findField(attr_name), py_type); if (member.get()) { return member.steal(); } diff --git a/src/dbzero/bindings/python/PyAPI.cpp b/src/dbzero/bindings/python/PyAPI.cpp index 3ff06ba2..7fb86ebc 100644 --- a/src/dbzero/bindings/python/PyAPI.cpp +++ b/src/dbzero/bindings/python/PyAPI.cpp @@ -16,6 +16,7 @@ #include "PyReflectionAPI.hpp" #include "PyHash.hpp" #include "PyWeakProxy.hpp" +#include "DataMasking.hpp" #include #include #include @@ -43,48 +44,6 @@ #include #include -namespace db0 - -{ - - enum class DataMaskingMode - { - RELEASE, - DEBUG - }; - - struct DataMaskingState - { - PyObject *contextVar = nullptr; - PyObject *missingValuePlaceholder = nullptr; - bool hasMissingValuePlaceholder = false; - DataMaskingMode mode = DataMaskingMode::RELEASE; - - DataMaskingState(PyObject *contextVar, PyObject *missingValuePlaceholder, - bool hasMissingValuePlaceholder, DataMaskingMode mode) - : contextVar(contextVar) - , missingValuePlaceholder(missingValuePlaceholder) - , hasMissingValuePlaceholder(hasMissingValuePlaceholder) - , mode(mode) - { - Py_INCREF(contextVar); - if (missingValuePlaceholder) { - Py_INCREF(missingValuePlaceholder); - } - } - - bool matches(PyObject *otherContextVar, PyObject *otherMissingValuePlaceholder, - bool otherHasMissingValuePlaceholder, DataMaskingMode otherMode) const - { - return contextVar == otherContextVar - && missingValuePlaceholder == otherMissingValuePlaceholder - && hasMissingValuePlaceholder == otherHasMissingValuePlaceholder - && mode == otherMode; - } - }; - -} - namespace db0::python { @@ -1045,6 +1004,43 @@ namespace db0::python return runSafe(tryGetFieldAccess, args); } + PyObject *tryResetProtectFields(PyObject *args) + { + PyObject *py_type = nullptr; + if (!PyArg_ParseTuple(args, "O:reset_protect_fields", &py_type)) { + return nullptr; + } + if (!PyType_Check(py_type)) { + THROWF(db0::InputException) << "First argument must be a type"; + } + if (!PyAnyMemoType_Check(reinterpret_cast(py_type))) { + THROWF(db0::InputException) << "First argument must be a dbzero memo type"; + } + + auto memo_type = reinterpret_cast(py_type); + auto &decor = MemoTypeDecoration::get(memo_type); + if (decor.getFlags()[MemoOptions::PROTECT_FIELDS]) { + THROWF(db0::InputException) + << "Type is still decorated with protect_fields=True; remove it or set protect_fields=False first"; + } + + using ClassFactory = db0::object_model::ClassFactory; + auto fixture_uuid = decor.getFixtureUUID(AccessType::READ_WRITE); + auto fixture = PyToolkit::getPyWorkspace().getWorkspace().getFixture(fixture_uuid, AccessType::READ_WRITE); + auto &class_factory = fixture->get(); + auto type = class_factory.getExistingType(memo_type); + + db0::FixtureLock lock(fixture); + type->resetProtectFields(); + Py_RETURN_NONE; + } + + PyObject *resetProtectFields(PyObject *, PyObject *args) + { + PY_API_FUNC + return runSafe(tryResetProtectFields, args); + } + PyObject *TryPyAPI_isSingleton(PyObject *py_object) { assert((PyMemo_Check(py_object))); diff --git a/src/dbzero/bindings/python/PyAPI.hpp b/src/dbzero/bindings/python/PyAPI.hpp index 13261c9a..b4ff4865 100644 --- a/src/dbzero/bindings/python/PyAPI.hpp +++ b/src/dbzero/bindings/python/PyAPI.hpp @@ -112,6 +112,8 @@ namespace db0::python PyObject *getFieldAccess(PyObject *self, PyObject *args); + PyObject *resetProtectFields(PyObject *self, PyObject *args); + PyObject *PyAPI_isSingleton(PyObject *self, PyObject *args); PyObject *PyAPI_getRefCount(PyObject *self, PyObject *args); diff --git a/src/dbzero/bindings/python/dbzero.cpp b/src/dbzero/bindings/python/dbzero.cpp index 4a0d2053..abbeb243 100644 --- a/src/dbzero/bindings/python/dbzero.cpp +++ b/src/dbzero/bindings/python/dbzero.cpp @@ -69,6 +69,7 @@ static PyMethodDef dbzero_methods[] = {"_init_data_masking", (PyCFunction)&py::initDataMasking, METH_VARARGS | METH_KEYWORDS, "Initialize data masking for specific prefixes"}, {"set_field_access", (PyCFunction)&py::setFieldAccess, METH_VARARGS, "Set protected field access masks for a memo class"}, {"get_field_access", (PyCFunction)&py::getFieldAccess, METH_VARARGS, "Get protected field access masks for a memo class and account"}, + {"reset_protect_fields", (PyCFunction)&py::resetProtectFields, METH_VARARGS, "Clear the persisted protected-fields flag for a memo class"}, {"is_singleton", &py::PyAPI_isSingleton, METH_VARARGS, "Check if a specific instance is a dbzero singleton"}, {"getrefcount", &py::PyAPI_getRefCount, METH_VARARGS, "Get dbzero ref counts"}, {"no", (PyCFunction)&py::negTagSet, METH_FASTCALL, "Tag negation function"}, diff --git a/src/dbzero/core/memory/config.cpp b/src/dbzero/core/memory/config.cpp index 3c6c804b..29d50c54 100644 --- a/src/dbzero/core/memory/config.cpp +++ b/src/dbzero/core/memory/config.cpp @@ -19,6 +19,8 @@ namespace db0 std::function Settings::m_decode_error = []() { THROWF(db0::IOException) << "Data decoding error: corrupt data detected"; }; + + bool Settings::m_data_masking_enabled = false; void Settings::reset() { @@ -29,6 +31,7 @@ namespace db0 __write_poison = 0; __dram_io_flush_poison = 0; #endif + m_data_masking_enabled = false; } } diff --git a/src/dbzero/core/memory/config.hpp b/src/dbzero/core/memory/config.hpp index 121481c9..1a570837 100644 --- a/src/dbzero/core/memory/config.hpp +++ b/src/dbzero/core/memory/config.hpp @@ -37,6 +37,10 @@ namespace db0 // Function to throw the data decoding error (i.e. corrupt data detected) static std::function m_decode_error; + // Shortcut flag: true when data masking is enabled for at least one open prefix. + // Callers can use this to skip checking data masking rules when no open fixture can use them. + static bool m_data_masking_enabled; + // reset all settings to default values static void reset(); }; diff --git a/src/dbzero/object_model/class/Class.cpp b/src/dbzero/object_model/class/Class.cpp index 12ddc9ec..18c87b09 100644 --- a/src/dbzero/object_model/class/Class.cpp +++ b/src/dbzero/object_model/class/Class.cpp @@ -209,7 +209,7 @@ namespace db0::object_model return false; } - std::pair Class::findField(const char *name) const + MemberLoc Class::findField(const char *name) const { // NOTE: refresh is a lightweght operation if there were no changes (no detach) m_member_cache.fastRefresh(); @@ -406,6 +406,37 @@ namespace db0::object_model } } + std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const Member &member) const + { + if (!isProtectFields()) { + return {}; + } + + auto &field_safe = getFieldSafe(); + auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(member.m_field_id); + if (!maybe_offset) { + return {}; + } + + auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); + if (!field_mask) { + return {}; + } + + return field_mask->getAssignedMask(*maybe_offset); + } + + std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &member_loc) const + { + const auto &member_id = member_loc.first; + if (!member_id) { + return {}; + } + + auto member = tryGetMember(member_id.primary().first); + return member ? tryGetFieldAccess(account_id, *member) : std::nullopt; + } + std::vector > Class::getFieldAccess(std::uint64_t account_id) const { if (!isProtectFields()) { diff --git a/src/dbzero/object_model/class/Class.hpp b/src/dbzero/object_model/class/Class.hpp index 4f52f581..04a09729 100644 --- a/src/dbzero/object_model/class/Class.hpp +++ b/src/dbzero/object_model/class/Class.hpp @@ -146,7 +146,7 @@ DB0_PACKED_END MemberID addField(const char *name, unsigned int fidelity); // @return member ID / init var flag assigned on initialization flag (see Schema Extensions) - std::pair findField(const char *name) const; + MemberLoc findField(const char *name) const; // Get the total number of unique members declared in this class std::size_t size() const { @@ -183,6 +183,8 @@ DB0_PACKED_END const FieldSafe &getFieldSafe() const; void setFieldAccess(const std::vector &account_ids, FieldMaskFlags mask, const std::vector &field_names); + std::optional tryGetFieldAccess(std::uint64_t account_id, const Member &) const; + std::optional tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &) const; std::vector > getFieldAccess(std::uint64_t account_id) const; std::uint32_t getFieldOffsetRange() const; @@ -333,7 +335,7 @@ DB0_PACKED_END // Field by-name index (cache) // values: member ID / assigned on initialization flag - mutable std::unordered_map > m_index; + mutable std::unordered_map m_index; // For fidelity = 0 this maps "index" to the unique field ID mutable std::vector m_unique_keys; // fields initialized on class creation (from static code analysis) diff --git a/src/dbzero/object_model/class/MemberID.hpp b/src/dbzero/object_model/class/MemberID.hpp index f8011cdd..63146370 100644 --- a/src/dbzero/object_model/class/MemberID.hpp +++ b/src/dbzero/object_model/class/MemberID.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "FieldID.hpp" namespace db0::object_model @@ -88,6 +89,8 @@ namespace db0::object_model std::pair m_primary = { FieldID(), std::numeric_limits::max() }; std::pair m_secondary = { FieldID(), std::numeric_limits::max() }; }; + + using MemberLoc = std::pair; } @@ -97,4 +100,4 @@ namespace std ostream &operator<<(std::ostream &os, const db0::object_model::MemberID &member_id); -} \ No newline at end of file +} diff --git a/src/dbzero/object_model/object/ObjectImplBase.cpp b/src/dbzero/object_model/object/ObjectImplBase.cpp index 8563e004..310c6bac 100644 --- a/src/dbzero/object_model/object/ObjectImplBase.cpp +++ b/src/dbzero/object_model/object/ObjectImplBase.cpp @@ -421,7 +421,7 @@ namespace db0::object_model } template - std::pair ObjectImplBase::findField(const char *name) const + MemberLoc ObjectImplBase::findField(const char *name) const { if (this->isDropped()) { // defunct objects should not be accessed @@ -440,11 +440,10 @@ namespace db0::object_model } template - FieldID ObjectImplBase::tryGetMember(const char *field_name, std::pair &member, - bool &is_init_var, bool *is_auto_generated) const + FieldID ObjectImplBase::tryGetMember(MemberLoc member_loc, std::pair &member, + bool *is_auto_generated) const { - MemberID member_id; - std::tie(member_id, is_init_var) = this->findField(field_name); + auto [member_id, is_init_var] = member_loc; bool exists, deleted = false; if (is_auto_generated) { *is_auto_generated = false; @@ -481,9 +480,9 @@ namespace db0::object_model } template - std::optional ObjectImplBase::tryGetX(const char *field_name) const + std::optional ObjectImplBase::tryGetX(MemberLoc member_loc) const { - auto [member_id, is_init_var] = this->findField(field_name); + auto [member_id, is_init_var] = member_loc; bool exists, deleted = false; if (member_id) { assert(member_id.primary().first); @@ -513,13 +512,19 @@ namespace db0::object_model return std::nullopt; } + template + std::optional ObjectImplBase::tryGetX(const char *field_name) const + { + return tryGetX(this->findField(field_name)); + } + template typename ObjectImplBase::ObjectSharedPtr - ObjectImplBase::tryGet(const char *field_name, bool *is_auto_generated) const + ObjectImplBase::tryGet(MemberLoc member_loc, bool *is_auto_generated) const { std::pair member; - bool is_init_var = false; - auto field_id = tryGetMember(field_name, member, is_init_var, is_auto_generated); + auto is_init_var = member_loc.second; + auto field_id = tryGetMember(member_loc, member, is_auto_generated); // NOTE: init vars are always reported as None if not explicitly set nor explicitly deleted if (field_id || (is_init_var && member.first != StorageClass::DELETED)) { auto fixture = this->getFixture(); @@ -533,14 +538,21 @@ namespace db0::object_model return nullptr; } + + template + typename ObjectImplBase::ObjectSharedPtr + ObjectImplBase::tryGet(const char *field_name, bool *is_auto_generated) const + { + return tryGet(this->findField(field_name), is_auto_generated); + } template typename ObjectImplBase::ObjectSharedPtr ObjectImplBase::tryGetAs( - const char *field_name, TypeObjectPtr lang_type) const + MemberLoc member_loc, TypeObjectPtr lang_type) const { std::pair member; - bool is_init_var = false; - auto field_id = tryGetMember(field_name, member, is_init_var); + auto is_init_var = member_loc.second; + auto field_id = tryGetMember(member_loc, member); if (field_id || (is_init_var && member.first != StorageClass::DELETED)) { // prevent accessing a deleted member assert(member.first != StorageClass::DELETED && member.first != StorageClass::UNDEFINED); @@ -558,6 +570,13 @@ namespace db0::object_model return nullptr; } + + template + typename ObjectImplBase::ObjectSharedPtr ObjectImplBase::tryGetAs( + const char *field_name, TypeObjectPtr lang_type) const + { + return tryGetAs(this->findField(field_name), lang_type); + } template typename ObjectImplBase::ObjectSharedPtr ObjectImplBase::get(const char *field_name) const @@ -1077,4 +1096,4 @@ namespace db0::object_model template class ObjectImplBase; template class ObjectImplBase; -} \ No newline at end of file +} diff --git a/src/dbzero/object_model/object/ObjectImplBase.hpp b/src/dbzero/object_model/object/ObjectImplBase.hpp index 0c31855b..149f79a6 100644 --- a/src/dbzero/object_model/object/ObjectImplBase.hpp +++ b/src/dbzero/object_model/object/ObjectImplBase.hpp @@ -96,7 +96,9 @@ namespace db0::object_model void setPreInit(const char *field_name, ObjectPtr lang_value) const; void removePreInit(const char *field_name) const; + ObjectSharedPtr tryGet(MemberLoc, bool *is_auto_generated = nullptr) const; ObjectSharedPtr tryGet(const char *field_name, bool *is_auto_generated = nullptr) const; + ObjectSharedPtr tryGetAs(MemberLoc, TypeObjectPtr) const; ObjectSharedPtr tryGetAs(const char *field_name, TypeObjectPtr) const; ObjectSharedPtr get(const char *field_name) const; @@ -127,7 +129,7 @@ namespace db0::object_model void commit() const; // FieldID, is_init_var, fidelity - std::pair findField(const char *name) const; + MemberLoc findField(const char *name) const; // NOTE: hasRefs is NOT available in ObjectAnyBase bacause // of the use of num_type_tags property @@ -156,8 +158,8 @@ namespace db0::object_model std::pair &find_result) const; std::pair tryGetMemberAt(std::pair, std::pair &) const; - FieldID tryGetMember(const char *field_name, std::pair &, - bool &is_init_var, bool *is_auto_generated = nullptr) const; + FieldID tryGetMember(MemberLoc, std::pair &, + bool *is_auto_generated = nullptr) const; // Try resolving field ID of an existing (or deleted) member and also its storage location // @param pos the member's position in the containing collection @@ -194,6 +196,7 @@ namespace db0::object_model bool hasValidClassRef() const; // try retrieving member as XValue + std::optional tryGetX(MemberLoc) const; std::optional tryGetX(const char *field_name) const; // Unreference value diff --git a/src/dbzero/workspace/Workspace.cpp b/src/dbzero/workspace/Workspace.cpp index 65541132..be3321b0 100644 --- a/src/dbzero/workspace/Workspace.cpp +++ b/src/dbzero/workspace/Workspace.cpp @@ -7,6 +7,8 @@ #include "Config.hpp" #include "WorkspaceView.hpp" #include "PrefixName.hpp" +#include +#include #include namespace db0 @@ -272,6 +274,7 @@ namespace db0 bool is_default = (it->second == m_default_fixture); it->second->close(false); m_fixtures.erase(it); + updateDataMaskingSettingsFlag(); if (is_default) { m_default_fixture = {}; @@ -315,6 +318,7 @@ namespace db0 it->second->close(as_defunct, timer.get()); it = m_fixtures.erase(it); } + updateDataMaskingSettingsFlag(); if (as_defunct) { m_lang_cache->clearDefunct(); @@ -376,6 +380,7 @@ namespace db0 } it = m_fixtures.emplace(fixture->getUUID(), fixture).first; + updateDataMaskingSettingsFlag(); m_fixture_catalog.add(prefix_name, *fixture); if (*access_type == AccessType::READ_ONLY) { // add read-only fixture to be monitored by the refresh thread (will be removed automatically when closed) @@ -565,6 +570,7 @@ namespace db0 for (auto &[uuid, fixture]: m_fixtures) { fixture->initMaskingState(m_data_masking_state); } + updateDataMaskingSettingsFlag(); } void Workspace::initDataMasking(const PrefixName &prefix_name, std::shared_ptr state) @@ -580,6 +586,7 @@ namespace db0 if (fixture) { fixture->initMaskingState(it->second); } + updateDataMaskingSettingsFlag(); } std::shared_ptr Workspace::getDataMaskingState() const @@ -852,6 +859,16 @@ namespace db0 std::size_t Workspace::size() const { return m_fixtures.size(); } + + void Workspace::updateDataMaskingSettingsFlag() const + { + Settings::m_data_masking_enabled = std::any_of( + m_fixtures.begin(), + m_fixtures.end(), + [](const auto &item) { + return static_cast(item.second->getMaskingState()); + }); + } std::optional Workspace::getLangCacheSize() const { diff --git a/src/dbzero/workspace/Workspace.hpp b/src/dbzero/workspace/Workspace.hpp index 2d419394..f2e1a9a6 100644 --- a/src/dbzero/workspace/Workspace.hpp +++ b/src/dbzero/workspace/Workspace.hpp @@ -350,7 +350,8 @@ namespace db0 void onFlushDirty(std::size_t limit) override; std::optional getLangCacheSize() const; - std::shared_ptr getWorkspaceHeadView() const; + std::shared_ptr getWorkspaceHeadView() const; + void updateDataMaskingSettingsFlag() const; }; } diff --git a/tests/unit_tests/WorkspaceTest.cpp b/tests/unit_tests/WorkspaceTest.cpp index 19722ca4..db1d1abd 100644 --- a/tests/unit_tests/WorkspaceTest.cpp +++ b/tests/unit_tests/WorkspaceTest.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ namespace tests void SetUp() override { + Settings::reset(); if (!Py_IsInitialized()) { Py_InitializeEx(0); } @@ -47,6 +49,7 @@ namespace tests { m_workspace.close(); m_no_gil = nullptr; + Settings::reset(); drop(file_name); } @@ -146,6 +149,32 @@ namespace tests ASSERT_EQ(snapshot_fixture->getMaskingState(), masking_state); } + + TEST_F( WorkspaceTest , testSettingsDataMaskingEnabledTracksWorkspaceScopeOpenFixtures ) + { + auto masking_state = makeTestMaskingState(3); + m_workspace.initDataMasking(masking_state); + ASSERT_FALSE(Settings::m_data_masking_enabled); + + auto fixture = m_workspace.getFixture(getPrefixName()); + ASSERT_TRUE(Settings::m_data_masking_enabled); + + m_workspace.close(fixture->getPrefix().getName()); + ASSERT_FALSE(Settings::m_data_masking_enabled); + } + + TEST_F( WorkspaceTest , testSettingsDataMaskingEnabledTracksPrefixScopeOpenFixtures ) + { + auto masking_state = makeTestMaskingState(4); + m_workspace.initDataMasking(getPrefixName(), masking_state); + ASSERT_FALSE(Settings::m_data_masking_enabled); + + auto fixture = m_workspace.getFixture(getPrefixName()); + ASSERT_TRUE(Settings::m_data_masking_enabled); + + m_workspace.close(fixture->getPrefix().getName()); + ASSERT_FALSE(Settings::m_data_masking_enabled); + } TEST_F( WorkspaceTest , testFreeCanBePerformedBetweenTransactions ) { From 1f85b7922e5571864de1126ab331a0b6bc46ebb9 Mon Sep 17 00:00:00 2001 From: Wojtek Date: Thu, 14 May 2026 15:29:28 +0200 Subject: [PATCH 2/2] post-review fixes --- python_tests/test_memo_protect_fields.py | 36 ++++++++++++++++++++++++ src/dbzero/bindings/python/Memo.cpp | 5 +++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python_tests/test_memo_protect_fields.py b/python_tests/test_memo_protect_fields.py index 2e49c4cb..74611f76 100644 --- a/python_tests/test_memo_protect_fields.py +++ b/python_tests/test_memo_protect_fields.py @@ -35,6 +35,23 @@ class MemoProtectedComplexFieldsClass: pair: tuple +@db0.memo(protect_fields=True) +@dataclass +class MemoProtectedDefaultFieldsClass: + name: str + generated_value: str = "auto-generated" + + +@db0.memo(protect_fields=True) +class MemoProtectedAutoGeneratedInitVarClass: + generated_value = "auto-generated" + + def __init__(self, name, include_generated=False): + self.name = name + if include_generated: + self.generated_value = "persisted" + + @db0.memo @dataclass class MemoUnprotectedFieldsClass: @@ -390,6 +407,25 @@ def test_protected_field_getter_rejects_invalid_negative_account_id(db0_fixture) _ = obj.name +def test_protected_field_getter_enforces_read_access_for_auto_generated_fields(db0_fixture): + account_id = ContextVar("protected_auto_generated_account_id") + MemoProtectedAutoGeneratedInitVarClass("materializer", include_generated=True) + obj = MemoProtectedAutoGeneratedInitVarClass("alpha") + db0.set_field_access(MemoProtectedAutoGeneratedInitVarClass, 101, (FieldAccess.READ,), "name") + db0.set_field_access(MemoProtectedAutoGeneratedInitVarClass, 202, (FieldAccess.READ,), "generated_value") + db0._init_data_masking(account_id) + + account_id.set(101) + assert obj.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = obj.generated_value + + account_id.set(202) + assert obj.generated_value == "auto-generated" + with pytest.raises(PermissionError, match="read"): + _ = obj.name + + def test_protected_field_getter_applies_complex_masks_for_mixed_field_types(db0_fixture): account_id = ContextVar("protected_complex_account_id") obj = MemoProtectedComplexFieldsClass(None, False, 7, "alpha", b"payload", ("x", 11)) diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index 6369df01..1ee36ca5 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -441,11 +441,14 @@ namespace db0::python auto member_loc = memo_obj->ext().findField(attr_name); member = memo_obj->ext().tryGet(member_loc, &is_auto_generated); - if (member.get() && !is_auto_generated) { + if (member.get()) { auto masked = checkProtectedFieldReadAccess(memo_obj, member_loc); if (masked || PyErr_Occurred()) { return masked; } + } + + if (member.get() && !is_auto_generated) { return member.steal(); } }