Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dbzero/dbzero/dbzero.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,14 @@ def get_field_access(class_obj: type, account_id: int) -> Iterable[Tuple[str, Tu
"""Return protected-field access flags for a memo class and account."""
...

def reset_protect_fields(class_obj: type) -> None:
"""Clear the persisted protect_fields flag for a memo class.

The memo type must no longer be decorated with ``protect_fields=True``.
Remove the argument or set it to ``False`` before calling this function.
"""
...

def _init_data_masking(
context_var: Any,
prefix: Union[str, Any, Sequence[Any], None] = None,
Expand Down
260 changes: 257 additions & 3 deletions python_tests/test_memo_protect_fields.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# SPDX-License-Identifier: LGPL-2.1-or-later
# Copyright (c) 2025 DBZero Software sp. z o.o.

import asyncio
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from contextvars import ContextVar

import dbzero as db0
import pytest

from .conftest import DB0_DIR

Expand All @@ -20,6 +24,34 @@ class MemoProtectedFieldsClass:
value: int


@db0.memo(protect_fields=True)
@dataclass
class MemoProtectedComplexFieldsClass:
none_value: object
flag: bool
count: int
name: str
payload: bytes
pair: tuple


@db0.memo(protect_fields=True)
@dataclass
class MemoProtectedDefaultFieldsClass:
name: str
generated_value: str = "auto-generated"


@db0.memo(protect_fields=True)
class MemoProtectedAutoGeneratedInitVarClass:
generated_value = "auto-generated"

def __init__(self, name, include_generated=False):
self.name = name
if include_generated:
self.generated_value = "persisted"


@db0.memo
@dataclass
class MemoUnprotectedFieldsClass:
Expand Down Expand Up @@ -145,12 +177,38 @@ class ProtectedAfter:
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True


def test_reset_protect_fields_clears_persisted_flag(db0_fixture):
obj = MemoProtectedFieldsClass("alpha", 1)
def test_reset_protect_fields_rejects_type_still_decorated_as_protected(db0_fixture):
MemoProtectedFieldsClass("alpha", 1)

with pytest.raises(RuntimeError, match="decorated.*protect_fields"):
db0.reset_protect_fields(MemoProtectedFieldsClass)


def test_reset_protect_fields_clears_persisted_flag_after_decorator_removed(db0_fixture):
@db0.memo(id="dbzero-software/dbzero/tests/protected-reset", protect_fields=True)
@dataclass
class ProtectedBefore:
name: str

obj = ProtectedBefore("alpha")
obj_id = db0.uuid(obj)
assert get_memo_class_object(obj).get_type_flags()["protect_fields"] is True
db0.commit()

db0.close()
db0.init(DB0_DIR)
db0.open("my-test-prefix")

@db0.memo(id="dbzero-software/dbzero/tests/protected-reset", protect_fields=False)
@dataclass
class ProtectedAfter:
name: str

obj = db0.fetch(ProtectedAfter, obj_id)
memo_class = get_memo_class_object(obj)
assert memo_class.get_type_flags()["protect_fields"] is True

memo_class.reset_protect_fields()
assert db0.reset_protect_fields(ProtectedAfter) is None
assert memo_class.get_type_flags()["protect_fields"] is False


Expand Down Expand Up @@ -296,3 +354,199 @@ def test_describe_field_offset_range_counts_predeclared_access_fields(db0_fixtur
materialized_range = db0.describe(obj)["field_offset_range"]
assert materialized_range >= predeclared_range
assert materialized_range < 120 * 2 + 128


def test_protected_field_getter_requires_initialized_data_masking(db0_fixture):
obj = MemoProtectedFieldsClass("alpha", 1)

with pytest.raises(RuntimeError, match="data masking"):
_ = obj.name


def test_protected_field_getter_requires_read_access(db0_fixture):
account_id = ContextVar("protected_read_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0.set_field_access(MemoProtectedFieldsClass, 123, (FieldAccess.READ,), "name")
db0._init_data_masking(account_id)
account_id.set(123)

assert obj.name == "alpha"

with pytest.raises(PermissionError, match="read"):
_ = obj.value


def test_protected_field_getter_returns_missing_value_placeholder(db0_fixture):
account_id = ContextVar("protected_placeholder_account_id")
missing_value = object()
obj = MemoProtectedFieldsClass("alpha", 1)
db0.set_field_access(MemoProtectedFieldsClass, 123, (FieldAccess.READ,), "name")
db0._init_data_masking(account_id, missing_value_placeholder=missing_value)
account_id.set(123)

assert obj.value is missing_value


def test_protected_field_getter_allows_debug_super_account_read(db0_fixture):
account_id = ContextVar("protected_super_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0._init_data_masking(account_id, mode="DEBUG")
account_id.set(-1)

assert obj.name == "alpha"
assert obj.value == 1


def test_protected_field_getter_rejects_invalid_negative_account_id(db0_fixture):
account_id = ContextVar("protected_invalid_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0._init_data_masking(account_id, mode="DEBUG")
account_id.set(-3)

with pytest.raises(RuntimeError, match="invalid.*account"):
_ = obj.name


def test_protected_field_getter_enforces_read_access_for_auto_generated_fields(db0_fixture):
account_id = ContextVar("protected_auto_generated_account_id")
MemoProtectedAutoGeneratedInitVarClass("materializer", include_generated=True)
obj = MemoProtectedAutoGeneratedInitVarClass("alpha")
db0.set_field_access(MemoProtectedAutoGeneratedInitVarClass, 101, (FieldAccess.READ,), "name")
db0.set_field_access(MemoProtectedAutoGeneratedInitVarClass, 202, (FieldAccess.READ,), "generated_value")
db0._init_data_masking(account_id)

account_id.set(101)
assert obj.name == "alpha"
with pytest.raises(PermissionError, match="read"):
_ = obj.generated_value

account_id.set(202)
assert obj.generated_value == "auto-generated"
with pytest.raises(PermissionError, match="read"):
_ = obj.name


def test_protected_field_getter_applies_complex_masks_for_mixed_field_types(db0_fixture):
account_id = ContextVar("protected_complex_account_id")
obj = MemoProtectedComplexFieldsClass(None, False, 7, "alpha", b"payload", ("x", 11))

db0.set_field_access(
MemoProtectedComplexFieldsClass,
101,
(FieldAccess.READ,),
"none_value",
"flag",
"name",
)
db0.set_field_access(
MemoProtectedComplexFieldsClass,
202,
(FieldAccess.READ,),
"count",
"payload",
)
db0.set_field_access(
MemoProtectedComplexFieldsClass,
303,
(FieldAccess.READ,),
"none_value",
"flag",
"count",
"name",
"payload",
"pair",
)
db0._init_data_masking(account_id)

account_id.set(101)
assert obj.none_value is None
assert obj.flag is False
assert obj.name == "alpha"
with pytest.raises(PermissionError, match="read"):
_ = obj.count
with pytest.raises(PermissionError, match="read"):
_ = obj.payload

account_id.set(202)
assert obj.count == 7
assert obj.payload == b"payload"
with pytest.raises(PermissionError, match="read"):
_ = obj.none_value
with pytest.raises(PermissionError, match="read"):
_ = obj.flag

account_id.set(303)
assert obj.none_value is None
assert obj.flag is False
assert obj.count == 7
assert obj.name == "alpha"
assert obj.payload == b"payload"
assert obj.pair == ("x", 11)

account_id.set(404)
for field_name in ("none_value", "flag", "count", "name", "payload", "pair"):
with pytest.raises(PermissionError, match="read"):
getattr(obj, field_name)


@pytest.mark.asyncio
async def test_protected_field_getter_uses_contextvar_per_async_task(db0_fixture):
account_id = ContextVar("protected_async_account_id")
obj = MemoProtectedFieldsClass("alpha", 1)
db0.set_field_access(MemoProtectedFieldsClass, 101, (FieldAccess.READ,), "name")
db0.set_field_access(MemoProtectedFieldsClass, 202, (FieldAccess.READ,), "value")
db0._init_data_masking(account_id)

async def read_as(account, allowed_field, denied_field):
account_id.set(account)
await asyncio.sleep(0)
allowed_value = getattr(obj, allowed_field)
await asyncio.sleep(0)
with pytest.raises(PermissionError, match="read"):
getattr(obj, denied_field)
return allowed_value

name_value, int_value = await asyncio.gather(
read_as(101, "name", "value"),
read_as(202, "value", "name"),
)

assert name_value == "alpha"
assert int_value == 1


def test_protected_field_getter_uses_contextvar_per_thread(db0_fixture):
account_id = ContextVar("protected_thread_account_id")
obj = MemoProtectedComplexFieldsClass(None, True, 7, "alpha", b"payload", ("x", 11))
db0.set_field_access(
MemoProtectedComplexFieldsClass,
101,
(FieldAccess.READ,),
"none_value",
"flag",
)
db0.set_field_access(
MemoProtectedComplexFieldsClass,
202,
(FieldAccess.READ,),
"count",
"name",
)
db0._init_data_masking(account_id)

def read_as(account):
account_id.set(account)
if account == 101:
with pytest.raises(PermissionError, match="read"):
_ = obj.count
return obj.none_value, obj.flag
with pytest.raises(PermissionError, match="read"):
_ = obj.flag
return obj.count, obj.name

with ThreadPoolExecutor(max_workers=2) as pool:
account_101 = pool.submit(read_as, 101)
account_202 = pool.submit(read_as, 202)

assert account_101.result() == (None, True)
assert account_202.result() == (7, "alpha")
48 changes: 48 additions & 0 deletions src/dbzero/bindings/python/DataMasking.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2025 DBZero Software sp. z o.o.

#pragma once

#include <Python.h>

namespace db0

{

enum class DataMaskingMode
{
RELEASE,
DEBUG
};

struct DataMaskingState
{
PyObject *contextVar = nullptr;
PyObject *missingValuePlaceholder = nullptr;
bool hasMissingValuePlaceholder = false;
DataMaskingMode mode = DataMaskingMode::RELEASE;

DataMaskingState(PyObject *contextVar, PyObject *missingValuePlaceholder,
bool hasMissingValuePlaceholder, DataMaskingMode mode)
: contextVar(contextVar)
, missingValuePlaceholder(missingValuePlaceholder)
, hasMissingValuePlaceholder(hasMissingValuePlaceholder)
, mode(mode)
{
Py_INCREF(contextVar);
if (missingValuePlaceholder) {
Py_INCREF(missingValuePlaceholder);
}
}

bool matches(PyObject *otherContextVar, PyObject *otherMissingValuePlaceholder,
bool otherHasMissingValuePlaceholder, DataMaskingMode otherMode) const
{
return contextVar == otherContextVar
&& missingValuePlaceholder == otherMissingValuePlaceholder
&& hasMissingValuePlaceholder == otherHasMissingValuePlaceholder
&& mode == otherMode;
}
};

}
Loading
Loading