Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 195 additions & 14 deletions ml_dtypes/_src/custom_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ limitations under the License.
#include <Python.h>

#include "Eigen/Core"
#include "ml_dtypes/_src/common.h" // NOLINT
#include "ml_dtypes/_src/ufuncs.h" // NOLINT
#include "ml_dtypes/_src/common.h" // NOLINT
#include "ml_dtypes/_src/dtype_compat.h" // NOLINT
#include "ml_dtypes/_src/ufuncs.h" // NOLINT
#include "ml_dtypes/include/complex_types.h"

#undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build
Expand Down Expand Up @@ -69,6 +70,9 @@ struct CustomComplexType {
static PyArray_ArrFuncs arr_funcs;
static PyArray_DescrProto npy_descr_proto;
static PyArray_Descr* npy_descr;

// New-style DType metaclass object.
static PyArray_DTypeMeta dtype_meta;
};

template <typename T>
Expand All @@ -79,6 +83,8 @@ template <typename T>
PyArray_DescrProto CustomComplexType<T>::npy_descr_proto;
template <typename T>
PyArray_Descr* CustomComplexType<T>::npy_descr = nullptr;
template <typename T>
PyArray_DTypeMeta CustomComplexType<T>::dtype_meta = {};

// Representation of a Python custom float object.
template <typename T>
Expand Down Expand Up @@ -904,6 +910,123 @@ bool RegisterComplexUFuncs(PyObject* numpy) {
return ok;
}

// ---------------------------------------------------------------------------
// New-style DType slot functions for CustomComplex types
// ---------------------------------------------------------------------------

template <typename T>
static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/,
char* data) {
return NPyCustomComplex_GetItem<T>(data, /*arr=*/nullptr);
}

template <typename T>
static int NPyCustomComplex_NewStyleSetItem(PyArray_Descr* /*descr*/,
PyObject* item, char* data) {
return NPyCustomComplex_SetItem<T>(item, data, /*arr=*/nullptr);
}

template <typename T>
static PyArray_Descr* NPyCustomComplex_EnsureCanonical(PyArray_Descr* self) {
Py_INCREF(self);
return self;
}

template <typename T>
static PyArray_Descr* NPyCustomComplex_DefaultDescr(PyArray_DTypeMeta* cls) {
Py_INCREF(cls->singleton);
return cls->singleton;
}

template <typename T>
static PyArray_DTypeMeta* NPyCustomComplex_CommonDType(
PyArray_DTypeMeta* cls, PyArray_DTypeMeta* other) {
if (cls == other) {
Py_INCREF(cls);
return cls;
}
// Python abstract scalars defer to the concrete type.
if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType ||
other == &PyArray_PyComplexDType) {
Py_INCREF(cls);
return cls;
}

switch (other->type_num) {
// bool, ints, half, float: wrap in the smallest complex that holds both.
// Our custom complex types all fit in cfloat.
case NPY_BOOL:
case NPY_BYTE: case NPY_SHORT: case NPY_INT:
case NPY_LONG: case NPY_LONGLONG:
case NPY_UBYTE: case NPY_USHORT: case NPY_UINT:
case NPY_ULONG: case NPY_ULONGLONG:
case NPY_HALF: case NPY_FLOAT:
Py_INCREF(reinterpret_cast<PyObject*>(&PyArray_CFloatDType));
return &PyArray_CFloatDType;
case NPY_DOUBLE:
Py_INCREF(reinterpret_cast<PyObject*>(&PyArray_CDoubleDType));
return &PyArray_CDoubleDType;
case NPY_LONGDOUBLE:
Py_INCREF(reinterpret_cast<PyObject*>(&PyArray_CLongDoubleDType));
return &PyArray_CLongDoubleDType;
// Built-in complex: our types are smaller, return other.
case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE:
Py_INCREF(other);
return other;

default:
break;
}

