Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions python_tests/test_memo_protect_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,23 +577,72 @@ def test_protected_field_getter_requires_read_access(db0_fixture):
_ = obj.value


def test_protected_field_getter_does_not_inherit_base_class_masks(db0_fixture):
def test_protected_field_getter_inherits_base_class_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "base_value", "name")
db0.set_field_access(MemoImplicitlyProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "derived_value")
db0._init_data_masking(account_id)
account_id.set(123)

assert obj.base_value == "base"
assert obj.name == "alpha"
assert obj.derived_value == 1.5
with pytest.raises(PermissionError, match="read"):
_ = obj.base_value
with pytest.raises(PermissionError, match="read"):
_ = obj.name
with pytest.raises(PermissionError, match="read"):
_ = obj.value


def test_protected_field_create_inherits_base_class_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_create_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "extra")
db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.CREATE, FieldAccess.READ), "extra")
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="create"):
obj.extra = "denied"

account_id.set(202)
obj.extra = "allowed"
assert obj.extra == "allowed"


def test_protected_field_update_inherits_base_class_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_update_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "name")
db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.READ, FieldAccess.UPDATE), "name")
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="update"):
obj.name = "denied"
assert obj.name == "alpha"

account_id.set(202)
obj.name = "allowed"
assert obj.name == "allowed"


def test_protected_field_delete_inherits_base_class_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_delete_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 101, (FieldAccess.READ,), "name")
db0.set_field_access(MemoProtectedDerivedFieldsClass, 202, (FieldAccess.DELETE,), "name")
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="delete"):
del obj.name
assert obj.name == "alpha"

account_id.set(202)
del obj.name
with pytest.raises(AttributeError):
_ = obj.name


def test_protected_field_getter_uses_derived_class_own_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_allow_override_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
Expand Down Expand Up @@ -686,7 +735,53 @@ class DerivedAfter(BaseAfter):

obj = db0.fetch(DerivedAfter, obj_id)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
with pytest.raises(RuntimeError, match="data masking"):
_ = obj.value
db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value")
account_id = ContextVar("protected_base_enabled_later_account_id")
db0._init_data_masking(account_id)
account_id.set(123)
assert obj.value == 1


def test_existing_derived_instance_inherits_base_field_access_enabled_after_materialization(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-base")
@dataclass
class BaseBefore:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-derived")
@dataclass
class DerivedBefore(BaseBefore):
value: int

obj = DerivedBefore("base", 1)
obj_id = db0.uuid(obj)
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-base", protect_fields=True)
@dataclass
class BaseAfter:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-enabled-later-derived")
@dataclass
class DerivedAfter(BaseAfter):
value: int

obj = db0.fetch(DerivedAfter, obj_id)
db0.set_field_access(BaseAfter, 123, (FieldAccess.READ,), "base_value")
account_id = ContextVar("protected_base_access_enabled_later_account_id")
db0._init_data_masking(account_id)
account_id.set(123)

assert obj.base_value == "base"
with pytest.raises(PermissionError, match="read"):
_ = obj.value


def test_derived_class_stops_inheriting_protect_fields_after_base_reset(db0_fixture):
Expand Down Expand Up @@ -722,11 +817,49 @@ class DerivedAfter(BaseAfter):
obj = db0.fetch(DerivedAfter, obj_id)
db0.reset_protect_fields(BaseAfter)

assert obj.value == 1
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False
with pytest.raises(RuntimeError, match="protected fields"):
db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value")


def test_existing_derived_instance_stops_using_base_field_access_after_base_reset(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-base", protect_fields=True)
@dataclass
class BaseBefore:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-derived")
@dataclass
class DerivedBefore(BaseBefore):
value: int

obj = DerivedBefore("base", 1)
obj_id = db0.uuid(obj)
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-base", protect_fields=False)
@dataclass
class BaseAfter:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-access-reset-derived")
@dataclass
class DerivedAfter(BaseAfter):
value: int

obj = db0.fetch(DerivedAfter, obj_id)
db0.reset_protect_fields(BaseAfter)

assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False
assert obj.base_value == "base"
assert obj.value == 1


def test_protected_field_getter_returns_missing_value_placeholder(db0_fixture):
account_id = ContextVar("protected_placeholder_account_id")
missing_value = object()
Expand Down
4 changes: 2 additions & 2 deletions src/dbzero/bindings/python/Memo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ namespace db0::python
return false;
}

auto mask = type->tryGetFieldAccess(static_cast<std::uint64_t>(account_id), member_loc);
auto mask = type->tryGetFieldAccessByMemberLoc(static_cast<std::uint64_t>(account_id), member_loc);
if (!mask && field_name) {
mask = type->tryGetFieldAccess(static_cast<std::uint64_t>(account_id), field_name);
mask = type->tryGetFieldAccessByName(static_cast<std::uint64_t>(account_id), field_name);
}
return mask && (*mask)[access_option];
}
Expand Down
59 changes: 41 additions & 18 deletions src/dbzero/object_model/class/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,27 +433,32 @@ namespace db0::object_model
}
}

