diff --git a/python_tests/test_memo_protect_fields.py b/python_tests/test_memo_protect_fields.py index 16205629..5eff15ca 100644 --- a/python_tests/test_memo_protect_fields.py +++ b/python_tests/test_memo_protect_fields.py @@ -100,6 +100,60 @@ class MemoImplicitlyProtectedDerivedFieldsClass(MemoProtectedDerivedFieldsClass) derived_value: float +@db0.memo(immutable=True, protect_fields=True) +@dataclass +class MemoProtectedImmutableFieldsClass: + name: str + value: int + + +@db0.memo(immutable=True, protect_fields=True) +@dataclass +class MemoProtectedImmutableBaseFieldsClass: + base_name: str + + +@db0.memo(immutable=True) +@dataclass +class MemoProtectedImmutableDerivedFieldsClass(MemoProtectedImmutableBaseFieldsClass): + name: str + value: int + + +@db0.memo(immutable=True, protect_fields=True) +class MemoProtectedEmbeddedPayload: + def __init__(self, name, value): + self.name = name + self.value = value + + +@db0.memo(immutable=True) +class MemoProtectedEmbeddedHolder: + def __init__(self, name, value, label="holder"): + self.child = MemoProtectedEmbeddedPayload(name, value) + self.label = label + + +@db0.memo(immutable=True, protect_fields=True) +class MemoProtectedEmbeddedBase: + def __init__(self, base_name): + self.base_name = base_name + + +@db0.memo(immutable=True) +class MemoProtectedEmbeddedDerived(MemoProtectedEmbeddedBase): + def __init__(self, base_name, value, secret="hidden"): + super().__init__(base_name) + self.value = value + self.secret = secret + + +@db0.memo(immutable=True) +class MemoProtectedEmbeddedDerivedHolder: + def __init__(self, base_name, value, secret="hidden"): + self.child = MemoProtectedEmbeddedDerived(base_name, value, secret) + + def get_memo_class_object(obj): return db0.get_memo_class(obj).get_class() @@ -1034,3 +1088,120 @@ def read_as(account): assert account_101.result() == (None, True) assert account_202.result() == (7, "alpha") + + +def test_protected_immutable_field_getter_enforces_read_access(db0_fixture): + account_id = ContextVar("protected_immutable_account_id") + obj = MemoProtectedImmutableFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedImmutableFieldsClass, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = obj.value + + +def test_protected_immutable_field_getter_inherits_base_class_masks(db0_fixture): + account_id = ContextVar("protected_immutable_inherited_account_id") + obj = MemoProtectedImmutableDerivedFieldsClass("base", "alpha", 1) + db0.set_field_access( + MemoProtectedImmutableBaseFieldsClass, + 123, + (FieldAccess.READ,), + "base_name", + ) + db0.set_field_access( + MemoProtectedImmutableDerivedFieldsClass, + 123, + (FieldAccess.READ,), + "name", + ) + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.base_name == "base" + assert obj.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = obj.value + + +def test_protected_embedded_field_getter_enforces_read_access(db0_fixture): + account_id = ContextVar("protected_embedded_account_id") + obj = db0.materialized(MemoProtectedEmbeddedHolder("alpha", 1)) + db0.set_field_access(MemoProtectedEmbeddedPayload, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.child.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = obj.child.value + + +def test_protected_embedded_field_getter_returns_missing_value_placeholder(db0_fixture): + account_id = ContextVar("protected_embedded_placeholder_account_id") + missing_value = object() + obj = db0.materialized(MemoProtectedEmbeddedHolder("alpha", 1)) + db0.set_field_access(MemoProtectedEmbeddedPayload, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id, missing_value_placeholder=missing_value) + account_id.set(123) + + assert obj.child.value is missing_value + + +def test_protected_embedded_field_getter_uses_debug_account_rules(db0_fixture): + account_id = ContextVar("protected_embedded_debug_account_id") + obj = db0.materialized(MemoProtectedEmbeddedHolder("alpha", 1)) + db0._init_data_masking(account_id, mode="DEBUG") + + account_id.set(-1) + assert obj.child.name == "alpha" + assert obj.child.value == 1 + + account_id.set(-3) + with pytest.raises(RuntimeError, match="invalid.*account"): + _ = obj.child.name + + +def test_protected_embedded_field_getter_inherits_base_class_masks(db0_fixture): + account_id = ContextVar("protected_embedded_inherited_account_id") + obj = db0.materialized(MemoProtectedEmbeddedDerivedHolder("base", 7, "secret")) + db0.set_field_access(MemoProtectedEmbeddedBase, 123, (FieldAccess.READ,), "base_name") + db0.set_field_access(MemoProtectedEmbeddedDerived, 123, (FieldAccess.READ,), "value") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.child.base_name == "base" + assert obj.child.value == 7 + with pytest.raises(PermissionError, match="read"): + _ = obj.child.secret + + +def test_protected_embedded_field_getter_survives_reopen(db0_fixture): + account_id = ContextVar("protected_embedded_reopen_account_id") + obj = db0.materialized(MemoProtectedEmbeddedHolder("alpha", 1)) + db0.tags(obj).add("keep-protected-embedded") + obj_id = db0.uuid(obj) + db0.set_field_access(MemoProtectedEmbeddedPayload, 123, (FieldAccess.READ,), "name") + db0.commit() + + db0.close() + db0.init(DB0_DIR) + db0.open("my-test-prefix") + db0._init_data_masking(account_id) + account_id.set(123) + + reopened = db0.fetch(obj_id) + assert reopened.child.name == "alpha" + with pytest.raises(PermissionError, match="read"): + _ = reopened.child.value + + +def test_protected_embedded_dict_omits_denied_fields(db0_fixture): + account_id = ContextVar("protected_embedded_dict_account_id") + obj = db0.materialized(MemoProtectedEmbeddedHolder("alpha", 1)) + db0.set_field_access(MemoProtectedEmbeddedPayload, 123, (FieldAccess.READ,), "name") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.child.__dict__ == {"name": "alpha"} diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index 8c91f38d..09de6f32 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -9,11 +9,11 @@ #include #include "PySnapshot.hpp" #include "PyInternalAPI.hpp" +#include "ProtectedFieldAccess.hpp" #include "Utils.hpp" #include "Types.hpp" #include "Migration.hpp" #include "PyHash.hpp" -#include "DataMasking.hpp" #include #include #include @@ -387,137 +387,6 @@ namespace db0::python return !(attr_name[0] == '_' && attr_name[1] == 'X' && attr_name[2] == '_' && attr_name[3] == '_'); } - template - bool getProtectedFieldAccessContext(MemoImplT *memo_obj, const db0::object_model::Class *&type, - std::shared_ptr &masking_state, long long &account_id) - { - type = &memo_obj->ext().getType(); - if (!type->isProtectFields()) { - return false; - } - - if (!Settings::m_data_masking_enabled) { - PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); - return false; - } - - masking_state = memo_obj->ext().getFixture()->getMaskingState(); - if (!masking_state) { - PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); - return false; - } - - PyObject *pyAccountId = nullptr; - if (PyContextVar_Get(masking_state->contextVar, NULL, &pyAccountId) < 0) { - PyErr_SetString(PyExc_RuntimeError, "unable to read data masking account context"); - return false; - } - if (!pyAccountId) { - PyErr_SetString(PyExc_RuntimeError, "data masking account context is not set"); - return false; - } - - account_id = PyLong_AsLongLong(pyAccountId); - Py_DECREF(pyAccountId); - if (PyErr_Occurred()) { - PyErr_SetString(PyExc_TypeError, "data masking account context must be an int"); - return false; - } - - if (account_id < -2) { - PyErr_SetString(PyExc_RuntimeError, "invalid data masking account id"); - return false; - } - - return true; - } - - template - bool checkProtectedFieldAccess(MemoImplT *memo_obj, db0::object_model::FieldMaskOptions access_option, - const db0::object_model::MemberLoc &member_loc, const char *field_name = nullptr) - { - auto &memo_type = memo_obj->ext().getType(); - if (!memo_type.isProtectFields()) { - return true; - } - - const db0::object_model::Class *type = nullptr; - std::shared_ptr masking_state; - long long account_id = 0; - if (!getProtectedFieldAccessContext(memo_obj, type, masking_state, account_id)) { - return !PyErr_Occurred(); - } - - if (masking_state->mode == DataMaskingMode::DEBUG) { - if (account_id == -2) { - return true; - } - if (account_id == -1) { - return access_option == db0::object_model::FieldMaskOptions::READ; - } - } - - if (account_id < 0) { - return false; - } - - auto mask = type->tryGetFieldAccessByMemberLoc(static_cast(account_id), member_loc); - if (!mask && field_name) { - mask = type->tryGetFieldAccessByName(static_cast(account_id), field_name); - } - return mask && (*mask)[access_option]; - } - - void setProtectedFieldPermissionError(db0::object_model::FieldMaskOptions access_option) - { - std::stringstream message; - message << "data masking denies " << access_option - << " access to protected field"; - PyErr_SetString(PyExc_PermissionError, message.str().c_str()); - } - - template - PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc) - { - auto accessOption = db0::object_model::FieldMaskOptions::READ; - if (checkProtectedFieldAccess(memo_obj, accessOption, member_loc)) { - return nullptr; - } - if (PyErr_Occurred()) { - return nullptr; - } - - auto masking_state = memo_obj->ext().getFixture()->getMaskingState(); - if (masking_state->hasMissingValuePlaceholder) { - Py_INCREF(masking_state->missingValuePlaceholder); - return masking_state->missingValuePlaceholder; - } - - setProtectedFieldPermissionError(accessOption); - return nullptr; - } - - template - bool checkProtectedFieldMutateAccess(MemoImplT *memo_obj, db0::object_model::FieldMaskOptions accessOption, - const char *fieldName) - { - auto &memo_type = memo_obj->ext().getType(); - if (!memo_type.isProtectFields()) { - return true; - } - - auto memberLoc = memo_type.findField(fieldName); - if (checkProtectedFieldAccess(memo_obj, accessOption, memberLoc, fieldName)) { - return true; - } - if (PyErr_Occurred()) { - return false; - } - - setProtectedFieldPermissionError(accessOption); - return false; - } - template PyObject *tryMemoObject_getattro(MemoImplT *memo_obj, PyObject *attr) { @@ -541,7 +410,9 @@ namespace db0::python member = memo_obj->ext().tryGet(member_loc, &is_auto_generated); if (member.get()) { - auto masked = checkProtectedFieldReadAccess(memo_obj, member_loc); + auto masked = checkProtectedFieldReadAccess( + memo_obj->ext().getType(), memo_obj->ext().getFixture(), member_loc + ); if (masked || PyErr_Occurred()) { return masked; } @@ -586,7 +457,8 @@ namespace db0::python if (isPersistentAttrName(attr_name)) { try { if (!value && !checkProtectedFieldMutateAccess( - self, db0::object_model::FieldMaskOptions::DELETE, attr_name + self->ext().getType(), self->ext().getFixture(), + db0::object_model::FieldMaskOptions::DELETE, attr_name )) { return -1; } @@ -609,7 +481,7 @@ namespace db0::python if ( (memberExists || Settings::m_data_masking_enabled) && !checkProtectedFieldMutateAccess( - self, accessOption, attr_name + self->ext().getType(), self->ext().getFixture(), accessOption, attr_name ) ) { return -1; @@ -622,7 +494,8 @@ namespace db0::python value && Settings::m_data_masking_enabled && !checkProtectedFieldMutateAccess( - self, db0::object_model::FieldMaskOptions::CREATE, attr_name + self->ext().getType(), self->ext().getFixture(), + db0::object_model::FieldMaskOptions::CREATE, attr_name ) ) { return -1; @@ -1431,7 +1304,8 @@ namespace db0::python if (memo_type.isProtectFields()) { auto member_loc = memo_obj->ext().findField(key.c_str()); if (!checkProtectedFieldAccess( - memo_obj, db0::object_model::FieldMaskOptions::READ, member_loc, key.c_str() + memo_type, memo_obj->ext().getFixture(), db0::object_model::FieldMaskOptions::READ, + member_loc, key.c_str() )) { if (PyErr_Occurred()) { has_error = true; diff --git a/src/dbzero/bindings/python/Memo.hpp b/src/dbzero/bindings/python/Memo.hpp index e53797d0..21322489 100644 --- a/src/dbzero/bindings/python/Memo.hpp +++ b/src/dbzero/bindings/python/Memo.hpp @@ -37,6 +37,8 @@ namespace db0::python void MemoType_get_info(PyTypeObject *type, PyObject *dict); void MemoType_close(PyTypeObject *type); + + bool isPersistentAttrName(const char *attr_name); template PyObject *MemoObject_set_prefix(MemoImplT *, const char *prefix_name); diff --git a/src/dbzero/bindings/python/ProtectedFieldAccess.cpp b/src/dbzero/bindings/python/ProtectedFieldAccess.cpp new file mode 100644 index 00000000..e4c6bc56 --- /dev/null +++ b/src/dbzero/bindings/python/ProtectedFieldAccess.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2025 DBZero Software sp. z o.o. + +#include + +#include +#include +#include +#include + +#include +#include + +namespace db0::python +{ + using namespace db0::object_model; + + namespace + { + bool getProtectedFieldAccessContext( + const Class &type, const db0::swine_ptr &fixture, + std::shared_ptr &maskingState, long long &accountId + ) + { + if (!type.isProtectFields()) { + return false; + } + + if (!Settings::m_data_masking_enabled) { + PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); + return false; + } + + maskingState = fixture->getMaskingState(); + if (!maskingState) { + PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); + return false; + } + + PyObject *pyAccountId = nullptr; + if (PyContextVar_Get(maskingState->contextVar, NULL, &pyAccountId) < 0) { + PyErr_SetString(PyExc_RuntimeError, "unable to read data masking account context"); + return false; + } + if (!pyAccountId) { + PyErr_SetString(PyExc_RuntimeError, "data masking account context is not set"); + return false; + } + + accountId = PyLong_AsLongLong(pyAccountId); + Py_DECREF(pyAccountId); + if (PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "data masking account context must be an int"); + return false; + } + + if (accountId < -2) { + PyErr_SetString(PyExc_RuntimeError, "invalid data masking account id"); + return false; + } + + return true; + } + + void setProtectedFieldPermissionError(FieldMaskOptions accessOption) + { + std::stringstream message; + message << "data masking denies " << accessOption + << " access to protected field"; + PyErr_SetString(PyExc_PermissionError, message.str().c_str()); + } + } + + bool checkProtectedFieldAccess( + const Class &type, const db0::swine_ptr &fixture, FieldMaskOptions accessOption, + const MemberLoc &memberLoc, const char *fieldName + ) + { + if (!type.isProtectFields()) { + return true; + } + + std::shared_ptr maskingState; + long long accountId = 0; + if (!getProtectedFieldAccessContext(type, fixture, maskingState, accountId)) { + return !PyErr_Occurred(); + } + + if (maskingState->mode == DataMaskingMode::DEBUG) { + if (accountId == -2) { + return true; + } + if (accountId == -1) { + return accessOption == FieldMaskOptions::READ; + } + } + + if (accountId < 0) { + return false; + } + + auto mask = type.tryGetFieldAccessByMemberLoc(static_cast(accountId), memberLoc); + if (!mask && fieldName) { + mask = type.tryGetFieldAccessByName(static_cast(accountId), fieldName); + } + return mask && (*mask)[accessOption]; + } + + PyObject *checkProtectedFieldReadAccess( + const Class &type, const db0::swine_ptr &fixture, const MemberLoc &memberLoc, const char *fieldName + ) + { + auto accessOption = FieldMaskOptions::READ; + if (checkProtectedFieldAccess(type, fixture, accessOption, memberLoc, fieldName)) { + return nullptr; + } + if (PyErr_Occurred()) { + return nullptr; + } + + auto maskingState = fixture->getMaskingState(); + if (maskingState->hasMissingValuePlaceholder) { + Py_INCREF(maskingState->missingValuePlaceholder); + return maskingState->missingValuePlaceholder; + } + + setProtectedFieldPermissionError(accessOption); + return nullptr; + } + + bool checkProtectedFieldMutateAccess( + const Class &type, const db0::swine_ptr &fixture, FieldMaskOptions accessOption, const char *fieldName + ) + { + if (!type.isProtectFields()) { + return true; + } + + auto memberLoc = type.findField(fieldName); + if (checkProtectedFieldAccess(type, fixture, accessOption, memberLoc, fieldName)) { + return true; + } + if (PyErr_Occurred()) { + return false; + } + + setProtectedFieldPermissionError(accessOption); + return false; + } +} diff --git a/src/dbzero/bindings/python/ProtectedFieldAccess.hpp b/src/dbzero/bindings/python/ProtectedFieldAccess.hpp new file mode 100644 index 00000000..c2d738e1 --- /dev/null +++ b/src/dbzero/bindings/python/ProtectedFieldAccess.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: LGPL-2.1-or-later +// Copyright (c) 2025 DBZero Software sp. z o.o. + +#pragma once + +#include + +#include +#include +#include + +namespace db0 +{ + class Fixture; +} + +namespace db0::object_model +{ + class Class; +} + +namespace db0::python +{ + bool checkProtectedFieldAccess( + const db0::object_model::Class &type, const db0::swine_ptr &fixture, + db0::object_model::FieldMaskOptions accessOption, + const db0::object_model::MemberLoc &memberLoc, const char *fieldName = nullptr + ); + + PyObject *checkProtectedFieldReadAccess( + const db0::object_model::Class &type, const db0::swine_ptr &fixture, + const db0::object_model::MemberLoc &memberLoc, const char *fieldName = nullptr + ); + + bool checkProtectedFieldMutateAccess( + const db0::object_model::Class &type, const db0::swine_ptr &fixture, + db0::object_model::FieldMaskOptions accessOption, const char *fieldName + ); +} diff --git a/src/dbzero/bindings/python/embedded/EmbeddedObject.cpp b/src/dbzero/bindings/python/embedded/EmbeddedObject.cpp index 8825f378..3fda8fbe 100644 --- a/src/dbzero/bindings/python/embedded/EmbeddedObject.cpp +++ b/src/dbzero/bindings/python/embedded/EmbeddedObject.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -119,9 +120,8 @@ namespace db0::python ); } - ObjectSharedPtr tryGetMember(EmbeddedObjectRef &embeddedRef, const char *attrName) + ObjectSharedPtr tryGetMember(EmbeddedObjectRef &embeddedRef, const MemberLoc &memberLoc) { - auto memberLoc = embeddedRef.type().findField(attrName); if (!memberLoc.first) { return {}; } @@ -173,7 +173,7 @@ namespace db0::python return result; } - PyObject *tryEmbeddedObjectGetAttr(EmbeddedObject *self, PyObject *attr) + PyObject *tryEmbeddedRefGetAttr(PyObject *self, EmbeddedObjectRef &embeddedRef, PyObject *attr) { const char *attrName = PyUnicode_AsUTF8(attr); if (!attrName) { @@ -181,10 +181,19 @@ namespace db0::python return nullptr; } - if (!(attrName[0] == '_' && attrName[1] == 'X' && attrName[2] == '_' && attrName[3] == '_')) { - auto fixture = getRootFixture(self->ext().rootObject()); + if (isPersistentAttrName(attrName)) { + auto fixture = getRootFixture(embeddedRef.rootObject()); fixture->refreshIfUpdated(); - auto member = tryGetMember(self->modifyExt(), attrName); + auto memberLoc = embeddedRef.type().findField(attrName); + if (memberLoc.first) { + auto masked = checkProtectedFieldReadAccess( + embeddedRef.type(), fixture, memberLoc, attrName + ); + if (masked || PyErr_Occurred()) { + return masked; + } + } + auto member = tryGetMember(embeddedRef, memberLoc); if (member.get()) { return member.steal(); } @@ -196,28 +205,12 @@ namespace db0::python PyObject *PyAPI_EmbeddedObject_getattro(EmbeddedObject *self, PyObject *attr) { PY_API_FUNC - return runSafe(tryEmbeddedObjectGetAttr, self, attr); + return runSafe(tryEmbeddedRefGetAttr, reinterpret_cast(self), self->modifyExt(), attr); } PyObject *tryEmbeddedMemoGetAttr(MemoImmutableObject *self, PyObject *attr) { - const char *attrName = PyUnicode_AsUTF8(attr); - if (!attrName) { - PyErr_SetString(PyExc_AttributeError, "Invalid attribute name"); - return nullptr; - } - - if (!(attrName[0] == '_' && attrName[1] == 'X' && attrName[2] == '_' && attrName[3] == '_')) { - auto &embeddedRef = embeddedMemoRef(self); - auto fixture = getRootFixture(embeddedRef.rootObject()); - fixture->refreshIfUpdated(); - auto member = tryGetMember(embeddedRef, attrName); - if (member.get()) { - return member.steal(); - } - } - - return PyObject_GenericGetAttr(reinterpret_cast(self), attr); + return tryEmbeddedRefGetAttr(reinterpret_cast(self), embeddedMemoRef(self), attr); } PyObject *PyAPI_EmbeddedMemo_getattro(MemoImmutableObject *self, PyObject *attr) @@ -332,8 +325,16 @@ namespace db0::python } auto &type = embeddedMemoRef(self).type(); + auto fixture = getRootFixture(embeddedMemoRef(self).rootObject()); for (const auto &name: getEmbeddedMemberNames(embeddedMemoRef(self).embeddedObject(), type)) { - auto value = tryGetMember(embeddedMemoRef(self), name.c_str()); + auto memberLoc = type.findField(name.c_str()); + if (!checkProtectedFieldAccess(type, fixture, FieldMaskOptions::READ, memberLoc, name.c_str())) { + if (PyErr_Occurred()) { + return nullptr; + } + continue; + } + auto value = tryGetMember(embeddedMemoRef(self), memberLoc); if (!value.get()) { continue; }