Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions dbzero/dbzero/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,18 @@ def find_migrations(from_type):
if hasattr(attr, '_db0_migration'):
yield (attr, list(dis_assig(attr)))

def wrap(cls_):
# note that we use the __dyn_prefix mechanism only for singletons
def wrap(cls_):
is_singleton = kwargs.get("singleton", False)
return _wrap_memo_type(cls_, py_file = getfile(cls_), py_init_vars = list(dis_init_assig(cls_)), \
py_dyn_prefix = __dyn_prefix(cls_) if is_singleton else None, \
# note that we use the __dyn_prefix mechanism only for singletons
try:
dyn_prefix = __dyn_prefix(cls_) if is_singleton else None
init_vars = list(dis_init_assig(cls_))
except TypeError:
# unable to process the __init__ function (e.g. wrapper_descriptor)
dyn_prefix = None
init_vars = []

return _wrap_memo_type(cls_, py_file = getfile(cls_), py_init_vars = init_vars, py_dyn_prefix = dyn_prefix, \
py_migrations = list(find_migrations(cls_)) if is_singleton else None, **kwargs
)

Expand Down
17 changes: 13 additions & 4 deletions python_tests/test_memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def value(self):
def value(self, new_value):
self._value = new_value + 1


@db0.memo
class MemoDerivedClassNoInit(MemoTestClass):
pass


def test_memo_is_instance_operator(db0_fixture):
obj_1 = MemoTestClass(999)
Expand Down Expand Up @@ -210,12 +215,16 @@ def test_selective_assign_members(db0_fixture):
obj_5 = MemoAnyAttrs(f4 = False, f6 = 1, f9 = 11)
# too spread apart, only some fraction of slots to be allocated to pos-vt
assert len(db0.describe(obj_5)["field_layout"]["pos_vt"]) < 3



@pytest.mark.skip(reason="Missing feature: https://github.com/dbzero-software/dbzero/issues/682")
def test_memo_setattr(db0_fixture):
obj_1 = MemoTestClass(1)
obj_1 = MemoTestClass(1)
obj_1.__setattr__("value", 10)
assert obj_1.value == 10
obj_1.__setattr__("new_field", 20)
assert obj_1.new_field == 20
assert obj_1.new_field == 20


def test_memo_derived_no_init(db0_fixture):
obj_1 = MemoDerivedClassNoInit(123)
assert obj_1.value == 123
148 changes: 130 additions & 18 deletions src/dbzero/bindings/python/Memo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ namespace db0::python
using TypeObjectSharedPtr = PyTypes::TypeObjectSharedPtr;

// @return type name / full type name (tp_name)
std::pair<std::string, std::string> getMemoTypeName(shared_py_object<PyTypeObject*> py_class)
std::string getMemoTypeName(shared_py_object<PyTypeObject*> py_class)
{
std::stringstream str;
str << "Memo_" << (*py_class)->tp_name;
auto type_name = str.str();
auto full_type_name = std::string("dbzero.") + type_name;
return { type_name, full_type_name };
return str.str();
}

template <typename MemoImplT>
Expand Down Expand Up @@ -229,7 +227,11 @@ namespace db0::python
// They only contain dbzero native objects
return 0;
}


// use MRO to idenfity the base tp_init to be called
template <typename MemoImplT>
initproc MemoObject_getInitFunc(MemoImplT *self);