std::optional<FieldMaskFlags> Class::tryGetFieldAccess(std::uint64_t account_id, const Member &member) const
std::optional<FieldMaskFlags> Class::tryGetFieldAccessByMember(std::uint64_t account_id, const Member &member) const
{
if (!isProtectFields()) {
return {};
}

if (!hasFieldSafe()) {
return {};
if (hasFieldSafe()) {
auto &field_safe = getFieldSafe();
auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id);
auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(member.m_field_id);
if (maybe_offset && field_mask) {
auto mask = field_mask->getAssignedMask(*maybe_offset);
if (mask && !mask->none()) {
return mask;
}
}
}

auto &field_safe = getFieldSafe();
auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id);
auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(member.m_field_id);
if (maybe_offset && field_mask) {
return field_mask->getAssignedMask(*maybe_offset);
if (m_base_class_ptr && m_base_class_ptr->isProtectFields()) {
return m_base_class_ptr->tryGetFieldAccess(account_id, member.m_name.c_str());
}

return {};
}

std::optional<FieldMaskFlags> Class::tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &member_loc) const
std::optional<FieldMaskFlags> Class::tryGetFieldAccessByMemberLoc(std::uint64_t account_id, const MemberLoc &member_loc) const
{
const auto &member_id = member_loc.first;
if (!member_id) {
Expand All @@ -462,29 +467,47 @@ namespace db0::object_model

auto member = tryGetMember(member_id.primary().first);
if (member) {
return tryGetFieldAccess(account_id, *member);
return tryGetFieldAccessByMember(account_id, *member);
}
return {};
}

std::optional<FieldMaskFlags> Class::tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const
{
auto member = tryGetMember(field_name);
if (member) {
return tryGetFieldAccessByMember(account_id, *member);
}

return tryGetFieldAccessByName(account_id, field_name);
}

std::optional<FieldMaskFlags> Class::tryGetFieldAccessByName(std::uint64_t account_id, const char *field_name) const
{
if (!isProtectFields()) {
return {};
}

auto &field_safe = getFieldSafe();
auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name);
if (!maybe_offset) {
return {};
if (hasFieldSafe()) {
auto &field_safe = getFieldSafe();
auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id);
if (field_mask) {
auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name);
if (maybe_offset) {
auto mask = field_mask->getAssignedMask(*maybe_offset);
if (mask && !mask->none()) {
return mask;
}
}
}

}

auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id);
if (!field_mask) {
return {};
if (m_base_class_ptr && m_base_class_ptr->isProtectFields()) {
return m_base_class_ptr->tryGetFieldAccessByName(account_id, field_name);
}

return field_mask->getAssignedMask(*maybe_offset);
return {};
}

std::vector<std::pair<std::string, FieldMaskFlags> > Class::getFieldAccess(std::uint64_t account_id) const
Expand Down
5 changes: 3 additions & 2 deletions src/dbzero/object_model/class/Class.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ DB0_PACKED_END
const FieldSafe &getFieldSafe() const;
void setFieldAccess(const std::vector<std::uint64_t> &account_ids, FieldMaskFlags mask,
const std::vector<std::string> &field_names);
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const Member &) const;
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &) const;
std::optional<FieldMaskFlags> tryGetFieldAccessByMember(std::uint64_t account_id, const Member &) const;
std::optional<FieldMaskFlags> tryGetFieldAccessByMemberLoc(std::uint64_t account_id, const MemberLoc &) const;
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const;
std::optional<FieldMaskFlags> tryGetFieldAccessByName(std::uint64_t account_id, const char *field_name) const;
std::vector<std::pair<std::string, FieldMaskFlags> > getFieldAccess(std::uint64_t account_id) const;
std::uint32_t getFieldOffsetRange() const;

Expand Down
Loading