// ---- Our own custom DTypes ----
// Custom float or custom int: all fit in cfloat alongside our complex.
if (other == &CustomFloatType<bfloat16>::dtype_meta ||
other == &CustomFloatType<float8_e3m4>::dtype_meta ||
other == &CustomFloatType<float8_e4m3>::dtype_meta ||
other == &CustomFloatType<float8_e4m3b11fnuz>::dtype_meta ||
other == &CustomFloatType<float8_e4m3fn>::dtype_meta ||
other == &CustomFloatType<float8_e4m3fnuz>::dtype_meta ||
other == &CustomFloatType<float8_e5m2>::dtype_meta ||
other == &CustomFloatType<float8_e5m2fnuz>::dtype_meta ||
other == &CustomFloatType<float6_e2m3fn>::dtype_meta ||
other == &CustomFloatType<float6_e3m2fn>::dtype_meta ||
other == &CustomFloatType<float4_e2m1fn>::dtype_meta ||
other == &CustomFloatType<float8_e8m0fnu>::dtype_meta ||
other == &IntNTypeDescriptor<int1>::dtype_meta ||
other == &IntNTypeDescriptor<uint1>::dtype_meta ||
other == &IntNTypeDescriptor<int2>::dtype_meta ||
other == &IntNTypeDescriptor<uint2>::dtype_meta ||
other == &IntNTypeDescriptor<int4>::dtype_meta ||
other == &IntNTypeDescriptor<uint4>::dtype_meta) {
Py_INCREF(reinterpret_cast<PyObject*>(&PyArray_CFloatDType));
return &PyArray_CFloatDType;
}

// Another custom complex: both fit in cfloat.
if (other == &CustomComplexType<bcomplex32>::dtype_meta ||
other == &CustomComplexType<complex32>::dtype_meta) {
if (cls->type_num < other->type_num) {
Py_INCREF(Py_NotImplemented);
return reinterpret_cast<PyArray_DTypeMeta*>(Py_NotImplemented);
}
Py_INCREF(reinterpret_cast<PyObject*>(&PyArray_CFloatDType));
return &PyArray_CFloatDType;
}

// Unknown user type: return NotImplemented.
Py_INCREF(Py_NotImplemented);
return reinterpret_cast<PyArray_DTypeMeta*>(Py_NotImplemented);
}

template <typename T>
static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) {
return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor<T>::kTypeName);
}
template <typename T>
static PyObject* NPyCustomComplex_DTypeStr(PyObject* /*self*/) {
return PyUnicode_FromString(TypeDescriptor<T>::kTypeName);
}

template <typename T>
bool RegisterComplexDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
Expand All @@ -928,7 +1051,7 @@ bool RegisterComplexDtype(PyObject* numpy) {
return false;
}

// Initializes the NumPy descriptor.
// Initializes the NumPy ArrFuncs (used by legacy code paths after the swap).
PyArray_ArrFuncs& arr_funcs = CustomComplexType<T>::arr_funcs;
PyArray_InitArrFuncs(&arr_funcs);
arr_funcs.getitem = NPyCustomComplex_GetItem<T>;
Expand All @@ -937,29 +1060,87 @@ bool RegisterComplexDtype(PyObject* numpy) {
arr_funcs.copyswapn = NPyCustomComplex_CopySwapN<T>;
arr_funcs.copyswap = NPyCustomComplex_CopySwap<T>;
arr_funcs.nonzero = NPyCustomComplex_NonZero<T>;
arr_funcs.fill = nullptr; // NPyCustomComplex_Fill<T>;
arr_funcs.fill = nullptr;
arr_funcs.dotfunc = NPyCustomComplex_DotFunc<T>;
arr_funcs.compare = NPyCustomComplex_CompareFunc<T>;
arr_funcs.argmax = nullptr; // NumPy defines them, but it's shaky
arr_funcs.argmax = nullptr;
arr_funcs.argmin = nullptr;

// This is messy, but that's because the NumPy 2.0 API transition is messy.
// Before 2.0, NumPy assumes we'll keep the descriptor passed in to
// RegisterDataType alive, because it stores its pointer.
// After 2.0, the proto and descriptor types diverge, and NumPy allocates
// and manages the lifetime of the descriptor itself.
// Prepare the legacy proto.
PyArray_DescrProto& descr_proto = CustomComplexType<T>::npy_descr_proto;
descr_proto = GetCustomComplexDescrProto<T>();
Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type);
descr_proto.typeobj = reinterpret_cast<PyTypeObject*>(type);
descr_proto.f = &arr_funcs;

