Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions python_tests/test_memo_protect_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,71 @@ def test_protected_field_getter_requires_initialized_data_masking(db0_fixture):
_ = obj.name


def test_protected_field_create_from_init_requires_create_access(db0_fixture):
account_id = ContextVar("protected_create_init_account_id")
MemoProtectedFieldsClass("materializer", 0)
db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.CREATE,), "name")
db0.set_field_access(
MemoProtectedFieldsClass,
202,
(FieldAccess.CREATE,),
"name",
"value",
)
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="create"):
MemoProtectedFieldsClass("alpha", 1)

account_id.set(202)
MemoProtectedFieldsClass("beta", 2)


def test_protected_field_create_from_init_uses_predeclared_name_masks(db0_fixture):
account_id = ContextVar("protected_create_predeclared_account_id")

@db0.memo(protect_fields=True)
class InitFutureField:
def __init__(self, include_fields=False):
if include_fields:
self.future = "allowed"
self.denied = "blocked"

InitFutureField()
db0.set_field_access(InitFutureField, 101, (FieldAccess.CREATE,), "future")
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="create"):
InitFutureField(include_fields=True)


def test_protected_field_append_requires_create_access(db0_fixture):
account_id = ContextVar("protected_create_append_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.UPDATE,), "extra")
db0.set_field_access(MemoProtectedFieldsClass, 202, (FieldAccess.CREATE,), "extra")
db0._init_data_masking(account_id)

account_id.set(101)
with pytest.raises(PermissionError, match="create"):
obj.extra = "denied"

account_id.set(202)
obj.extra = "allowed"


def test_protected_field_update_does_not_require_create_access(db0_fixture):
account_id = ContextVar("protected_update_existing_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.UPDATE,), "name")
db0._init_data_masking(account_id)

account_id.set(101)
obj.name = "beta"


def test_protected_field_getter_requires_read_access(db0_fixture):
account_id = ContextVar("protected_read_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
Expand Down
110 changes: 90 additions & 20 deletions src/dbzero/bindings/python/Memo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,56 +361,97 @@ namespace db0::python
}

template <typename MemoImplT>
PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc)
bool getProtectedFieldAccessContext(MemoImplT *memo_obj, const db0::object_model::Class *&type,
std::shared_ptr<DataMaskingState> &masking_state, long long &account_id)
{
auto &type = memo_obj->ext().getType();
if (!type.isProtectFields()) {
return nullptr;
type = &memo_obj->ext().getType();
if (!type->isProtectFields()) {
return false;
}

if (!Settings::m_data_masking_enabled) {
PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields");
return nullptr;
return false;
}

auto masking_state = memo_obj->ext().getFixture()->getMaskingState();
masking_state = memo_obj->ext().getFixture()->getMaskingState();
if (!masking_state) {
PyErr_SetString(PyExc_RuntimeError, "data masking is not initialized for protected fields");
return nullptr;
return false;
}

PyObject *pyAccountId = nullptr;
if (PyContextVar_Get(masking_state->contextVar, NULL, &pyAccountId) < 0) {
PyErr_SetString(PyExc_RuntimeError, "unable to read data masking account context");
return nullptr;
return false;
}
if (!pyAccountId) {
PyErr_SetString(PyExc_RuntimeError, "data masking account context is not set");
return nullptr;
return false;
}

auto accountId = PyLong_AsLongLong(pyAccountId);
account_id = PyLong_AsLongLong(pyAccountId);
Py_DECREF(pyAccountId);
if (PyErr_Occurred()) {
PyErr_SetString(PyExc_TypeError, "data masking account context must be an int");
return nullptr;
return false;
}

bool canRead = false;
if (accountId < -2) {
if (account_id < -2) {
PyErr_SetString(PyExc_RuntimeError, "invalid data masking account id");
return nullptr;
} else if (accountId == -1 || accountId == -2) {
canRead = masking_state->mode == DataMaskingMode::DEBUG;
} else if (accountId >= 0) {
auto mask = type.tryGetFieldAccess(static_cast<std::uint64_t>(accountId), member_loc);
canRead = mask && (*mask)[db0::object_model::FieldMaskOptions::READ];
return false;
}

return true;
}

template <typename MemoImplT>
bool checkProtectedFieldAccess(MemoImplT *memo_obj, db0::object_model::FieldMaskOptions access_option,
const db0::object_model::MemberLoc &member_loc, const char *field_name = nullptr)
{
auto &memo_type = memo_obj->ext().getType();
if (!memo_type.isProtectFields()) {
return true;
}

const db0::object_model::Class *type = nullptr;
std::shared_ptr<DataMaskingState> masking_state;
long long account_id = 0;
if (!getProtectedFieldAccessContext(memo_obj, type, masking_state, account_id)) {
return !PyErr_Occurred();
}

if (masking_state->mode == DataMaskingMode::DEBUG) {
if (account_id == -2) {
return true;
}
if (account_id == -1) {
return access_option == db0::object_model::FieldMaskOptions::READ;
}
}

if (canRead) {
if (account_id < 0) {
return false;
}

auto mask = type->tryGetFieldAccess(static_cast<std::uint64_t>(account_id), member_loc);
Comment thread
wskozlowski marked this conversation as resolved.
if (!mask && field_name) {
mask = type->tryGetFieldAccess(static_cast<std::uint64_t>(account_id), field_name);
}
return mask && (*mask)[access_option];
}

template <typename MemoImplT>
PyObject *checkProtectedFieldReadAccess(MemoImplT *memo_obj, const db0::object_model::MemberLoc &member_loc)
{
if (checkProtectedFieldAccess(memo_obj, db0::object_model::FieldMaskOptions::READ, member_loc)) {
return nullptr;
}
if (PyErr_Occurred()) {
return nullptr;
}

auto masking_state = memo_obj->ext().getFixture()->getMaskingState();
if (masking_state->hasMissingValuePlaceholder) {
Py_INCREF(masking_state->missingValuePlaceholder);
return masking_state->missingValuePlaceholder;
Expand All @@ -420,6 +461,28 @@ namespace db0::python
return nullptr;
}

template <typename MemoImplT>
bool checkProtectedFieldCreateAccess(MemoImplT *memo_obj, const char *field_name)
Comment thread
adrian-zawadzki marked this conversation as resolved.
{
auto &memo_type = memo_obj->ext().getType();
if (!memo_type.isProtectFields() || !Settings::m_data_masking_enabled) {
return true;
}

auto member_loc = memo_type.findField(field_name);
if (checkProtectedFieldAccess(
memo_obj, db0::object_model::FieldMaskOptions::CREATE, member_loc, field_name
)) {
return true;
}
if (PyErr_Occurred()) {
return false;
}

PyErr_SetString(PyExc_PermissionError, "data masking denies create access to protected field");
return false;
}

template <typename MemoImplT>
PyObject *tryMemoObject_getattro(MemoImplT *memo_obj, PyObject *attr)
{
Expand Down Expand Up @@ -496,9 +559,16 @@ namespace db0::python
auto maybe_type_id = PyToolkit::getTypeManager().tryGetTypeId(value);
if (maybe_type_id) {
if (self->ext().hasInstance()) {
auto member_loc = self->ext().findField(attr_name);
if (!self->ext().tryGet(member_loc) && !checkProtectedFieldCreateAccess(self, attr_name)) {
return -1;
}
db0::FixtureLock lock(self->ext().getFixture());
self->modifyExt().set(lock, attr_name, *maybe_type_id, value);
} else {
if (!checkProtectedFieldCreateAccess(self, attr_name)) {
return -1;
}
// considered as a non-mutating operation
self->ext().setPreInit(attr_name, *maybe_type_id, value);
}
Expand Down
20 changes: 20 additions & 0 deletions src/dbzero/object_model/class/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,26 @@ namespace db0::object_model
return {};
}

std::optional<FieldMaskFlags> Class::tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const
{
if (!isProtectFields()) {
return {};
}

auto &field_safe = getFieldSafe();
auto maybe_offset = field_safe.getFieldIDMapper().tryGetAssignedFieldOffset(field_name);
if (!maybe_offset) {
return {};
}

auto field_mask = field_safe.getFieldMaskManager().tryGetFieldMask(account_id);
if (!field_mask) {
return {};
}

return field_mask->getAssignedMask(*maybe_offset);
}

std::vector<std::pair<std::string, FieldMaskFlags> > Class::getFieldAccess(std::uint64_t account_id) const
{
if (!isProtectFields()) {
Expand Down
1 change: 1 addition & 0 deletions src/dbzero/object_model/class/Class.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ DB0_PACKED_END
const std::vector<std::string> &field_names);
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const Member &) const;
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const MemberLoc &) const;
std::optional<FieldMaskFlags> tryGetFieldAccess(std::uint64_t account_id, const char *field_name) const;
std::vector<std::pair<std::string, FieldMaskFlags> > getFieldAccess(std::uint64_t account_id) const;
std::uint32_t getFieldOffsetRange() const;

Expand Down
Loading