diff --git a/python_tests/test_memo_protect_fields.py b/python_tests/test_memo_protect_fields.py index c0bef517..16205629 100644 --- a/python_tests/test_memo_protect_fields.py +++ b/python_tests/test_memo_protect_fields.py @@ -577,7 +577,7 @@ def test_protected_field_getter_requires_read_access(db0_fixture): _ = obj.value -def test_protected_field_getter_does_not_inherit_base_class_masks(db0_fixture): +def test_protected_field_getter_inherits_base_class_masks(db0_fixture): account_id = ContextVar("protected_inheritance_account_id") obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5) db0.set_field_access(MemoProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "base_value", "name") @@ -585,15 +585,64 @@ def test_protected_field_getter_does_not_inherit_base_class_masks(db0_fixture): db0._init_data_masking(account_id) account_id.set(123) + assert obj.base_value == "base" + assert obj.name == "alpha" assert obj.derived_value == 1.5 - with pytest.raises(PermissionError, match="read"): - _ = obj.base_value - with pytest.raises(PermissionError, match="read"): - _ = obj.name with pytest.raises(PermissionError, match="read"): _ = obj.value +def test_protected_field_create_inherits_base_class_masks(db0_fixture): + account_id = ContextVar("protected_inheritance_create_account_id") + obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5) + db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "extra") + db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.CREATE, FieldAccess.READ), "extra") + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="create"): + obj.extra = "denied" + + account_id.set(202) + obj.extra = "allowed" + assert obj.extra == "allowed" + + +def test_protected_field_update_inherits_base_class_masks(db0_fixture): + account_id = ContextVar("protected_inheritance_update_account_id") + obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5) + db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "name") + db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.READ, FieldAccess.UPDATE), "name") + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="update"): + obj.name = "denied" + assert obj.name == "alpha" + + account_id.set(202) + obj.name = "allowed" + assert obj.name == "allowed" + + +def test_protected_field_delete_inherits_base_class_masks(db0_fixture): + account_id = ContextVar("protected_inheritance_delete_account_id") + obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5) + db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "name") + db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.DELETE,), "name") + db0._init_data_masking(account_id) + + account_id.set(101) + with pytest.raises(PermissionError, match="delete"): + del obj.name + assert obj.name == "alpha" + + account_id.set(202) + del obj.name + with pytest.raises(AttributeError): + _ = obj.name + + def test_protected_field_getter_uses_derived_class_own_masks(db0_fixture): account_id = ContextVar("protected_inheritance_allow_override_account_id") obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5) @@ -686,7 +735,53 @@ class DerivedAfter(BaseAfter): obj = db0.fetch(DerivedAfter, obj_id) assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True + with pytest.raises(RuntimeError, match="data masking"): + _ = obj.value db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value") + account_id = ContextVar("protected_base_enabled_later_account_id") + db0._init_data_masking(account_id) + account_id.set(123) + assert obj.value == 1 + + +def test_existing_derived_instance_inherits_base_field_access_enabled_after_materialization(db0_fixture): + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-base") + @dataclass + class BaseBefore: + base_value: str + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-derived") + @dataclass + class DerivedBefore(BaseBefore): + value: int + + obj = DerivedBefore("base", 1) + obj_id = db0.uuid(obj) + db0.commit() + + db0.close() + db0.init(DB0_DIR) + db0.open("my-test-prefix") + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-base", protect_fields=True) + @dataclass + class BaseAfter: + base_value: str + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-derived") + @dataclass + class DerivedAfter(BaseAfter): + value: int + + obj = db0.fetch(DerivedAfter, obj_id) + db0.set_field_access(BaseAfter, 123, (FieldAccess.READ,), "base_value") + account_id = ContextVar("protected_base_access_enabled_later_account_id") + db0._init_data_masking(account_id) + account_id.set(123) + + assert obj.base_value == "base" + with pytest.raises(PermissionError, match="read"): + _ = obj.value def test_derived_class_stops_inheriting_protect_fields_after_base_reset(db0_fixture): @@ -722,11 +817,49 @@ class DerivedAfter(BaseAfter): obj = db0.fetch(DerivedAfter, obj_id) db0.reset_protect_fields(BaseAfter) + assert obj.value == 1 assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False with pytest.raises(RuntimeError, match="protected fields"): db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value") +def test_existing_derived_instance_stops_using_base_field_access_after_base_reset(db0_fixture): + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-base", protect_fields=True) + @dataclass + class BaseBefore: + base_value: str + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-derived") + @dataclass + class DerivedBefore(BaseBefore): + value: int + + obj = DerivedBefore("base", 1) + obj_id = db0.uuid(obj) + db0.commit() + + db0.close() + db0.init(DB0_DIR) + db0.open("my-test-prefix") + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-base", protect_fields=False) + @dataclass + class BaseAfter: + base_value: str + + @db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-derived") + @dataclass + class DerivedAfter(BaseAfter): + value: int + + obj = db0.fetch(DerivedAfter, obj_id) + db0.reset_protect_fields(BaseAfter) + + assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False + assert obj.base_value == "base" + assert obj.value == 1 + + def test_protected_field_getter_returns_missing_value_placeholder(db0_fixture): account_id = ContextVar("protected_placeholder_account_id") missing_value = object() diff --git a/src/dbzero/bindings/python/Memo.cpp b/src/dbzero/bindings/python/Memo.cpp index 9f98b9af..93501c62 100644 --- a/src/dbzero/bindings/python/Memo.cpp +++ b/src/dbzero/bindings/python/Memo.cpp @@ -434,9 +434,9 @@ namespace db0::python return false; } - auto mask = type->tryGetFieldAccess(static_cast(account_id), member_loc); + auto mask = type->tryGetFieldAccessByMemberLoc(static_cast(account_id), member_loc); if (!mask && field_name) { - mask = type->tryGetFieldAccess(static_cast(account_id), field_name); + mask = type->tryGetFieldAccessByName(static_cast(account_id), field_name); } return mask && (*mask)[access_option]; } diff --git a/src/dbzero/object_model/class/Class.cpp b/src/dbzero/object_model/class/Class.cpp index dbb83a32..6e2cd832 100644 --- a/src/dbzero/object_model/class/Class.cpp +++ b/src/dbzero/object_model/class/Class.cpp @@ -433,27 +433,32 @@ namespace db0::object_model } } - std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const Member &member) const + std::optional Class::tryGetFieldAccessByMember(std::uint64_t account_id, const Member &member) const { if (!isProtectFields()) { return {}; } - if (!hasFieldSafe()) { - return {}; + if (hasFieldSafe()) { + auto &field_safe = getFieldSafe(); + auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); + auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(member.m_field_id); + if (maybe_offset && field_mask) { + auto mask = field_mask->getAssignedMask(*maybe_offset); + if (mask && !mask->none()) { + return mask; + } + } } - auto &field_safe = getFieldSafe(); - auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); - auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(member.m_field_id); - if (maybe_offset && field_mask) { - return field_mask->getAssignedMask(*maybe_offset); + if (m_base_class_ptr && m_base_class_ptr->isProtectFields()) { + return m_base_class_ptr->tryGetFieldAccess(account_id, member.m_name.c_str()); } return {}; } - std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &member_loc) const + std::optional Class::tryGetFieldAccessByMemberLoc(std::uint64_t account_id, const MemberLoc &member_loc) const { const auto &member_id = member_loc.first; if (!member_id) { @@ -462,29 +467,47 @@ namespace db0::object_model auto member = tryGetMember(member_id.primary().first); if (member) { - return tryGetFieldAccess(account_id, *member); + return tryGetFieldAccessByMember(account_id, *member); } return {}; } std::optional Class::tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const + { + auto member = tryGetMember(field_name); + if (member) { + return tryGetFieldAccessByMember(account_id, *member); + } + + return tryGetFieldAccessByName(account_id, field_name); + } + + std::optional Class::tryGetFieldAccessByName(std::uint64_t account_id, const char *field_name) const { if (!isProtectFields()) { return {}; } - auto &field_safe = getFieldSafe(); - auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name); - if (!maybe_offset) { - return {}; + if (hasFieldSafe()) { + auto &field_safe = getFieldSafe(); + auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); + if (field_mask) { + auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name); + if (maybe_offset) { + auto mask = field_mask->getAssignedMask(*maybe_offset); + if (mask && !mask->none()) { + return mask; + } + } + } + } - auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id); - if (!field_mask) { - return {}; + if (m_base_class_ptr && m_base_class_ptr->isProtectFields()) { + return m_base_class_ptr->tryGetFieldAccessByName(account_id, field_name); } - return field_mask->getAssignedMask(*maybe_offset); + return {}; } std::vector > Class::getFieldAccess(std::uint64_t account_id) const diff --git a/src/dbzero/object_model/class/Class.hpp b/src/dbzero/object_model/class/Class.hpp index 27bc5212..696620db 100644 --- a/src/dbzero/object_model/class/Class.hpp +++ b/src/dbzero/object_model/class/Class.hpp @@ -184,9 +184,10 @@ DB0_PACKED_END const FieldSafe &getFieldSafe() const; void setFieldAccess(const std::vector &account_ids, FieldMaskFlags mask, const std::vector &field_names); - std::optional tryGetFieldAccess(std::uint64_t account_id, const Member &) const; - std::optional tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &) const; + std::optional tryGetFieldAccessByMember(std::uint64_t account_id, const Member &) const; + std::optional tryGetFieldAccessByMemberLoc(std::uint64_t account_id, const MemberLoc &) const; std::optional tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const; + std::optional tryGetFieldAccessByName(std::uint64_t account_id, const char *field_name) const; std::vector > getFieldAccess(std::uint64_t account_id) const; std::uint32_t getFieldOffsetRange() const;