Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions dbzero/dbzero/dbzero.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,9 @@ def rename_field(class_obj: type, from_name: str, to_name: str) -> None:
def set_field_access(class_obj: type, account_id: Union[int, Sequence[int]], mode: Tuple[EnumValue, ...], *fields: str) -> None:
"""Set protected-field access flags for one or more fields of a memo class.

The memo class must be declared with ``protect_fields=True``. Pass an empty
``mode`` tuple to clear all access flags for the specified account and fields.
The memo class must be declared with ``protect_fields=True`` or inherit it
from a protected memo base. Pass an empty ``mode`` tuple to clear all access
flags for the specified account and fields.
"""
...

Expand All @@ -622,8 +623,9 @@ def get_field_access(class_obj: type, account_id: int) -> Iterable[Tuple[str, Tu
def reset_protect_fields(class_obj: type) -> None:
"""Clear the persisted protect_fields flag for a memo class.

The memo type must no longer be decorated with ``protect_fields=True``.
Remove the argument or set it to ``False`` before calling this function.
The memo type must no longer be decorated with ``protect_fields=True`` and
must not inherit protected fields from a protected memo base. Remove the
argument or set it to ``False`` before calling this function.
"""
...

Expand Down
1 change: 1 addition & 0 deletions dbzero/dbzero/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def memo(cls: Optional[type] = None, **kwargs) -> type:
If True, the persistent class is marked for field protection. Once the class is
materialized, removing this argument from the Python definition does not clear
the persisted flag; use reset_protect_fields on the dbzero Class object instead.
Derived memo classes inherit field protection and cannot disable it.

Returns
-------
Expand Down
223 changes: 223 additions & 0 deletions python_tests/test_memo_protect_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def __init__(self, count):
setattr(self, f"init_field_{i}", values[i % len(values)](i))


@db0.memo
@dataclass
class MemoUnprotectedBaseFieldsClass:
base_value: str


@db0.memo(protect_fields=True)
@dataclass
class MemoProtectedDerivedFieldsClass(MemoUnprotectedBaseFieldsClass):
name: str
value: int


@db0.memo
@dataclass
class MemoImplicitlyProtectedDerivedFieldsClass(MemoProtectedDerivedFieldsClass):
derived_value: float


def get_memo_class_object(obj):
return db0.get_memo_class(obj).get_class()

Expand Down Expand Up @@ -242,6 +261,60 @@ def test_set_field_access_requires_protected_class(db0_fixture):
assert False, "set_field_access should fail for unprotected classes"


def test_protect_fields_is_inherited_by_derived_memo_classes(db0_fixture):
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)

assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.set_field_access(MemoImplicitlyProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "derived_value")


def test_protect_fields_cannot_be_disabled_in_derived_memo_class(db0_fixture):
with pytest.raises(RuntimeError, match="protect_fields.*base"):
@db0.memo(protect_fields=False)
@dataclass
class ExplicitlyUnprotectedDerived(MemoProtectedDerivedFieldsClass):
extra_value: str


def test_set_field_access_still_rejects_unprotected_base_of_protected_class(db0_fixture):
MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)

with pytest.raises(RuntimeError, match="protected fields"):
db0.set_field_access(MemoUnprotectedBaseFieldsClass, 123, (FieldAccess.READ,), "base_value")


def test_reset_protect_fields_rejects_inherited_protection(db0_fixture):
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
memo_class = get_memo_class_object(obj)

with pytest.raises(RuntimeError, match="inherits.*protect_fields"):
db0.reset_protect_fields(MemoImplicitlyProtectedDerivedFieldsClass)

assert memo_class.get_type_flags()["protect_fields"] is True


def test_class_reset_protect_fields_rejects_inherited_protection(db0_fixture):
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
memo_class = get_memo_class_object(obj)

with pytest.raises(RuntimeError, match="inherits.*protect_fields"):
memo_class.reset_protect_fields()