template <typename MemoImplT>
int PyAPI_MemoObject_init(MemoImplT *self, PyObject* args, PyObject* kwds)
{
Expand All @@ -240,11 +242,9 @@ namespace db0::python
PY_API_FUNC
// the instance may already exist (e.g. if this is a singleton)
if (!self->ext().hasInstance()) {
auto py_type = Py_TYPE(self);
auto base_type = py_type->tp_base;

auto init_func = MemoObject_getInitFunc<MemoImplT>(self);
// invoke tp_init from base type (wrapped pyhon class)
if (base_type->tp_init((PyObject*)self, args, kwds) < 0) {
if (init_func && init_func((PyObject*)self, args, kwds) < 0) {
// mark object as defunct
self->ext().setDefunct();
PyObject *ptype, *pvalue, *ptraceback;
Expand Down Expand Up @@ -297,6 +297,17 @@ namespace db0::python
return 0;
}

template <typename MemoImplT>
initproc MemoObject_getInitFunc(MemoImplT *self)
{
auto py_type = Py_TYPE(self);
if (!py_type) {
return nullptr;
}
// NOTE: for memo types just use the immediate base tp_init
return py_type->tp_base->tp_init;
}

void MemoObject_drop(MemoObject* memo_obj)
{
// since objects are destroyed by GC0 drop is only responsible for marking
Expand Down Expand Up @@ -418,6 +429,11 @@ namespace db0::python
return -1;
}

// Avoid infinite recursion if base class tp_setattro is the same as ours
if (py_type->tp_base->tp_setattro == (setattrofunc)PyAPI_MemoObject_setattro<MemoObject>) {
return PyObject_GenericSetAttr((PyObject*)self, attr, value);
}

// Forward to base class setattro
return py_type->tp_base->tp_setattro((PyObject*)self, attr, value);
}
Expand Down Expand Up @@ -466,10 +482,49 @@ namespace db0::python
PyObject * obj_memo = reinterpret_cast<PyObject*>(memo_obj);
// if richcompare is overriden by the python class, call the python class implementation
if (obj_memo->ob_type->tp_base->tp_richcompare != PyType_Type.tp_richcompare) {
// if the base class richcompare is the same as the memo richcompare don't call the base class richcompare
// to avoid infinite recursion
if (obj_memo->ob_type->tp_base->tp_richcompare != (richcmpfunc)PyAPI_MemoObject_rq<MemoImplT>) {
return obj_memo->ob_type->tp_base->tp_richcompare(reinterpret_cast<PyObject*>(memo_obj), other, op);
// Map the integer op to the corresponding string method name
const char* op_name = nullptr;
switch (op) {
case Py_EQ: op_name = "__eq__"; break;
case Py_NE: op_name = "__ne__"; break;
case Py_LT: op_name = "__lt__"; break;
case Py_LE: op_name = "__le__"; break;
case Py_GT: op_name = "__gt__"; break;
case Py_GE: op_name = "__ge__"; break;
}

if (op_name) {
// 1. Look up the method on the BASE type directly.
// This avoids looking at 'obj_memo' and finding the wrapper again.
PyObject* base_type = (PyObject*)obj_memo->ob_type->tp_base;
PyObject* base_impl = PyObject_GetAttrString(base_type, op_name);

// 2. Check if we found a valid method
if (base_impl) {
// Important: Ensure we didn't just grab object.__eq__ wrapper which would loop again.
// We only want to call it if the base class actually overrode it.
// (If base_impl is a wrapper_descriptor, it might trigger the slot logic again)

// However, for a generic Python class, this is usually a 'function'.
// We call it: BaseClass.__eq__(self, other)
PyObject *result = PyObject_CallFunctionObjArgs(base_impl, obj_memo, other, NULL);
Py_DECREF(base_impl);

if (result == NULL) {
return NULL; // Propagate errors raised by the python method
}

// If base class returned NotImplemented, we fall through to our C logic
if (result != Py_NotImplemented) {
return result;
}
Py_DECREF(result);
} else {
// Method doesn't exist on base (e.g. __lt__ not defined), clear error and proceed
PyErr_Clear();
}
}
}
}

Expand Down Expand Up @@ -513,6 +568,54 @@ namespace db0::python
return Py_NEW(PyDict_GetItem(*modules_dict, module_name));
}

// Merge two python dicts, take all keys from first and specific keys from second
PyObject *mergeDicts(PyObject *first, PyObject *second,
std::unordered_set<std::string> secondary_keys = {})
{
auto py_result_dict = Py_OWN(PyDict_New());
PyObject *key, *value;
Py_ssize_t pos = 0;

// copy all items (except secondary keys) from the first dict
while (PyDict_Next(first, &pos, &key, &value)) {
if (secondary_keys.find(PyUnicode_AsUTF8(key)) != secondary_keys.end()) {
continue;
}
PySafeDict_SetItem(*py_result_dict, Py_BORROW(key), Py_BORROW(value));
}

// copy selected items from the second dict
pos = 0;
while (PyDict_Next(second, &pos, &key, &value)) {
if (!secondary_keys.empty()) {
if (secondary_keys.find(PyUnicode_AsUTF8(key)) == secondary_keys.end()) {
continue;
}
}
PySafeDict_SetItem(*py_result_dict, Py_BORROW(key), Py_BORROW(value));
}

return py_result_dict.steal();
}

std::ostream &showKeys(std::ostream &out, PyObject *dict)
{
PyObject *key, *value;
Py_ssize_t pos = 0;

out << "{";
bool first = true;
while (PyDict_Next(dict, &pos, &key, &value)) {
if (!first) {
out << ", ";
}
out << PyUnicode_AsUTF8(key);
first = false;
}
out << "}";
return out;
}

// Copy a python dict
PyObject *copyDict(PyObject *dict, std::unordered_set<std::string> exclude_keys = {})
{
Expand Down Expand Up @@ -624,8 +727,11 @@ namespace db0::python
};

auto bases = Py_OWN(PySafeTuple_Pack(Py_BORROW(base_class)));
auto tp_result = Py_OWN((PyTypeObject*)PyType_FromSpecWithBases(&type_spec, *bases));
(*tp_result)->tp_dict = copyDict(base_class->tp_dict);
auto tp_result = Py_OWN((PyTypeObject*)PyType_FromSpecWithBases(&type_spec, *bases));
auto old_dict = Py_OWN((*tp_result)->tp_dict);
// NOTE: take only the auto-generated __setattr__ and __delattr__
(*tp_result)->tp_dict = mergeDicts(base_class->tp_dict, *old_dict, { "__setattr__", "__delattr__" });

// disable weak-refs (important for Python 3.11.x)
(*tp_result)->tp_weaklistoffset = 0;
#if PY_VERSION_HEX < 0x030B0000 // Python < 3.11
Expand Down Expand Up @@ -667,13 +773,13 @@ namespace db0::python
THROWF(db0::InternalException) << "Variable-length types not supported: " << (*py_class)->tp_name;
}