PyArray_DTypeMeta& dm = CustomComplexType<T>::dtype_meta;
Py_SET_REFCNT(&dm, 1);
auto* tp = reinterpret_cast<PyTypeObject*>(&dm);
tp->tp_name = TypeDescriptor<T>::kTypeName;
tp->tp_base = &PyArrayDescr_Type;
tp->tp_flags = Py_TPFLAGS_DEFAULT;
tp->tp_repr = NPyCustomComplex_DTypeRepr<T>;
tp->tp_str = NPyCustomComplex_DTypeStr<T>;
if (PyType_Ready(tp) < 0) {
return false;
}

TypeDescriptor<T>::npy_type = PyArray_RegisterDataType(&descr_proto);
if (TypeDescriptor<T>::npy_type < 0) {
// Build the within-dtype self-cast spec.
PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr};
PyType_Slot self_cast_slots[] = {
{NPY_METH_strided_loop,
reinterpret_cast<void*>(TrivialStridedCopyLoop)},
{NPY_METH_unaligned_strided_loop,
reinterpret_cast<void*>(TrivialStridedCopyLoop)},
{0, nullptr}};
PyArrayMethod_Spec self_cast_spec;
self_cast_spec.name = "copy";
self_cast_spec.nin = 1;
self_cast_spec.nout = 1;
self_cast_spec.casting = NPY_NO_CASTING;
self_cast_spec.flags = static_cast<NPY_ARRAYMETHOD_FLAGS>(
NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS);
self_cast_spec.dtypes = self_cast_dtypes;
self_cast_spec.slots = self_cast_slots;
PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr};

// Build the new-style DType spec.
PyType_Slot dtype_slots[] = {
{NPY_DT_legacy_descriptor_proto,
reinterpret_cast<void*>(&descr_proto)},
{NPY_DT_getitem,
reinterpret_cast<void*>(NPyCustomComplex_NewStyleGetItem<T>)},
{NPY_DT_setitem,
reinterpret_cast<void*>(NPyCustomComplex_NewStyleSetItem<T>)},
{NPY_DT_ensure_canonical,
reinterpret_cast<void*>(NPyCustomComplex_EnsureCanonical<T>)},
{NPY_DT_default_descr,
reinterpret_cast<void*>(NPyCustomComplex_DefaultDescr<T>)},
{NPY_DT_common_dtype,
reinterpret_cast<void*>(NPyCustomComplex_CommonDType<T>)},
{NPY_DT_PyArray_ArrFuncs_getitem,
reinterpret_cast<void*>(NPyCustomComplex_GetItem<T>)},
{NPY_DT_PyArray_ArrFuncs_setitem,
reinterpret_cast<void*>(NPyCustomComplex_SetItem<T>)},
{NPY_DT_PyArray_ArrFuncs_nonzero,
reinterpret_cast<void*>(NPyCustomComplex_NonZero<T>)},
{NPY_DT_PyArray_ArrFuncs_dotfunc,
reinterpret_cast<void*>(NPyCustomComplex_DotFunc<T>)},
{NPY_DT_PyArray_ArrFuncs_compare,
reinterpret_cast<void*>(NPyCustomComplex_CompareFunc<T>)},
{0, nullptr}};
PyArrayDTypeMeta_Spec dtype_spec;
dtype_spec.typeobj = reinterpret_cast<PyTypeObject*>(type);
dtype_spec.flags = 0;
dtype_spec.casts = casts;
dtype_spec.slots = dtype_slots;
dtype_spec.baseclass = nullptr;

if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) {
return false;
}
TypeDescriptor<T>::npy_type = dm.type_num;

// TODO(phawkins): We intentionally leak the pointer to the descriptor.
// Implement a better module destructor to handle this.
CustomComplexType<T>::npy_descr =
PyArray_DescrFromType(TypeDescriptor<T>::npy_type);

Expand Down
Loading