assert memo_class.get_type_flags()["protect_fields"] is True


def test_explicit_true_is_allowed_on_class_derived_from_protected_base(db0_fixture):
@db0.memo(protect_fields=True)
@dataclass
class ExplicitlyProtectedDerived(MemoProtectedDerivedFieldsClass):
extra_value: str

obj = ExplicitlyProtectedDerived("base", "alpha", 1, "extra")

assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.set_field_access(ExplicitlyProtectedDerived, 123, (FieldAccess.READ,), "extra_value")


def test_set_field_access_accepts_unknown_field_names(db0_fixture):
MemoProtectedFieldsClass("alpha", 1)

Expand Down Expand Up @@ -376,6 +449,156 @@ def test_protected_field_getter_requires_read_access(db0_fixture):
_ = obj.value


def test_protected_field_getter_does_not_inherit_base_class_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "base_value", "name")
db0.set_field_access(MemoImplicitlyProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "derived_value")
db0._init_data_masking(account_id)
account_id.set(123)

assert obj.derived_value == 1.5
with pytest.raises(PermissionError, match="read"):
_ = obj.base_value
with pytest.raises(PermissionError, match="read"):
_ = obj.name
with pytest.raises(PermissionError, match="read"):
_ = obj.value


def test_protected_field_getter_uses_derived_class_own_masks(db0_fixture):
account_id = ContextVar("protected_inheritance_allow_override_account_id")
obj = MemoImplicitlyProtectedDerivedFieldsClass("base", "alpha", 1, 1.5)
db0.set_field_access(MemoProtectedDerivedFieldsClass, 123, (), "name")
db0.set_field_access(MemoImplicitlyProtectedDerivedFieldsClass, 123, (FieldAccess.READ,), "name")
db0._init_data_masking(account_id)
account_id.set(123)

assert obj.name == "alpha"
with pytest.raises(PermissionError, match="read"):
_ = obj.value


