diff --git a/python_tests/test_memo_protect_fields.py b/python_tests/test_memo_protect_fields.py index cc847d64..d4d2a4a7 100644 --- a/python_tests/test_memo_protect_fields.py +++ b/python_tests/test_memo_protect_fields.py @@ -436,6 +436,71 @@ def test_protected_field_getter_requires_initialized_data_masking(db0_fixture): _ = obj.name +def test_protected_field_create_from_init_requires_create_access(db0_fixture): + account_id = ContextVar("protected_create_init_account_id") + MemoProtectedFieldsClass("materializer", 0) + db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.CREATE,), "name") + db0.set_field_access( + MemoProtectedFieldsClass, + 202, + (FieldAccess.CREATE,), + "name", + "value", + ) + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="create"): + MemoProtectedFieldsClass("alpha", 1) + + account_id.set(202) + MemoProtectedFieldsClass("beta", 2) + + +def test_protected_field_create_from_init_uses_predeclared_name_masks(db0_fixture): + account_id = ContextVar("protected_create_predeclared_account_id") + + @db0.memo(protect_fields=True) + class InitFutureField: + def __init__(self, include_fields=False): + if include_fields: + self.future = "allowed" + self.denied = "blocked" + + InitFutureField() + db0.set_field_access(InitFutureField, 101, (FieldAccess.CREATE,), "future") + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="create"): + InitFutureField(include_fields=True) + + +def test_protected_field_append_requires_create_access(db0_fixture): + account_id = ContextVar("protected_create_append_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.UPDATE,), "extra") + db0.set_field_access(MemoProtectedFieldsClass, 202, (FieldAccess.CREATE,), "extra") + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="create"): + obj.extra = "denied" + + account_id.set(202) + obj.extra = "allowed" + + +def test_protected_field_update_does_not_require_create_access(db0_fixture): + account_id = ContextVar("protected_update_existing_account_id") + obj = MemoProtectedFieldsClass("alpha", 1) + db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.UPDATE,), "name") + db0._init_data_masking(account_id) + + account_id.set(101) + obj.name = "beta" + + def test_protected_field_getter_requires_read_access(db0_fixture): account_id = ContextVar("protected_read_account_id") obj = MemoProtectedFieldsClass("alpha", 1) diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index 3d573133..ebed082d 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -361,56 +361,97 @@ namespace db0::python } template - PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc) + bool getProtectedFieldAccessContext(MemoImplT *memo_obj, const db0::object_model::Class *&type, + std::shared_ptr &masking_state, long long &account_id) { - auto &type = memo_obj->ext().getType(); - if (!type.isProtectFields()) { - return nullptr; + type = &memo_obj->ext().getType(); + if (!type->isProtectFields()) { + return false; } if (!Settings::m_data_masking_enabled) { PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); - return nullptr; + return false; } - auto masking_state = memo_obj->ext().getFixture()->getMaskingState(); + masking_state = memo_obj->ext().getFixture()->getMaskingState(); if (!masking_state) { PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields"); - return nullptr; + return false; } PyObject *pyAccountId = nullptr; if (PyContextVar_Get(masking_state->contextVar, NULL, &pyAccountId) < 0) { PyErr_SetString(PyExc_RuntimeError, "unable to read data masking account context"); - return nullptr; + return false; } if (!pyAccountId) { PyErr_SetString(PyExc_RuntimeError, "data masking account context is not set"); - return nullptr; + return false; } - auto accountId = PyLong_AsLongLong(pyAccountId); + account_id = PyLong_AsLongLong(pyAccountId); Py_DECREF(pyAccountId); if (PyErr_Occurred()) { PyErr_SetString(PyExc_TypeError, "data masking account context must be an int"); - return nullptr; + return false; } - bool canRead = false; - if (accountId < -2) { + if (account_id < -2) { PyErr_SetString(PyExc_RuntimeError, "invalid data masking account id"); - return nullptr; - } else if (accountId == -1 || accountId == -2) { - canRead = masking_state->mode == DataMaskingMode::DEBUG; - } else if (accountId >= 0) { - auto mask = type.tryGetFieldAccess(static_cast(accountId), member_loc); - canRead = mask && (*mask)[db0::object_model::FieldMaskOptions::READ]; + return false; + } + + return true; + } + + template + bool checkProtectedFieldAccess(MemoImplT *memo_obj, db0::object_model::FieldMaskOptions access_option, + const db0::object_model::MemberLoc &member_loc, const char *field_name = nullptr) + { + auto &memo_type = memo_obj->ext().getType(); + if (!memo_type.isProtectFields()) { + return true; + } + + const db0::object_model::Class *type = nullptr; + std::shared_ptr masking_state; + long long account_id = 0; + if (!getProtectedFieldAccessContext(memo_obj, type, masking_state, account_id)) { + return !PyErr_Occurred(); + } + + if (masking_state->mode == DataMaskingMode::DEBUG) { + if (account_id == -2) { + return true; + } + if (account_id == -1) { + return access_option == db0::object_model::FieldMaskOptions::READ; + } } - if (canRead) { + if (account_id < 0) { + return false; + } + + auto mask = type->tryGetFieldAccess(static_cast(account_id), member_loc); + if (!mask && field_name) { + mask = type->tryGetFieldAccess(static_cast(account_id), field_name); + } + return mask && (*mask)[access_option]; + } + + template + PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc) + { + if (checkProtectedFieldAccess(memo_obj, db0::object_model::FieldMaskOptions::READ, member_loc)) { + return nullptr; + } + if (PyErr_Occurred()) { return nullptr; } + auto masking_state = memo_obj->ext().getFixture()->getMaskingState(); if (masking_state->hasMissingValuePlaceholder) { Py_INCREF(masking_state->missingValuePlaceholder); return masking_state->missingValuePlaceholder; @@ -420,6 +461,28 @@ namespace db0::python return nullptr; } + template + bool checkProtectedFieldCreateAccess(MemoImplT *memo_obj, const char *field_name) + { + auto &memo_type = memo_obj->ext().getType(); + if (!memo_type.isProtectFields() || !Settings::m_data_masking_enabled) { + return true; + } + + auto member_loc = memo_type.findField(field_name); + if (checkProtectedFieldAccess( + memo_obj, db0::object_model::FieldMaskOptions::CREATE, member_loc, field_name + )) { + return true; + } + if (PyErr_Occurred()) { + return false; + } + + PyErr_SetString(PyExc_PermissionError, "data masking denies create access to protected field"); + return false; + } + template PyObject *tryMemoObject_getattro(MemoImplT *memo_obj, PyObject *attr) { @@ -496,9 +559,16 @@ namespace db0::python auto maybe_type_id = PyToolkit::getTypeManager().tryGetTypeId(value); if (maybe_type_id) { if (self->ext().hasInstance()) { + auto member_loc = self->ext().findField(attr_name); + if (!self->ext().tryGet(member_loc) && !checkProtectedFieldCreateAccess(self, attr_name)) { + return -1; + } db0::FixtureLock lock(self->ext().getFixture()); self->modifyExt().set(lock, attr_name, *maybe_type_id, value); } else { + if (!checkProtectedFieldCreateAccess(self, attr_name)) { + return -1; + } // considered as a non-mutating operation self->ext().setPreInit(attr_name, *maybe_type_id, value); } diff --git a/src/dbzero/object_model/class/Class.cpp b/src/dbzero/object_model/class/Class.cpp index 768f8e01..dbb83a32 100644 --- a/src/dbzero/object_model/class/Class.cpp +++ b/src/dbzero/object_model/class/Class.cpp @@ -467,6 +467,26 @@ namespace db0::object_model return {}; } + std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const + { + if (!isProtectFields()) { + return {}; + } + + auto &field_safe = getFieldSafe(); + auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name); + if (!maybe_offset) { + return {}; + } + + auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); + if (!field_mask) { + return {}; + } + + return field_mask->getAssignedMask(*maybe_offset); + } + std::vector > Class::getFieldAccess(std::uint64_t account_id) const { if (!isProtectFields()) { diff --git a/src/dbzero/object_model/class/Class.hpp b/src/dbzero/object_model/class/Class.hpp index 99d0033c..27bc5212 100644 --- a/src/dbzero/object_model/class/Class.hpp +++ b/src/dbzero/object_model/class/Class.hpp @@ -186,6 +186,7 @@ DB0_PACKED_END const std::vector &field_names); std::optional tryGetFieldAccess(std::uint64_t account_id, const Member &) const; std::optional tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &) const; + std::optional tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const; std::vector > getFieldAccess(std::uint64_t account_id) const; std::uint32_t getFieldOffsetRange() const;