auto [type_name, full_type_name] = getMemoTypeName(py_class);
auto type_name = getMemoTypeName(py_class);
TypeObjectSharedPtr new_type = nullptr;

// For Python 3.10 compatibility: ensure tp_name string persists beyond this scope
// by using pooled string to avoid segfault due to tp_name pointer being copied literally
auto &type_manager = PyToolkit::getTypeManager();
const char* safe_tp_name = type_manager.getPooledString(full_type_name);
const char* safe_tp_name = type_manager.getPooledString(type_name);

// NOTE: MemoObject and MemoImmutableObject have different implementations
if (immutable) {
Expand All @@ -684,6 +790,12 @@ namespace db0::python
if (!new_type) {
return nullptr;
}

// FIXME:log
// print tp_dict["__setattr__"] of the new type
// PyObject *setattr_method = PyDict_GetItemString((*new_type)->tp_dict, "__setattr__");
// std::cout << "New type __setattr__: " << PyUnicode_AsUTF8(PyObject_Repr(setattr_method)) << std::endl;

MemoFlags type_flags = no_default_tags ? MemoFlags { MemoOptions::NO_DEFAULT_TAGS } : MemoFlags();
if (no_cache) {
type_flags.set(MemoOptions::NO_CACHE);
Expand All @@ -701,7 +813,7 @@ namespace db0::python
py_dyn_prefix_callable,
std::move(migrations)
);

// add to memo type registry
PyToolkit::getTypeManager().addMemoType(*new_type, type_id, std::move(type_info));
// register new type with the module where the original type was located
Expand All @@ -713,7 +825,7 @@ namespace db0::python
PyErr_SetString(PyExc_RuntimeError, "Failed to set __fields__");
return nullptr;
}

return (PyObject*)new_type.steal();
}

Expand Down