def test_inherited_protect_fields_survives_reopen_without_parameter(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-base")
@dataclass
class ReopenBase:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-mid", protect_fields=True)
@dataclass
class ReopenMid(ReopenBase):
name: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-derived")
@dataclass
class ReopenDerived(ReopenMid):
value: int

obj = ReopenDerived("base", "alpha", 1)
obj_id = db0.uuid(obj)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-base")
@dataclass
class ReopenBaseAfter:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-mid")
@dataclass
class ReopenMidAfter(ReopenBaseAfter):
name: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-inherited-reopen-derived")
@dataclass
class ReopenDerivedAfter(ReopenMidAfter):
value: int

obj = db0.fetch(ReopenDerivedAfter, obj_id)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.set_field_access(ReopenDerivedAfter, 123, (FieldAccess.READ,), "value")


def test_derived_class_inherits_protect_fields_enabled_on_base_after_materialization(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-base-enabled-later-base")
@dataclass
class BaseBefore:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-enabled-later-derived")
@dataclass
class DerivedBefore(BaseBefore):
value: int

obj = DerivedBefore("base", 1)
obj_id = db0.uuid(obj)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False
with pytest.raises(RuntimeError, match="protected fields"):
db0.set_field_access(DerivedBefore, 123, (FieldAccess.READ,), "value")
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-enabled-later-base", protect_fields=True)
@dataclass
class BaseAfter:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-enabled-later-derived")
@dataclass
class DerivedAfter(BaseAfter):
value: int

obj = db0.fetch(DerivedAfter, obj_id)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value")


def test_derived_class_stops_inheriting_protect_fields_after_base_reset(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-base-reset-base", protect_fields=True)
@dataclass
class BaseBefore:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-reset-derived")
@dataclass
class DerivedBefore(BaseBefore):
value: int

obj = DerivedBefore("base", 1)
obj_id = db0.uuid(obj)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-reset-base", protect_fields=False)
@dataclass
class BaseAfter:
base_value: str

@db0.memo(id="dbzero-software/dbzero/tests/protected-base-reset-derived")
@dataclass
class DerivedAfter(BaseAfter):
value: int

obj = db0.fetch(DerivedAfter, obj_id)
db0.reset_protect_fields(BaseAfter)

assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is False
with pytest.raises(RuntimeError, match="protected fields"):
db0.set_field_access(DerivedAfter, 123, (FieldAccess.READ,), "value")


def test_protected_field_getter_returns_missing_value_placeholder(db0_fixture):
account_id = ContextVar("protected_placeholder_account_id")
missing_value = object()
Expand Down
19 changes: 16 additions & 3 deletions src/dbzero/bindings/python/Memo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Migration.hpp"
#include "PyHash.hpp"
#include "DataMasking.hpp"
#include <optional>
#include <dbzero/object_model/object.hpp>
#include <dbzero/object_model/class.hpp>
#include <dbzero/object_model/class/FieldMask.hpp>
Expand Down Expand Up @@ -902,7 +903,7 @@ namespace db0::python

PyObject *wrapPyType(PyTypeObject *base_class, bool is_singleton, bool no_default_tags, const char *prefix_name,
const char *type_id, const char *file_name, std::vector<std::string> &&init_vars, PyObject *py_dyn_prefix_callable,
std::vector<Migration> &&migrations, bool no_cache, bool immutable, bool protect_fields)
std::vector<Migration> &&migrations, bool no_cache, bool immutable, std::optional<bool> protect_fields_option)
{
auto py_class = Py_BORROW(base_class);
auto py_module = Py_OWN(findModule(*Py_OWN(PyObject_GetAttrString((PyObject*)*py_class, "__module__"))));
Expand Down Expand Up @@ -936,6 +937,15 @@ namespace db0::python
return nullptr;
}

auto base_memo_type = PyToolkit::getBaseMemoType(*new_type);
bool inherited_protect_fields = base_memo_type
&& MemoTypeDecoration::get(base_memo_type).getFlags()[MemoOptions::PROTECT_FIELDS];
if (inherited_protect_fields && protect_fields_option == false) {
THROWF(db0::InputException)
<< "Cannot set protect_fields=False on a class derived from a protect_fields base class";
}
bool protect_fields = protect_fields_option.value_or(false);

MemoFlags type_flags = no_default_tags ? MemoFlags { MemoOptions::NO_DEFAULT_TAGS } : MemoFlags();
if (no_cache) {
type_flags.set(MemoOptions::NO_CACHE);
Expand Down Expand Up @@ -1000,7 +1010,10 @@ namespace db0::python
bool no_default_tags = py_no_default_tags && PyObject_IsTrue(py_no_default_tags);
bool no_cache = py_no_cache && PyObject_IsTrue(py_no_cache);
bool immutable = py_immutable && PyObject_IsTrue(py_immutable);
bool protect_fields = py_protect_fields && PyObject_IsTrue(py_protect_fields);
std::optional<bool> protect_fields_option;
if (py_protect_fields) {
protect_fields_option = PyObject_IsTrue(py_protect_fields);
}
const char *prefix_name = (py_prefix_name && py_prefix_name != Py_None) ? PyUnicode_AsUTF8(py_prefix_name) : nullptr;
const char *type_id = py_type_id ? PyUnicode_AsUTF8(py_type_id) : nullptr;
const char *file_name = (py_file_name && py_file_name != Py_None) ? PyUnicode_AsUTF8(py_file_name) : nullptr;
Expand Down Expand Up @@ -1035,7 +1048,7 @@ namespace db0::python

auto migrations = extractMigrations(py_migrations);
return wrapPyType(castToType(class_obj), is_singleton, no_default_tags, prefix_name, type_id, file_name,
std::move(init_vars), py_dyn_prefix, std::move(migrations), no_cache, immutable, protect_fields
std::move(init_vars), py_dyn_prefix, std::move(migrations), no_cache, immutable, protect_fields_option
);
}

Expand Down
22 changes: 14 additions & 8 deletions src/dbzero/bindings/python/PyAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ namespace db0::python
auto fixture_uuid = MemoTypeDecoration::get(reinterpret_cast<PyTypeObject*>(py_type)).getFixtureUUID();
auto fixture = PyToolkit::getPyWorkspace().getWorkspace().getFixture(fixture_uuid, AccessType::READ_WRITE);
auto &class_factory = fixture->get<ClassFactory>();
auto type = class_factory.getExistingType(reinterpret_cast<PyTypeObject*>(py_type));
auto type = class_factory.getOrCreateType(reinterpret_cast<PyTypeObject*>(py_type));
if (!type->isProtectFields()) {
THROWF(db0::InputException) << "Class " << type->getName() << " does not have protected fields enabled";
}
Expand Down Expand Up @@ -979,7 +979,7 @@ namespace db0::python
auto fixture_uuid = MemoTypeDecoration::get(reinterpret_cast<PyTypeObject*>(py_type)).getFixtureUUID();
auto fixture = PyToolkit::getPyWorkspace().getWorkspace().getFixture(fixture_uuid, AccessType::READ_WRITE);
auto &class_factory = fixture->get<ClassFactory>();
auto type = class_factory.getExistingType(reinterpret_cast<PyTypeObject*>(py_type));
auto type = class_factory.getOrCreateType(reinterpret_cast<PyTypeObject*>(py_type));
if (!type->isProtectFields()) {
THROWF(db0::InputException) << "Class " << type->getName() << " does not have protected fields enabled";
}
Expand Down Expand Up @@ -1019,16 +1019,22 @@ namespace db0::python

auto memo_type = reinterpret_cast<PyTypeObject*>(py_type);
auto &decor = MemoTypeDecoration::get(memo_type);
if (decor.getFlags()[MemoOptions::PROTECT_FIELDS]) {
THROWF(db0::InputException)
<< "Type is still decorated with protect_fields=True; remove it or set protect_fields=False first";
}

using ClassFactory = db0::object_model::ClassFactory;
auto fixture_uuid = decor.getFixtureUUID(AccessType::READ_WRITE);
auto fixture = PyToolkit::getPyWorkspace().getWorkspace().getFixture(fixture_uuid, AccessType::READ_WRITE);
auto &class_factory = fixture->get<ClassFactory>();
auto type = class_factory.getExistingType(memo_type);
auto type = class_factory.getOrCreateType(memo_type);

if (type->getBaseClassPtr() && type->getBaseClassPtr()->isProtectFields()) {
THROWF(db0::InputException)
<< "Cannot disable protected fields on class " << type->getName()
<< " because it inherits from a protect_fields base class";
}

if (decor.getFlags()[MemoOptions::PROTECT_FIELDS]) {
THROWF(db0::InputException)
<< "Type is still decorated with protect_fields=True; remove it or set protect_fields=False first";
}

db0::FixtureLock lock(fixture);
type->resetProtectFields();
Expand Down
4 changes: 2 additions & 2 deletions src/dbzero/bindings/python/PyInternalAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ namespace db0::python
// validate type if requested (no validation for MemoBase)
if (py_expected_type && !PyToolkit::getTypeManager().isMemoBase(py_expected_type)) {
// in other cases the type must match the actual object type
auto expected_class = class_factory.getExistingType(py_expected_type);
auto expected_class = class_factory.getOrCreateType(py_expected_type);
// honor class-specific access flags (e.g. type-level no_cache)
auto result = PyToolkit::unloadObject(fixture, addr, class_factory, nullptr, addr.getInstanceId(),
expected_class->getInstanceFlags()
Expand Down Expand Up @@ -1100,4 +1100,4 @@ namespace db0::python
template PyObject *getMaterializedMemoObject(MemoObject *);
template PyObject *getMaterializedMemoObject(MemoImmutableObject *);

}
}
Loading
Loading