From c5bbff76a48de2790f183ebf97888c0320b62b5d Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Fri, 17 Apr 2026 19:09:52 +0200 Subject: [PATCH 1/6] WIP: Try, mostly based on claude --- ml_dtypes/_src/custom_complex.h | 123 ++++++++++++++++++++++++---- ml_dtypes/_src/custom_float.h | 138 +++++++++++++++++++++++++++++--- ml_dtypes/_src/intn_numpy.h | 129 +++++++++++++++++++++++++---- ml_dtypes/_src/numpy.h | 2 + 4 files changed, 352 insertions(+), 40 deletions(-) diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..f106f245 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/complex_types.h" #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build @@ -69,6 +70,9 @@ struct CustomComplexType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -79,6 +83,8 @@ template PyArray_DescrProto CustomComplexType::npy_descr_proto; template PyArray_Descr* CustomComplexType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomComplexType::dtype_meta = {}; // Representation of a Python custom float object. template @@ -904,6 +910,39 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomComplex types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} + +template +static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomComplex_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyCustomComplex_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomComplex_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyCustomComplex_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyCustomComplex_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -928,7 +967,7 @@ bool RegisterComplexDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomComplexType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomComplex_GetItem; @@ -937,29 +976,85 @@ bool RegisterComplexDtype(PyObject* numpy) { arr_funcs.copyswapn = NPyCustomComplex_CopySwapN; arr_funcs.copyswap = NPyCustomComplex_CopySwap; arr_funcs.nonzero = NPyCustomComplex_NonZero; - arr_funcs.fill = nullptr; // NPyCustomComplex_Fill; + arr_funcs.fill = nullptr; arr_funcs.dotfunc = NPyCustomComplex_DotFunc; arr_funcs.compare = NPyCustomComplex_CompareFunc; - arr_funcs.argmax = nullptr; // NumPy defines them, but it's shaky + arr_funcs.argmax = nullptr; arr_funcs.argmin = nullptr; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto. PyArray_DescrProto& descr_proto = CustomComplexType::npy_descr_proto; descr_proto = GetCustomComplexDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + // Set up the DTypeMeta. + PyArray_DTypeMeta& dm = CustomComplexType::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomComplex_DTypeRepr; + tp->tp_str = NPyCustomComplex_DTypeRepr; + if (PyType_Ready(tp) < 0) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_getitem, + reinterpret_cast(NPyCustomComplex_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomComplex_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomComplex_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomComplex_DefaultDescr)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyCustomComplex_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyCustomComplex_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyCustomComplex_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyCustomComplex_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyCustomComplex_CompareFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, + &descr_proto) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index bf2568a7..e90dc2b8 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build // Possible this has to do with numpy.h being included before @@ -66,6 +67,10 @@ struct CustomFloatType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. Zero-initialized; fields are filled in + // at registration time before PyType_Ready is called. + static PyArray_DTypeMeta dtype_meta; }; template @@ -76,6 +81,8 @@ template PyArray_DescrProto CustomFloatType::npy_descr_proto; template PyArray_Descr* CustomFloatType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomFloatType::dtype_meta = {}; // Representation of a Python custom float object. template @@ -841,6 +848,47 @@ bool RegisterFloatUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomFloat types +// --------------------------------------------------------------------------- + +// tp_repr / tp_str for the DTypeMeta itself (required by +// PyArrayInitDTypeMeta_FromSpec; must differ from PyArrayDescr_Type's). +template +static PyObject* NPyCustomFloat_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} + +// New-style getitem: (PyArray_Descr*, char*) -> PyObject* +template +static PyObject* NPyCustomFloat_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomFloat_GetItem(data, /*arr=*/nullptr); +} + +// New-style setitem: (PyArray_Descr*, PyObject*, char*) -> int +template +static int NPyCustomFloat_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomFloat_SetItem(item, data, /*arr=*/nullptr); +} + +// ensure_canonical: for a non-parametric dtype just return self. +template +static PyArray_Descr* NPyCustomFloat_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +// default_descr: return the singleton. +// This avoids use_new_as_default calling dm() -> arraydescr_new, which fails +// for legacy-flagged DTypes because the legacy-check branch errors out. +template +static PyArray_Descr* NPyCustomFloat_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -865,7 +913,7 @@ bool RegisterFloatDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomFloatType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomFloat_GetItem; @@ -880,23 +928,89 @@ bool RegisterFloatDtype(PyObject* numpy) { arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc; arr_funcs.argmin = NPyCustomFloat_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto (for PyArrayInitDTypeMeta_FromSpec_WithLegacy). PyArray_DescrProto& descr_proto = CustomFloatType::npy_descr_proto; descr_proto = GetCustomFloatDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + // Set up the DTypeMeta. It must subclass PyArrayDescr_Type and have a + // metaclass of PyArrayDTypeMeta_Type (inherited via PyType_Ready from + // PyArrayDescr_Type.ob_type). + PyArray_DTypeMeta& dm = CustomFloatType::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomFloat_DTypeRepr; + tp->tp_str = NPyCustomFloat_DTypeRepr; + if (PyType_Ready(tp) < 0) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_getitem, + reinterpret_cast(NPyCustomFloat_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomFloat_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomFloat_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomFloat_DefaultDescr)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyCustomFloat_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyCustomFloat_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyCustomFloat_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_fill, + reinterpret_cast(NPyCustomFloat_Fill)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyCustomFloat_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyCustomFloat_CompareFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmax, + reinterpret_cast(NPyCustomFloat_ArgMaxFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmin, + reinterpret_cast(NPyCustomFloat_ArgMinFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, + &descr_proto) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. + // The singleton is owned by dm; grab a borrowed reference for npy_descr. + // PyArray_DescrFromType returns a new reference — intentionally leaked. CustomFloatType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 8e32a63c..41e673dc 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -25,8 +25,9 @@ limitations under the License. // clang-format on #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/intn.h" #if NPY_ABI_VERSION < 0x02000000 @@ -56,6 +57,9 @@ struct IntNTypeDescriptor { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -66,6 +70,8 @@ template PyArray_DescrProto IntNTypeDescriptor::npy_descr_proto; template PyArray_Descr* IntNTypeDescriptor::npy_descr = nullptr; +template +PyArray_DTypeMeta IntNTypeDescriptor::dtype_meta = {}; // Representation of a Python custom integer object. template @@ -774,6 +780,38 @@ bool RegisterIntNUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for IntN types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyIntN_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} + +template +static PyObject* NPyIntN_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { + return NPyIntN_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyIntN_NewStyleSetItem(PyArray_Descr* /*descr*/, PyObject* item, + char* data) { + return NPyIntN_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyIntN_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyIntN_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -799,7 +837,7 @@ bool RegisterIntNDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = IntNTypeDescriptor::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyIntN_GetItem; @@ -814,22 +852,85 @@ bool RegisterIntNDtype(PyObject* numpy) { arr_funcs.argmax = NPyIntN_ArgMaxFunc; arr_funcs.argmin = NPyIntN_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto. PyArray_DescrProto& descr_proto = IntNTypeDescriptor::npy_descr_proto; descr_proto = GetIntNDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + // Set up the DTypeMeta. + PyArray_DTypeMeta& dm = IntNTypeDescriptor::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyIntN_DTypeRepr; + tp->tp_str = NPyIntN_DTypeRepr; + if (PyType_Ready(tp) < 0) { + return false; + } + + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_getitem, + reinterpret_cast(NPyIntN_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyIntN_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyIntN_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyIntN_DefaultDescr)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyIntN_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyIntN_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyIntN_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_fill, + reinterpret_cast(NPyIntN_Fill)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyIntN_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyIntN_CompareFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmax, + reinterpret_cast(NPyIntN_ArgMaxFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmin, + reinterpret_cast(NPyIntN_ArgMinFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, + &descr_proto) < 0) { + return false; + } + TypeDescriptor::npy_type = dm.type_num; - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { - return false; - } - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. IntNTypeDescriptor::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..f9946d42 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -22,6 +22,8 @@ limitations under the License. // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#define NPY_TARGET_VERSION NPY_2_0_API_VERSION + // We import_array in the ml_dtypes init function only. #define PY_ARRAY_UNIQUE_SYMBOL _ml_dtypes_numpy_api From 8de28f30579733bc22e6b58118178a87699a989e Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 20 Apr 2026 13:25:16 +0200 Subject: [PATCH 2/6] Add hack to use new DType path but still remain largely "old". --- ml_dtypes/_src/custom_complex.h | 87 +++++++++- ml_dtypes/_src/custom_float.h | 109 +++++++++++- ml_dtypes/_src/dtype_compat.h | 177 +++++++++++++++++++ ml_dtypes/_src/intn_numpy.h | 63 ++++++- ml_dtypes/tests/result_type_test.py | 253 ++++++++++++++++++++++++++++ 5 files changed, 686 insertions(+), 3 deletions(-) create mode 100644 ml_dtypes/_src/dtype_compat.h create mode 100644 ml_dtypes/tests/result_type_test.py diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index f106f245..11775a9b 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -919,6 +919,11 @@ static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) { return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); } +template +static PyObject* NPyCustomComplex_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { @@ -943,6 +948,84 @@ static PyArray_Descr* NPyCustomComplex_DefaultDescr(PyArray_DTypeMeta* cls) { return cls->singleton; } +template +static PyArray_DTypeMeta* NPyCustomComplex_CommonDType( + PyArray_DTypeMeta* cls, PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType || + other == &PyArray_PyComplexDType) { + Py_INCREF(cls); + return cls; + } + + switch (other->type_num) { + // bool, ints, half, float: wrap in the smallest complex that holds both. + // Our custom complex types all fit in cfloat. + case NPY_BOOL: + case NPY_BYTE: case NPY_SHORT: case NPY_INT: + case NPY_LONG: case NPY_LONGLONG: + case NPY_UBYTE: case NPY_USHORT: case NPY_UINT: + case NPY_ULONG: case NPY_ULONGLONG: + case NPY_HALF: case NPY_FLOAT: + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + case NPY_DOUBLE: case NPY_LONGDOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CDoubleDType)); + return &PyArray_CDoubleDType; + + // Built-in complex: our types are smaller, return other. + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + + default: + break; + } + + // ---- Our own custom DTypes ---- + // Custom float or custom int: all fit in cfloat alongside our complex. + if (other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Another custom complex: both fit in cfloat. + if (other == &CustomComplexType::dtype_meta || + other == &CustomComplexType::dtype_meta) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Unknown user type: return NotImplemented. + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -997,7 +1080,7 @@ bool RegisterComplexDtype(PyObject* numpy) { tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; tp->tp_repr = NPyCustomComplex_DTypeRepr; - tp->tp_str = NPyCustomComplex_DTypeRepr; + tp->tp_str = NPyCustomComplex_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -1031,6 +1114,8 @@ bool RegisterComplexDtype(PyObject* numpy) { reinterpret_cast(NPyCustomComplex_EnsureCanonical)}, {NPY_DT_default_descr, reinterpret_cast(NPyCustomComplex_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomComplex_CommonDType)}, {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyCustomComplex_GetItem)}, {NPY_DT_PyArray_ArrFuncs_setitem, diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index e90dc2b8..182d4e04 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -859,6 +859,11 @@ static PyObject* NPyCustomFloat_DTypeRepr(PyObject* /*self*/) { return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); } +template +static PyObject* NPyCustomFloat_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + // New-style getitem: (PyArray_Descr*, char*) -> PyObject* template static PyObject* NPyCustomFloat_NewStyleGetItem(PyArray_Descr* /*descr*/, @@ -889,6 +894,106 @@ static PyArray_Descr* NPyCustomFloat_DefaultDescr(PyArray_DTypeMeta* cls) { return cls->singleton; } +// True if every value of Src is exactly representable in Dst: Dst must have +// at least as many mantissa bits (precision) and at least as much exponent +// range as Src. +template +static constexpr bool CustomFloatSafeTo() { + return std::numeric_limits::digits >= std::numeric_limits::digits && + std::numeric_limits::max_exponent >= + std::numeric_limits::max_exponent; +} + +template +static PyArray_DTypeMeta* NPyCustomFloat_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. (should add complex here) + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) { + Py_INCREF(cls); + return cls; + } + + constexpr bool is_bfloat16 = std::is_same_v; + + if (PyTypeNum_ISINTEGER(other->type_num)) { + if (is_bfloat16 && (other->type_num == NPY_BYTE || other->type_num == NPY_UBYTE)) { + Py_INCREF(cls); + return cls; + } + /* Our precision is irrelevant, the integer one is higher. */ + return PyArray_CommonDType(&PyArray_PyFloatDType, other); + } + + switch (other->type_num) { + case NPY_BOOL: + Py_INCREF(cls); + return cls; + case NPY_HALF: + if (is_bfloat16) { + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); + return &PyArray_FloatDType; + } + [[fallthrough]]; + case NPY_FLOAT: case NPY_DOUBLE: case NPY_LONGDOUBLE: + [[fallthrough]]; + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + default: + break; + } + + // ---- Our own custom DTypes ---- + // Another custom float: use compile-time safe-cast predicate to pick the + // wider type; fall back to float32 when neither contains the other. + // T is known at compile time so all CustomFloatSafeTo calls fold away. +#define TRY_CUSTOM_FLOAT(OtherT) \ + if (other == &CustomFloatType::dtype_meta) { \ + if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(other); return other; \ + } else if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(cls); return cls; \ + } else { \ + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); \ + return &PyArray_FloatDType; \ + } \ + } + TRY_CUSTOM_FLOAT(bfloat16) + TRY_CUSTOM_FLOAT(float8_e3m4) + TRY_CUSTOM_FLOAT(float8_e4m3) + TRY_CUSTOM_FLOAT(float8_e4m3b11fnuz) + TRY_CUSTOM_FLOAT(float8_e4m3fn) + TRY_CUSTOM_FLOAT(float8_e4m3fnuz) + TRY_CUSTOM_FLOAT(float8_e5m2) + TRY_CUSTOM_FLOAT(float8_e5m2fnuz) + TRY_CUSTOM_FLOAT(float6_e2m3fn) + TRY_CUSTOM_FLOAT(float6_e3m2fn) + TRY_CUSTOM_FLOAT(float4_e2m1fn) + TRY_CUSTOM_FLOAT(float8_e8m0fnu) +#undef TRY_CUSTOM_FLOAT + + // Custom int: float dominates. NPyIntN_CommonDType returns NotImplemented + // for user types it can't see, so we handle this side explicitly. + if (other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + Py_INCREF(cls); + return cls; + } + + // Custom complex or unknown user type: swapping will work (NPyCustomComplex + // handles complex+float and returns the appropriate complex result). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -945,7 +1050,7 @@ bool RegisterFloatDtype(PyObject* numpy) { tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; tp->tp_repr = NPyCustomFloat_DTypeRepr; - tp->tp_str = NPyCustomFloat_DTypeRepr; + tp->tp_str = NPyCustomFloat_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -979,6 +1084,8 @@ bool RegisterFloatDtype(PyObject* numpy) { reinterpret_cast(NPyCustomFloat_EnsureCanonical)}, {NPY_DT_default_descr, reinterpret_cast(NPyCustomFloat_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomFloat_CommonDType)}, {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyCustomFloat_GetItem)}, {NPY_DT_PyArray_ArrFuncs_setitem, diff --git a/ml_dtypes/_src/dtype_compat.h b/ml_dtypes/_src/dtype_compat.h new file mode 100644 index 00000000..7a492940 --- /dev/null +++ b/ml_dtypes/_src/dtype_compat.h @@ -0,0 +1,177 @@ +/* Copyright 2025 The ml_dtypes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + * Compatibility helper for registering DTypes that work with both the new-style + * NumPy DType API (PyArrayInitDTypeMeta_FromSpec) and the legacy + * PyArray_RegisterDataType path. + * + * The PyArrayInitDTypeMeta_FromSpec_WithLegacy function is a backport hack for + * NumPy 2.0. Starting with NumPy 2.5/2.6 (TBD) there will be native NumPy + * support and this can be replaced. + * + * NOTE: This file requires NumPy >= 2.0. + */ + +#ifndef ML_DTYPES_DTYPE_COMPAT_H_ +#define ML_DTYPES_DTYPE_COMPAT_H_ + +// clang-format off +#include "ml_dtypes/_src/numpy.h" // NOLINT (must be first) +// clang-format on + +#include +#include +#include "numpy/arrayobject.h" +#include "numpy/dtype_api.h" + +#if NPY_ABI_VERSION < 0x02000000 +#error "ml_dtypes dtype_compat.h requires NumPy >= 2.0" +#endif + +namespace ml_dtypes { + +/* + * Within-dtype copy strided loop: a plain memcpy per element. + * Registered as the within_dtype_castingimpl for every fixed-size dtype. + * MUST carry NPY_METH_SUPPORTS_UNALIGNED (NumPy enforces this for self-casts). + */ +static inline int TrivialStridedCopyLoop(PyArrayMethod_Context *context, + char *const data[], + npy_intp const dimensions[], + npy_intp const strides[], + NpyAuxData * /*auxdata*/) { + const npy_intp N = dimensions[0]; + const npy_intp elsize = context->descriptors[0]->elsize; + const char *in = data[0]; + char *out = data[1]; + for (npy_intp i = 0; i < N; ++i) { + std::memcpy(out, in, elsize); + in += strides[0]; + out += strides[1]; + } + return 0; +} + +/* + * PyArrayInitDTypeMeta_FromSpec_WithLegacy + * + * Initialises a new-style user DType (via PyArrayInitDTypeMeta_FromSpec) while + * also plumbing in legacy compatibility so that NumPy < 2.5 assigns a + * type_num, singleton, and legacy flag. + * + * Algorithm (for NumPy 2.0 – 2.4): + * + * Step 1 – Legacy registration (for type_num + singleton allocation): + * Temporarily replace proto->typeobj with &PyBaseObject_Type so that + * _PyArray_MapPyTypeToDType sees a non-generic type and hits the + * NPY_DT_is_legacy bail-out, meaning the auto-DTypeMeta created by + * PyArray_RegisterDataType is NOT inserted into the pytype-to-DType dict. + * + * Step 2 – New-style init: + * PyArrayInitDTypeMeta_FromSpec sets up slots, casts, and the + * pytype-to-DType mapping for the real scalar type with no conflict. + * + * Step 3 – Swap: + * Steal type_num + singleton from the old legacy registration and point + * the singleton at the user's new DType. Set the legacy flag so NumPy + * uses legacy descriptor code paths where needed. + * + * If proto is NULL the function just forwards to PyArrayInitDTypeMeta_FromSpec. + */ +static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( + PyArray_DTypeMeta *DType, PyArrayDTypeMeta_Spec *spec, + PyArray_DescrProto *proto) { + if (proto == nullptr) { + return PyArrayInitDTypeMeta_FromSpec(DType, spec); + } + + /* + * Step 1: Register old-style with a garbage typeobj so that + * _PyArray_MapPyTypeToDType does NOT add the auto-DTypeMeta to the + * pytype-to-DType dict (it bails out on NPY_DT_is_legacy for non-generic + * types), regardless of whether the real scalar subclasses np.generic. + */ + PyTypeObject *real_typeobj = proto->typeobj; + proto->typeobj = &PyBaseObject_Type; + + int typenum = PyArray_RegisterDataType(proto); + + proto->typeobj = real_typeobj; + if (typenum < 0) { + return -1; + } + + /* + * Step 2: Initialise the user's DType with new-style slots and casts. + * type_num stays at -1 / 0 for now; we fix it in step 3. + */ + if (PyArrayInitDTypeMeta_FromSpec(DType, spec) < 0) { + return -1; + } + + /* + * Step 3: Steal the singleton descriptor and type_num from the legacy + * registration. Point the descriptor's Python type at the user's DType + * and fix up its typeobj field (which we temporarily set to PyBaseObject_Type + * in step 1). + */ + PyArray_Descr *descr = PyArray_DescrFromType(typenum); + if (descr == nullptr) { + return -1; + } + + /* Save the auto-DTypeMeta so we can decref it after the swap. */ + PyObject *old_meta = reinterpret_cast(Py_TYPE(descr)); + + DType->type_num = typenum; + /* PyArray_DescrFromType returns a new reference; transfer ownership. */ + DType->singleton = descr; + /* Set the legacy flag (bit 0 == _NPY_DT_LEGACY_FLAG) so NumPy uses legacy + * code paths (copyswap, ArrFuncs, etc.) where the new-style API doesn't + * cover them yet. */ + DType->flags |= 1; + + /* Re-type the descriptor so it belongs to the user's DType class. */ + Py_INCREF(reinterpret_cast(DType)); + Py_SET_TYPE(descr, reinterpret_cast(DType)); + Py_DECREF(old_meta); + + /* Fix the descriptor's scalar-type field (it was set to PyBaseObject_Type + * in step 1 by PyArray_RegisterDataType copying proto->typeobj). */ + Py_INCREF(real_typeobj); + Py_XDECREF(descr->typeobj); + descr->typeobj = real_typeobj; + + /* + * Patch copyswap/copyswapn into the new DType's legacy f-slots. + * + * copyswap and copyswapn are disabled as public NPY_DT_PyArray_ArrFuncs_* + * spec slots (commented out in dtype_api.h), so dtypemeta_initialize_struct_ + * from_spec leaves them as stubs from default_funcs. PyArray_Scalar and + * other legacy paths call copyswap directly, so we must fill it in. + */ + if (proto->f != nullptr) { + PyArray_ArrFuncs *f = _PyDataType_GetArrFuncs(descr); + f->copyswap = proto->f->copyswap; + f->copyswapn = proto->f->copyswapn; + } + + return 0; +} + +} // namespace ml_dtypes + +#endif // ML_DTYPES_DTYPE_COMPAT_H_ diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 41e673dc..c915aced 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -789,6 +789,11 @@ static PyObject* NPyIntN_DTypeRepr(PyObject* /*self*/) { return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); } +template +static PyObject* NPyIntN_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template static PyObject* NPyIntN_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { return NPyIntN_GetItem(data, /*arr=*/nullptr); @@ -812,6 +817,60 @@ static PyArray_Descr* NPyIntN_DefaultDescr(PyArray_DTypeMeta* cls) { return cls->singleton; } +template +static PyArray_DTypeMeta* NPyIntN_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) { + Py_INCREF(cls); + return cls; + } + else if (other == &PyArray_PyFloatDType) { + Py_INCREF(&PyArray_DoubleDType); + return &PyArray_DoubleDType; + } + else if (other == &PyArray_PyComplexDType) { + Py_INCREF(&PyArray_CDoubleDType); + return &PyArray_CDoubleDType; + } + + // Our intN types are smaller than every NumPy built-in except bool. + if (other->type_num == NPY_BOOL) { + Py_INCREF(cls); + return cls; + } + if (!PyTypeNum_ISUSERDEF(other->type_num)) { + Py_INCREF(other); + return other; + } + + // ---- Our own custom DTypes ---- + // Another custom int: lower type_num defers (swapping will work). + if (other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + // No cross-custom-int safe casts are registered; int16 contains all. + Py_INCREF(reinterpret_cast(&PyArray_Int16DType)); + return &PyArray_Int16DType; + } + + // Custom float or custom complex: swapping will work (NPyCustomFloat handles + // float+int, NPyCustomComplex handles complex+int). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -867,7 +926,7 @@ bool RegisterIntNDtype(PyObject* numpy) { tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; tp->tp_repr = NPyIntN_DTypeRepr; - tp->tp_str = NPyIntN_DTypeRepr; + tp->tp_str = NPyIntN_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -901,6 +960,8 @@ bool RegisterIntNDtype(PyObject* numpy) { reinterpret_cast(NPyIntN_EnsureCanonical)}, {NPY_DT_default_descr, reinterpret_cast(NPyIntN_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyIntN_CommonDType)}, {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyIntN_GetItem)}, {NPY_DT_PyArray_ArrFuncs_setitem, diff --git a/ml_dtypes/tests/result_type_test.py b/ml_dtypes/tests/result_type_test.py new file mode 100644 index 00000000..f4deee79 --- /dev/null +++ b/ml_dtypes/tests/result_type_test.py @@ -0,0 +1,253 @@ +# Copyright 2026 The ml_dtypes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for np.result_type() across ml_dtypes custom DTypes.""" + +import ml_dtypes +import numpy as np +import pytest + + +bf16 = ml_dtypes.bfloat16 +f4 = ml_dtypes.float4_e2m1fn +f8_e4m3 = ml_dtypes.float8_e4m3fn +f8_e5m2 = ml_dtypes.float8_e5m2 +bc32 = ml_dtypes.bcomplex32 +c32 = ml_dtypes.complex32 +i4 = ml_dtypes.int4 +ui4 = ml_dtypes.uint4 +i2 = ml_dtypes.int2 + + +def rt(a, b): + return np.result_type(a, b) + + +# --------------------------------------------------------------------------- +# Custom float + NumPy built-in +# --------------------------------------------------------------------------- + +class TestCustomFloatVsNumpy: + + def test_float8_plus_float16_gives_float16(self): + assert rt(f8_e4m3, np.float16) == np.dtype(np.float16) + + def test_bfloat16_plus_float16_gives_float32(self): + # Neither fits the other: bfloat16 has 7 mantissa bits, float16 has wider + # exponent; float32 contains both. + assert rt(bf16, np.float16) == np.dtype(np.float32) + + def test_float8_plus_float32_gives_float32(self): + assert rt(f8_e4m3, np.float32) == np.dtype(np.float32) + + def test_float8_plus_float64_gives_float64(self): + assert rt(f8_e4m3, np.float64) == np.dtype(np.float64) + + def test_float8_plus_bool_gives_float8(self): + assert rt(f8_e4m3, np.bool_) == np.dtype(f8_e4m3) + + def test_bfloat16_plus_int8_gives_bfloat16(self): + # bfloat16 has enough precision to represent all int8 values. + assert rt(bf16, np.int8) == np.dtype(bf16) + + def test_float8_plus_int8_gives_float64(self): + # float8 cannot represent all int8 values; PyArray_CommonDType defers + # to the integer's required precision → float64. + assert rt(f8_e4m3, np.int8) == np.dtype(np.float64) + + def test_float8_plus_complex64_gives_complex64(self): + assert rt(f8_e4m3, np.complex64) == np.dtype(np.complex64) + + +# --------------------------------------------------------------------------- +# Custom float + custom float +# --------------------------------------------------------------------------- + +class TestCustomFloatVsCustomFloat: + + def test_float8_fits_in_bfloat16(self): + # float8_e4m3fn has fewer bits than bfloat16 in every dimension. + assert rt(f8_e4m3, bf16) == np.dtype(bf16) + + def test_float8_e5m2_fits_in_bfloat16(self): + # float8_e5m2 has less precision but same exponent range; bfloat16 wins. + assert rt(f8_e5m2, bf16) == np.dtype(bf16) + + def test_float4_fits_in_float8(self): + # float4_e2m1fn has fewer exp and mantissa bits than float8_e4m3fn. + assert rt(f4, f8_e4m3) == np.dtype(f8_e4m3) + + def test_float8_e4m3_vs_float8_e5m2_gives_float32(self): + # Incomparable: e4m3 has more mantissa, e5m2 has more exponent → float32. + assert rt(f8_e4m3, f8_e5m2) == np.dtype(np.float32) + + def test_same_type_gives_same_type(self): + assert rt(f8_e4m3, f8_e4m3) == np.dtype(f8_e4m3) + assert rt(bf16, bf16) == np.dtype(bf16) + + +# --------------------------------------------------------------------------- +# Custom float + custom int +# --------------------------------------------------------------------------- + +class TestCustomFloatVsCustomInt: + + def test_float_beats_int4(self): + assert rt(f8_e4m3, i4) == np.dtype(f8_e4m3) + assert rt(bf16, i4) == np.dtype(bf16) + + def test_symmetry(self): + assert rt(i4, f8_e4m3) == rt(f8_e4m3, i4) + assert rt(i4, bf16) == rt(bf16, i4) + + +# --------------------------------------------------------------------------- +# Custom int + NumPy built-in +# --------------------------------------------------------------------------- + +class TestCustomIntVsNumpy: + + def test_int4_plus_bool_gives_int4(self): + assert rt(i4, np.bool_) == np.dtype(i4) + + def test_int4_plus_int8_gives_int8(self): + assert rt(i4, np.int8) == np.dtype(np.int8) + + def test_int4_plus_int16_gives_int16(self): + assert rt(i4, np.int16) == np.dtype(np.int16) + + def test_int4_plus_float16_gives_float16(self): + assert rt(i4, np.float16) == np.dtype(np.float16) + + def test_int4_plus_float32_gives_float32(self): + assert rt(i4, np.float32) == np.dtype(np.float32) + + def test_int4_plus_complex64_gives_complex64(self): + assert rt(i4, np.complex64) == np.dtype(np.complex64) + + def test_symmetry(self): + for numpy_t in [np.int8, np.int16, np.float32, np.complex64]: + assert rt(i4, numpy_t) == rt(numpy_t, i4) + + +# --------------------------------------------------------------------------- +# Custom int + custom int +# --------------------------------------------------------------------------- + +class TestCustomIntVsCustomInt: + + def test_int4_plus_uint4_gives_int16(self): + # Signed + unsigned 4-bit: neither fits the other → int16. + assert rt(i4, ui4) == np.dtype(np.int16) + + def test_int2_plus_int4_gives_int16(self): + # int2 < int4 by type_num; int4 handles and int2 fits in int4 → int16? + # Actually both are narrow custom ints; falls back to int16. + assert rt(ml_dtypes.int2, i4) == np.dtype(np.int16) + + def test_same_type_gives_same_type(self): + assert rt(i4, i4) == np.dtype(i4) + + +# --------------------------------------------------------------------------- +# Custom complex + NumPy built-in +# --------------------------------------------------------------------------- + +class TestCustomComplexVsNumpy: + + def test_bcomplex32_plus_bool_gives_cfloat(self): + assert rt(bc32, np.bool_) == np.dtype(np.complex64) + + def test_bcomplex32_plus_int8_gives_cfloat(self): + assert rt(bc32, np.int8) == np.dtype(np.complex64) + + def test_bcomplex32_plus_float16_gives_cfloat(self): + assert rt(bc32, np.float16) == np.dtype(np.complex64) + + def test_bcomplex32_plus_float32_gives_cfloat(self): + assert rt(bc32, np.float32) == np.dtype(np.complex64) + + def test_bcomplex32_plus_float64_gives_cdouble(self): + assert rt(bc32, np.float64) == np.dtype(np.complex128) + + def test_bcomplex32_plus_cfloat_gives_cfloat(self): + assert rt(bc32, np.complex64) == np.dtype(np.complex64) + + def test_bcomplex32_plus_cdouble_gives_cdouble(self): + assert rt(bc32, np.complex128) == np.dtype(np.complex128) + + def test_symmetry(self): + for numpy_t in [np.float32, np.float64, np.complex64, np.complex128]: + assert rt(bc32, numpy_t) == rt(numpy_t, bc32) + + +# --------------------------------------------------------------------------- +# Custom complex + custom float / int +# --------------------------------------------------------------------------- + +class TestCustomComplexVsCustom: + + def test_bcomplex32_plus_bfloat16_gives_cfloat(self): + assert rt(bc32, bf16) == np.dtype(np.complex64) + + def test_bcomplex32_plus_float8_gives_cfloat(self): + assert rt(bc32, f8_e4m3) == np.dtype(np.complex64) + + def test_bcomplex32_plus_int4_gives_cfloat(self): + assert rt(bc32, i4) == np.dtype(np.complex64) + + def test_bcomplex32_plus_complex32_gives_cfloat(self): + assert rt(bc32, c32) == np.dtype(np.complex64) + + def test_same_type_gives_same_type(self): + assert rt(bc32, bc32) == np.dtype(bc32) + assert rt(c32, c32) == np.dtype(c32) + + +# --------------------------------------------------------------------------- +# Python scalars (abstract types: 0, 0.0, 0.0j) +# --------------------------------------------------------------------------- + +class TestPythonScalars: + """Concrete custom DTypes should dominate abstract Python scalar types.""" + + @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) + def test_custom_float_dominates_python_int(self, dtype): + assert rt(dtype, 0) == np.dtype(dtype) + + @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) + def test_custom_float_dominates_python_float(self, dtype): + assert rt(dtype, 0.0) == np.dtype(dtype) + + @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) + def test_custom_float_plus_python_complex_gives_cfloat(self, dtype): + # Abstract complex + custom float → cfloat (smallest complex containing both). + assert rt(dtype, 0.0j) == np.dtype(np.complex64) + + @pytest.mark.parametrize("dtype", [i4, ui4]) + def test_custom_int_dominates_python_int(self, dtype): + assert rt(dtype, 0) == np.dtype(dtype) + + @pytest.mark.parametrize("dtype", [i4, ui4]) + def test_custom_int_dominates_python_float(self, dtype): + assert rt(dtype, 0.0) == np.dtype(dtype) + + def test_custom_complex_dominates_python_int(self): + assert rt(bc32, 0) == np.dtype(bc32) + + def test_custom_complex_dominates_python_float(self): + assert rt(bc32, 0.0) == np.dtype(bc32) + + def test_custom_complex_dominates_python_complex(self): + assert rt(bc32, 0.0j) == np.dtype(bc32) From 956681b60f33c3745850e08e8c2ae8a2c7f75816 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 20 Apr 2026 13:42:57 +0200 Subject: [PATCH 3/6] parametrize test --- ml_dtypes/tests/result_type_test.py | 388 ++++++++++++++-------------- 1 file changed, 201 insertions(+), 187 deletions(-) diff --git a/ml_dtypes/tests/result_type_test.py b/ml_dtypes/tests/result_type_test.py index f4deee79..bdc9fab6 100644 --- a/ml_dtypes/tests/result_type_test.py +++ b/ml_dtypes/tests/result_type_test.py @@ -18,16 +18,29 @@ import numpy as np import pytest - +# Short aliases for readability in parametrize lists bf16 = ml_dtypes.bfloat16 -f4 = ml_dtypes.float4_e2m1fn -f8_e4m3 = ml_dtypes.float8_e4m3fn -f8_e5m2 = ml_dtypes.float8_e5m2 +f4 = ml_dtypes.float4_e2m1fn +f6_e2m3 = ml_dtypes.float6_e2m3fn +f6_e3m2 = ml_dtypes.float6_e3m2fn +f8_e3m4 = ml_dtypes.float8_e3m4 +f8_e4m3 = ml_dtypes.float8_e4m3 +f8_e4m3fn = ml_dtypes.float8_e4m3fn +f8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +f8_e4m3b11 = ml_dtypes.float8_e4m3b11fnuz +f8_e5m2 = ml_dtypes.float8_e5m2 +f8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +f8_e8m0 = ml_dtypes.float8_e8m0fnu bc32 = ml_dtypes.bcomplex32 -c32 = ml_dtypes.complex32 -i4 = ml_dtypes.int4 -ui4 = ml_dtypes.uint4 -i2 = ml_dtypes.int2 +c32 = ml_dtypes.complex32 +i1, i2, i4 = ml_dtypes.int1, ml_dtypes.int2, ml_dtypes.int4 +u1, u2, u4 = ml_dtypes.uint1, ml_dtypes.uint2, ml_dtypes.uint4 + +ALL_CUSTOM_FLOATS = [bf16, f4, f6_e2m3, f6_e3m2, + f8_e3m4, f8_e4m3, f8_e4m3fn, f8_e4m3fnuz, + f8_e4m3b11, f8_e5m2, f8_e5m2fnuz, f8_e8m0] +ALL_INTN = [i1, i2, i4, u1, u2, u4] +ALL_CUSTOM_COMPLEX = [bc32, c32] def rt(a, b): @@ -35,219 +48,220 @@ def rt(a, b): # --------------------------------------------------------------------------- -# Custom float + NumPy built-in +# Custom float vs NumPy built-in types # --------------------------------------------------------------------------- -class TestCustomFloatVsNumpy: - - def test_float8_plus_float16_gives_float16(self): - assert rt(f8_e4m3, np.float16) == np.dtype(np.float16) - - def test_bfloat16_plus_float16_gives_float32(self): - # Neither fits the other: bfloat16 has 7 mantissa bits, float16 has wider - # exponent; float32 contains both. - assert rt(bf16, np.float16) == np.dtype(np.float32) - - def test_float8_plus_float32_gives_float32(self): - assert rt(f8_e4m3, np.float32) == np.dtype(np.float32) - - def test_float8_plus_float64_gives_float64(self): - assert rt(f8_e4m3, np.float64) == np.dtype(np.float64) - - def test_float8_plus_bool_gives_float8(self): - assert rt(f8_e4m3, np.bool_) == np.dtype(f8_e4m3) - - def test_bfloat16_plus_int8_gives_bfloat16(self): - # bfloat16 has enough precision to represent all int8 values. - assert rt(bf16, np.int8) == np.dtype(bf16) - - def test_float8_plus_int8_gives_float64(self): - # float8 cannot represent all int8 values; PyArray_CommonDType defers - # to the integer's required precision → float64. - assert rt(f8_e4m3, np.int8) == np.dtype(np.float64) - - def test_float8_plus_complex64_gives_complex64(self): - assert rt(f8_e4m3, np.complex64) == np.dtype(np.complex64) +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom float always wins ---- + (bf16, np.bool_, bf16), + (f8_e4m3fn, np.bool_, f8_e4m3fn), + (f4, np.bool_, f4), + # ---- floats: pick the wider ---- + (f4, np.float16, np.float16), # f4 fits in float16 + (f8_e4m3fn, np.float16, np.float16), # float8 fits in float16 + (f8_e5m2, np.float16, np.float16), # float8 fits in float16 + (bf16, np.float16, np.float32), # incomparable → float32 + (f8_e4m3fn, np.float32, np.float32), # all custom floats fit in float32 + (bf16, np.float32, np.float32), + (bf16, np.float64, np.float64), + (f8_e4m3fn, np.float64, np.float64), + # ---- integers: PyArray_CommonDType decides ---- + (bf16, np.int8, bf16), # bfloat16 has enough precision for int8 + (bf16, np.int16, bf16), # bfloat16 has enough precision for int16 + (f8_e4m3fn, np.int8, np.float64), # float8 can't represent all int8 values + (f8_e4m3fn, np.int32, np.float64), # float8 can't represent all int32 values + # ---- complex: other always wins ---- + (bf16, np.complex64, np.complex64), + (f8_e4m3fn, np.complex64, np.complex64), + (bf16, np.complex128, np.complex128), + (f8_e4m3fn, np.complex128, np.complex128), +]) +def test_custom_float_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Custom float + custom float +# Custom float vs custom float # --------------------------------------------------------------------------- -class TestCustomFloatVsCustomFloat: - - def test_float8_fits_in_bfloat16(self): - # float8_e4m3fn has fewer bits than bfloat16 in every dimension. - assert rt(f8_e4m3, bf16) == np.dtype(bf16) - - def test_float8_e5m2_fits_in_bfloat16(self): - # float8_e5m2 has less precision but same exponent range; bfloat16 wins. - assert rt(f8_e5m2, bf16) == np.dtype(bf16) - - def test_float4_fits_in_float8(self): - # float4_e2m1fn has fewer exp and mantissa bits than float8_e4m3fn. - assert rt(f4, f8_e4m3) == np.dtype(f8_e4m3) - - def test_float8_e4m3_vs_float8_e5m2_gives_float32(self): - # Incomparable: e4m3 has more mantissa, e5m2 has more exponent → float32. - assert rt(f8_e4m3, f8_e5m2) == np.dtype(np.float32) - - def test_same_type_gives_same_type(self): - assert rt(f8_e4m3, f8_e4m3) == np.dtype(f8_e4m3) - assert rt(bf16, bf16) == np.dtype(bf16) +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (bf16, bf16, bf16), + (f8_e4m3fn, f8_e4m3fn, f8_e4m3fn), + (f4, f4, f4), + # ---- narrower fits safely into wider ---- + (f4, f6_e2m3, f6_e2m3), # f4 ⊂ f6_e2m3 (more exp + mantissa) + (f4, f8_e4m3fn, f8_e4m3fn), # f4 fits in every float8+ + (f4, bf16, bf16), # f4 fits in bfloat16 + (f8_e4m3fn, bf16, bf16), # float8 fits in bfloat16 + (f8_e5m2, bf16, bf16), # float8 fits in bfloat16 + (f8_e3m4, bf16, bf16), # float8 fits in bfloat16 + # ---- incomparable: one has more exp, other more mantissa → float32 ---- + (bf16, f8_e5m2, bf16), # f8_e5m2 fits in bf16 (bf16 > in all dims) + (f8_e4m3fn, f8_e5m2, np.float32), # e4m3 has more mantissa, e5m2 has more exp + (f8_e4m3fn, f8_e4m3fnuz, np.float32), # same bits, different special-value encoding + (f6_e2m3, f6_e3m2, np.float32), # one has more mantissa, other more exp +]) +def test_custom_float_vs_custom_float(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Custom float + custom int +# Custom float vs custom int (float always dominates) # --------------------------------------------------------------------------- -class TestCustomFloatVsCustomInt: - - def test_float_beats_int4(self): - assert rt(f8_e4m3, i4) == np.dtype(f8_e4m3) - assert rt(bf16, i4) == np.dtype(bf16) - - def test_symmetry(self): - assert rt(i4, f8_e4m3) == rt(f8_e4m3, i4) - assert rt(i4, bf16) == rt(bf16, i4) +@pytest.mark.parametrize("float_t, int_t", [ + (bf16, i4), + (bf16, u4), + (bf16, i1), + (f8_e4m3fn, i4), + (f8_e4m3fn, u4), + (f8_e5m2, i2), + (f4, i1), +]) +def test_custom_float_beats_custom_int(float_t, int_t): + assert rt(float_t, int_t) == np.dtype(float_t) + assert rt(int_t, float_t) == np.dtype(float_t) # symmetric # --------------------------------------------------------------------------- -# Custom int + NumPy built-in +# Custom int vs NumPy built-in types # --------------------------------------------------------------------------- -class TestCustomIntVsNumpy: - - def test_int4_plus_bool_gives_int4(self): - assert rt(i4, np.bool_) == np.dtype(i4) - - def test_int4_plus_int8_gives_int8(self): - assert rt(i4, np.int8) == np.dtype(np.int8) - - def test_int4_plus_int16_gives_int16(self): - assert rt(i4, np.int16) == np.dtype(np.int16) - - def test_int4_plus_float16_gives_float16(self): - assert rt(i4, np.float16) == np.dtype(np.float16) - - def test_int4_plus_float32_gives_float32(self): - assert rt(i4, np.float32) == np.dtype(np.float32) - - def test_int4_plus_complex64_gives_complex64(self): - assert rt(i4, np.complex64) == np.dtype(np.complex64) - - def test_symmetry(self): - for numpy_t in [np.int8, np.int16, np.float32, np.complex64]: - assert rt(i4, numpy_t) == rt(numpy_t, i4) +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom int always wins ---- + (i4, np.bool_, i4), + (u4, np.bool_, u4), + (i1, np.bool_, i1), + # ---- all other NumPy types: return other (intN is always smaller) ---- + (i4, np.int8, np.int8), + (i4, np.int16, np.int16), + (i4, np.int32, np.int32), + (i4, np.uint8, np.uint8), + (u4, np.int8, np.int8), + (i2, np.int8, np.int8), + (i4, np.float16, np.float16), + (i4, np.float32, np.float32), + (i4, np.float64, np.float64), + (i4, np.complex64, np.complex64), + (i4, np.complex128, np.complex128), + (u4, np.float32, np.float32), +]) +def test_custom_int_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Custom int + custom int +# Custom int vs custom int # --------------------------------------------------------------------------- -class TestCustomIntVsCustomInt: - - def test_int4_plus_uint4_gives_int16(self): - # Signed + unsigned 4-bit: neither fits the other → int16. - assert rt(i4, ui4) == np.dtype(np.int16) - - def test_int2_plus_int4_gives_int16(self): - # int2 < int4 by type_num; int4 handles and int2 fits in int4 → int16? - # Actually both are narrow custom ints; falls back to int16. - assert rt(ml_dtypes.int2, i4) == np.dtype(np.int16) - - def test_same_type_gives_same_type(self): - assert rt(i4, i4) == np.dtype(i4) +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (i4, i4, i4), + (u4, u4, u4), + # ---- mixed sign: neither fits the other → int16 ---- + (i4, u4, np.int16), + (i2, u2, np.int16), + (i1, u1, np.int16), + # ---- same sign, different width → int16 fallback ---- + (i2, i4, np.int16), + (u2, u4, np.int16), + (i1, i4, np.int16), +]) +def test_custom_int_vs_custom_int(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Custom complex + NumPy built-in +# Custom complex vs NumPy built-in types # --------------------------------------------------------------------------- -class TestCustomComplexVsNumpy: - - def test_bcomplex32_plus_bool_gives_cfloat(self): - assert rt(bc32, np.bool_) == np.dtype(np.complex64) - - def test_bcomplex32_plus_int8_gives_cfloat(self): - assert rt(bc32, np.int8) == np.dtype(np.complex64) - - def test_bcomplex32_plus_float16_gives_cfloat(self): - assert rt(bc32, np.float16) == np.dtype(np.complex64) - - def test_bcomplex32_plus_float32_gives_cfloat(self): - assert rt(bc32, np.float32) == np.dtype(np.complex64) - - def test_bcomplex32_plus_float64_gives_cdouble(self): - assert rt(bc32, np.float64) == np.dtype(np.complex128) - - def test_bcomplex32_plus_cfloat_gives_cfloat(self): - assert rt(bc32, np.complex64) == np.dtype(np.complex64) - - def test_bcomplex32_plus_cdouble_gives_cdouble(self): - assert rt(bc32, np.complex128) == np.dtype(np.complex128) - - def test_symmetry(self): - for numpy_t in [np.float32, np.float64, np.complex64, np.complex128]: - assert rt(bc32, numpy_t) == rt(numpy_t, bc32) +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool + integers: wrap in cfloat ---- + (bc32, np.bool_, np.complex64), + (bc32, np.int8, np.complex64), + (bc32, np.int32, np.complex64), + (c32, np.bool_, np.complex64), + (c32, np.int8, np.complex64), + # ---- floats ≤ float32: wrap in cfloat ---- + (bc32, np.float16, np.complex64), + (bc32, np.float32, np.complex64), + (c32, np.float16, np.complex64), + (c32, np.float32, np.complex64), + # ---- float64+: need cdouble ---- + (bc32, np.float64, np.complex128), + (bc32, np.longdouble, np.dtype("clongdouble")), + (c32, np.float64, np.complex128), + # ---- built-in complex: other always wins ---- + (bc32, np.complex64, np.complex64), + (bc32, np.complex128, np.complex128), + (c32, np.complex64, np.complex64), + (c32, np.complex128, np.complex128), +]) +def test_custom_complex_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Custom complex + custom float / int +# Custom complex vs custom float / custom int # --------------------------------------------------------------------------- -class TestCustomComplexVsCustom: - - def test_bcomplex32_plus_bfloat16_gives_cfloat(self): - assert rt(bc32, bf16) == np.dtype(np.complex64) - - def test_bcomplex32_plus_float8_gives_cfloat(self): - assert rt(bc32, f8_e4m3) == np.dtype(np.complex64) - - def test_bcomplex32_plus_int4_gives_cfloat(self): - assert rt(bc32, i4) == np.dtype(np.complex64) - - def test_bcomplex32_plus_complex32_gives_cfloat(self): - assert rt(bc32, c32) == np.dtype(np.complex64) - - def test_same_type_gives_same_type(self): - assert rt(bc32, bc32) == np.dtype(bc32) - assert rt(c32, c32) == np.dtype(c32) +@pytest.mark.parametrize("a, b, expected", [ + # ---- custom floats: all fit in cfloat alongside our complex ---- + (bc32, bf16, np.complex64), + (bc32, f8_e4m3fn, np.complex64), + (bc32, f8_e5m2, np.complex64), + (bc32, f4, np.complex64), + (c32, bf16, np.complex64), + (c32, f8_e4m3fn, np.complex64), + # ---- custom ints: all tiny, fit in cfloat ---- + (bc32, i4, np.complex64), + (bc32, u4, np.complex64), + (bc32, i1, np.complex64), + (c32, i4, np.complex64), + # ---- two custom complex types ---- + (bc32, c32, np.complex64), + (bc32, bc32, bc32), + (c32, c32, c32), +]) +def test_custom_complex_vs_custom(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric # --------------------------------------------------------------------------- -# Python scalars (abstract types: 0, 0.0, 0.0j) +# Python scalars: 0, 0.0, 0.0j (abstract types) # --------------------------------------------------------------------------- -class TestPythonScalars: - """Concrete custom DTypes should dominate abstract Python scalar types.""" - - @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) - def test_custom_float_dominates_python_int(self, dtype): - assert rt(dtype, 0) == np.dtype(dtype) - - @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) - def test_custom_float_dominates_python_float(self, dtype): - assert rt(dtype, 0.0) == np.dtype(dtype) - - @pytest.mark.parametrize("dtype", [f8_e4m3, bf16]) - def test_custom_float_plus_python_complex_gives_cfloat(self, dtype): - # Abstract complex + custom float → cfloat (smallest complex containing both). - assert rt(dtype, 0.0j) == np.dtype(np.complex64) - - @pytest.mark.parametrize("dtype", [i4, ui4]) - def test_custom_int_dominates_python_int(self, dtype): - assert rt(dtype, 0) == np.dtype(dtype) - - @pytest.mark.parametrize("dtype", [i4, ui4]) - def test_custom_int_dominates_python_float(self, dtype): - assert rt(dtype, 0.0) == np.dtype(dtype) - - def test_custom_complex_dominates_python_int(self): - assert rt(bc32, 0) == np.dtype(bc32) - - def test_custom_complex_dominates_python_float(self): - assert rt(bc32, 0.0) == np.dtype(bc32) - - def test_custom_complex_dominates_python_complex(self): - assert rt(bc32, 0.0j) == np.dtype(bc32) +@pytest.mark.parametrize("dtype, scalar, expected", [ + # ---- custom floats dominate Python int and Python float ---- + (bf16, 0, bf16), + (bf16, 0.0, bf16), + (f8_e4m3fn, 0, f8_e4m3fn), + (f8_e4m3fn, 0.0, f8_e4m3fn), + (f4, 0, f4), + (f4, 0.0, f4), + # ---- custom float + Python complex → cfloat ---- + (bf16, 0.0j, np.complex64), + (f8_e4m3fn, 0.0j, np.complex64), + (f4, 0.0j, np.complex64), + # ---- custom ints dominate Python int and Python float ---- + (i4, 0, i4), + (i4, 0.0, i4), + (u4, 0, u4), + (u4, 0.0, u4), + # ---- custom complex dominates all Python scalars ---- + (bc32, 0, bc32), + (bc32, 0.0, bc32), + (bc32, 0.0j, bc32), + (c32, 0, c32), + (c32, 0.0, c32), + (c32, 0.0j, c32), +]) +def test_python_scalars(dtype, scalar, expected): + assert rt(dtype, scalar) == np.dtype(expected) From eff09d7934f5a4ee996de182cdd83b8383f2b7eb Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 20 Apr 2026 14:42:42 +0200 Subject: [PATCH 4/6] Work around NumPy ABI bug and fixup tests (now tested to work with old NumPy versions) --- ml_dtypes/_src/custom_complex.h | 16 +++++++++------- ml_dtypes/_src/custom_float.h | 16 ++++++++-------- ml_dtypes/_src/dtype_compat.h | 7 +++++++ ml_dtypes/_src/intn_numpy.h | 16 ++++++++-------- ml_dtypes/tests/result_type_test.py | 8 ++++---- 5 files changed, 36 insertions(+), 27 deletions(-) diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index 11775a9b..680b48a3 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -973,10 +973,12 @@ static PyArray_DTypeMeta* NPyCustomComplex_CommonDType( case NPY_HALF: case NPY_FLOAT: Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); return &PyArray_CFloatDType; - case NPY_DOUBLE: case NPY_LONGDOUBLE: + case NPY_DOUBLE: Py_INCREF(reinterpret_cast(&PyArray_CDoubleDType)); return &PyArray_CDoubleDType; - + case NPY_LONGDOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CLongDoubleDType)); + return &PyArray_CLongDoubleDType; // Built-in complex: our types are smaller, return other. case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: Py_INCREF(other); @@ -1116,15 +1118,15 @@ bool RegisterComplexDtype(PyObject* numpy) { reinterpret_cast(NPyCustomComplex_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyCustomComplex_CommonDType)}, - {NPY_DT_PyArray_ArrFuncs_getitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), reinterpret_cast(NPyCustomComplex_GetItem)}, - {NPY_DT_PyArray_ArrFuncs_setitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), reinterpret_cast(NPyCustomComplex_SetItem)}, - {NPY_DT_PyArray_ArrFuncs_nonzero, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), reinterpret_cast(NPyCustomComplex_NonZero)}, - {NPY_DT_PyArray_ArrFuncs_dotfunc, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), reinterpret_cast(NPyCustomComplex_DotFunc)}, - {NPY_DT_PyArray_ArrFuncs_compare, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), reinterpret_cast(NPyCustomComplex_CompareFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec; diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 182d4e04..f64d80d3 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -1086,21 +1086,21 @@ bool RegisterFloatDtype(PyObject* numpy) { reinterpret_cast(NPyCustomFloat_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyCustomFloat_CommonDType)}, - {NPY_DT_PyArray_ArrFuncs_getitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), reinterpret_cast(NPyCustomFloat_GetItem)}, - {NPY_DT_PyArray_ArrFuncs_setitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), reinterpret_cast(NPyCustomFloat_SetItem)}, - {NPY_DT_PyArray_ArrFuncs_nonzero, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), reinterpret_cast(NPyCustomFloat_NonZero)}, - {NPY_DT_PyArray_ArrFuncs_fill, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_fill), reinterpret_cast(NPyCustomFloat_Fill)}, - {NPY_DT_PyArray_ArrFuncs_dotfunc, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), reinterpret_cast(NPyCustomFloat_DotFunc)}, - {NPY_DT_PyArray_ArrFuncs_compare, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), reinterpret_cast(NPyCustomFloat_CompareFunc)}, - {NPY_DT_PyArray_ArrFuncs_argmax, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmax), reinterpret_cast(NPyCustomFloat_ArgMaxFunc)}, - {NPY_DT_PyArray_ArrFuncs_argmin, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmin), reinterpret_cast(NPyCustomFloat_ArgMinFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec; diff --git a/ml_dtypes/_src/dtype_compat.h b/ml_dtypes/_src/dtype_compat.h index 7a492940..7cb6af2f 100644 --- a/ml_dtypes/_src/dtype_compat.h +++ b/ml_dtypes/_src/dtype_compat.h @@ -41,6 +41,13 @@ limitations under the License. #error "ml_dtypes dtype_compat.h requires NumPy >= 2.0" #endif +#if NPY_TARGET_VERSION >= 0x15 // NUMPY_2_4_API_VERSION +#define ARRFUNCS_OFFSET_FIX(v) (v) +#else +#define ARRFUNCS_OFFSET_FIX(v) \ + (v) - (NPY_DT_PyArray_ArrFuncs_getitem) + 1 + (((PyArray_RUNTIME_VERSION >= 0x15) ? (1 << 11) : (1 << 10))) +#endif + namespace ml_dtypes { /* diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index c915aced..52dbfad4 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -962,21 +962,21 @@ bool RegisterIntNDtype(PyObject* numpy) { reinterpret_cast(NPyIntN_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyIntN_CommonDType)}, - {NPY_DT_PyArray_ArrFuncs_getitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), reinterpret_cast(NPyIntN_GetItem)}, - {NPY_DT_PyArray_ArrFuncs_setitem, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), reinterpret_cast(NPyIntN_SetItem)}, - {NPY_DT_PyArray_ArrFuncs_nonzero, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), reinterpret_cast(NPyIntN_NonZero)}, - {NPY_DT_PyArray_ArrFuncs_fill, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_fill), reinterpret_cast(NPyIntN_Fill)}, - {NPY_DT_PyArray_ArrFuncs_dotfunc, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), reinterpret_cast(NPyIntN_DotFunc)}, - {NPY_DT_PyArray_ArrFuncs_compare, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), reinterpret_cast(NPyIntN_CompareFunc)}, - {NPY_DT_PyArray_ArrFuncs_argmax, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmax), reinterpret_cast(NPyIntN_ArgMaxFunc)}, - {NPY_DT_PyArray_ArrFuncs_argmin, + {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmin), reinterpret_cast(NPyIntN_ArgMinFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec; diff --git a/ml_dtypes/tests/result_type_test.py b/ml_dtypes/tests/result_type_test.py index bdc9fab6..8554eb77 100644 --- a/ml_dtypes/tests/result_type_test.py +++ b/ml_dtypes/tests/result_type_test.py @@ -66,8 +66,8 @@ def rt(a, b): (bf16, np.float64, np.float64), (f8_e4m3fn, np.float64, np.float64), # ---- integers: PyArray_CommonDType decides ---- - (bf16, np.int8, bf16), # bfloat16 has enough precision for int8 - (bf16, np.int16, bf16), # bfloat16 has enough precision for int16 + (bf16, np.int8, bf16), # bfloat16 has 8 sig bits, int8 needs 7 → bf16 wins + (bf16, np.int16, np.float64), # bfloat16 has 8 sig bits, int16 needs 15 → float64 (f8_e4m3fn, np.int8, np.float64), # float8 can't represent all int8 values (f8_e4m3fn, np.int32, np.float64), # float8 can't represent all int32 values # ---- complex: other always wins ---- @@ -100,7 +100,7 @@ def test_custom_float_vs_numpy(a, b, expected): # ---- incomparable: one has more exp, other more mantissa → float32 ---- (bf16, f8_e5m2, bf16), # f8_e5m2 fits in bf16 (bf16 > in all dims) (f8_e4m3fn, f8_e5m2, np.float32), # e4m3 has more mantissa, e5m2 has more exp - (f8_e4m3fn, f8_e4m3fnuz, np.float32), # same bits, different special-value encoding + (f8_e4m3fn, f8_e4m3fnuz, f8_e4m3fn), # same digits/max_exp → numeric_limits match; fn wins (f6_e2m3, f6_e3m2, np.float32), # one has more mantissa, other more exp ]) def test_custom_float_vs_custom_float(a, b, expected): @@ -194,7 +194,7 @@ def test_custom_int_vs_custom_int(a, b, expected): (c32, np.float32, np.complex64), # ---- float64+: need cdouble ---- (bc32, np.float64, np.complex128), - (bc32, np.longdouble, np.dtype("clongdouble")), + (bc32, np.longdouble, np.clongdouble), (c32, np.float64, np.complex128), # ---- built-in complex: other always wins ---- (bc32, np.complex64, np.complex64), From 3514afc1d318641acb57c3def5effe076b8c4631 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 21 Apr 2026 20:57:09 +0200 Subject: [PATCH 5/6] simplify --- ml_dtypes/_src/custom_complex.h | 16 ++------------- ml_dtypes/_src/custom_float.h | 18 ++--------------- ml_dtypes/_src/dtype_compat.h | 36 ++++++++++++++++++++------------- ml_dtypes/_src/intn_numpy.h | 16 ++------------- 4 files changed, 28 insertions(+), 58 deletions(-) diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index 680b48a3..d901b59b 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -914,16 +914,6 @@ bool RegisterComplexUFuncs(PyObject* numpy) { // New-style DType slot functions for CustomComplex types // --------------------------------------------------------------------------- -template -static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) { - return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); -} - -template -static PyObject* NPyCustomComplex_DTypeStr(PyObject* /*self*/) { - return PyUnicode_FromString(TypeDescriptor::kTypeName); -} - template static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { @@ -1081,8 +1071,6 @@ bool RegisterComplexDtype(PyObject* numpy) { tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; - tp->tp_repr = NPyCustomComplex_DTypeRepr; - tp->tp_str = NPyCustomComplex_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -1108,6 +1096,7 @@ bool RegisterComplexDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { + {300, reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyCustomComplex_NewStyleGetItem)}, {NPY_DT_setitem, @@ -1136,8 +1125,7 @@ bool RegisterComplexDtype(PyObject* numpy) { dtype_spec.slots = dtype_slots; dtype_spec.baseclass = nullptr; - if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, - &descr_proto) < 0) { + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { return false; } TypeDescriptor::npy_type = dm.type_num; diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index f64d80d3..33b51880 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -852,18 +852,6 @@ bool RegisterFloatUFuncs(PyObject* numpy) { // New-style DType slot functions for CustomFloat types // --------------------------------------------------------------------------- -// tp_repr / tp_str for the DTypeMeta itself (required by -// PyArrayInitDTypeMeta_FromSpec; must differ from PyArrayDescr_Type's). -template -static PyObject* NPyCustomFloat_DTypeRepr(PyObject* /*self*/) { - return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); -} - -template -static PyObject* NPyCustomFloat_DTypeStr(PyObject* /*self*/) { - return PyUnicode_FromString(TypeDescriptor::kTypeName); -} - // New-style getitem: (PyArray_Descr*, char*) -> PyObject* template static PyObject* NPyCustomFloat_NewStyleGetItem(PyArray_Descr* /*descr*/, @@ -1049,8 +1037,6 @@ bool RegisterFloatDtype(PyObject* numpy) { tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; - tp->tp_repr = NPyCustomFloat_DTypeRepr; - tp->tp_str = NPyCustomFloat_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -1076,6 +1062,7 @@ bool RegisterFloatDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { + {300, reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyCustomFloat_NewStyleGetItem)}, {NPY_DT_setitem, @@ -1110,8 +1097,7 @@ bool RegisterFloatDtype(PyObject* numpy) { dtype_spec.slots = dtype_slots; dtype_spec.baseclass = nullptr; - if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, - &descr_proto) < 0) { + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { return false; } TypeDescriptor::npy_type = dm.type_num; diff --git a/ml_dtypes/_src/dtype_compat.h b/ml_dtypes/_src/dtype_compat.h index 7cb6af2f..84402e37 100644 --- a/ml_dtypes/_src/dtype_compat.h +++ b/ml_dtypes/_src/dtype_compat.h @@ -98,11 +98,19 @@ static inline int TrivialStridedCopyLoop(PyArrayMethod_Context *context, * * If proto is NULL the function just forwards to PyArrayInitDTypeMeta_FromSpec. */ -static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( - PyArray_DTypeMeta *DType, PyArrayDTypeMeta_Spec *spec, - PyArray_DescrProto *proto) { - if (proto == nullptr) { - return PyArrayInitDTypeMeta_FromSpec(DType, spec); +#if NPY_TARGET_VERSION < 0x16 && NPY_TARGET_VERSION >=0x00000012 +#define _PyArrayInitDTypeMeta_FromSpec \ + (*(int (*)(PyArray_DTypeMeta *, PyArrayDTypeMeta_Spec *))PyArray_API[362]) +#undef PyArrayInitDTypeMeta_FromSpec + +static inline int PyArrayInitDTypeMeta_FromSpec( + PyArray_DTypeMeta *DType, PyArrayDTypeMeta_Spec *spec) { + PyArray_DescrProto *proto = nullptr; + if (spec->slots[0].slot == 300) { + proto = reinterpret_cast(spec->slots[0].pfunc); + } + if (proto == nullptr || PyArray_RUNTIME_VERSION >= 0x16) { + return _PyArrayInitDTypeMeta_FromSpec(DType, spec); } /* @@ -111,12 +119,9 @@ static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( * pytype-to-DType dict (it bails out on NPY_DT_is_legacy for non-generic * types), regardless of whether the real scalar subclasses np.generic. */ - PyTypeObject *real_typeobj = proto->typeobj; - proto->typeobj = &PyBaseObject_Type; - - int typenum = PyArray_RegisterDataType(proto); - - proto->typeobj = real_typeobj; + PyArray_DescrProto new_proto = *proto; + new_proto.typeobj = &PyBaseObject_Type; + int typenum = PyArray_RegisterDataType(&new_proto); if (typenum < 0) { return -1; } @@ -125,7 +130,9 @@ static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( * Step 2: Initialise the user's DType with new-style slots and casts. * type_num stays at -1 / 0 for now; we fix it in step 3. */ - if (PyArrayInitDTypeMeta_FromSpec(DType, spec) < 0) { + PyArrayDTypeMeta_Spec new_spec = *spec; + new_spec.slots = &spec->slots[1]; // skip proto slot. + if (_PyArrayInitDTypeMeta_FromSpec(DType, &new_spec) < 0) { return -1; } @@ -158,9 +165,9 @@ static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( /* Fix the descriptor's scalar-type field (it was set to PyBaseObject_Type * in step 1 by PyArray_RegisterDataType copying proto->typeobj). */ - Py_INCREF(real_typeobj); + Py_INCREF(proto->typeobj); Py_XDECREF(descr->typeobj); - descr->typeobj = real_typeobj; + descr->typeobj = proto->typeobj; /* * Patch copyswap/copyswapn into the new DType's legacy f-slots. @@ -178,6 +185,7 @@ static inline int PyArrayInitDTypeMeta_FromSpec_WithLegacy( return 0; } +#endif } // namespace ml_dtypes diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 52dbfad4..081f88d8 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -784,16 +784,6 @@ bool RegisterIntNUFuncs(PyObject* numpy) { // New-style DType slot functions for IntN types // --------------------------------------------------------------------------- -template -static PyObject* NPyIntN_DTypeRepr(PyObject* /*self*/) { - return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); -} - -template -static PyObject* NPyIntN_DTypeStr(PyObject* /*self*/) { - return PyUnicode_FromString(TypeDescriptor::kTypeName); -} - template static PyObject* NPyIntN_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { return NPyIntN_GetItem(data, /*arr=*/nullptr); @@ -925,8 +915,6 @@ bool RegisterIntNDtype(PyObject* numpy) { tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; - tp->tp_repr = NPyIntN_DTypeRepr; - tp->tp_str = NPyIntN_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -952,6 +940,7 @@ bool RegisterIntNDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { + {300, reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyIntN_NewStyleGetItem)}, {NPY_DT_setitem, @@ -986,8 +975,7 @@ bool RegisterIntNDtype(PyObject* numpy) { dtype_spec.slots = dtype_slots; dtype_spec.baseclass = nullptr; - if (PyArrayInitDTypeMeta_FromSpec_WithLegacy(&dm, &dtype_spec, - &descr_proto) < 0) { + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { return false; } TypeDescriptor::npy_type = dm.type_num; From c67581e3f33ed63abff12fac9be13926c79089d7 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 22 Apr 2026 13:29:58 +0200 Subject: [PATCH 6/6] Version based on compiling with NumPy 2.5 (after merging PR there) --- ml_dtypes/_src/custom_complex.h | 25 +++++-- ml_dtypes/_src/custom_float.h | 33 +++++---- ml_dtypes/_src/dtype_compat.h | 122 -------------------------------- ml_dtypes/_src/intn_numpy.h | 31 +++++--- 4 files changed, 60 insertions(+), 151 deletions(-) diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index d901b59b..5b6b85a9 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -1018,6 +1018,15 @@ static PyArray_DTypeMeta* NPyCustomComplex_CommonDType( return reinterpret_cast(Py_NotImplemented); } +template +static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyCustomComplex_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -1064,13 +1073,14 @@ bool RegisterComplexDtype(PyObject* numpy) { descr_proto.typeobj = reinterpret_cast(type); descr_proto.f = &arr_funcs; - // Set up the DTypeMeta. PyArray_DTypeMeta& dm = CustomComplexType::dtype_meta; Py_SET_REFCNT(&dm, 1); auto* tp = reinterpret_cast(&dm); tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomComplex_DTypeRepr; + tp->tp_str = NPyCustomComplex_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -1096,7 +1106,8 @@ bool RegisterComplexDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { - {300, reinterpret_cast(&descr_proto)}, + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyCustomComplex_NewStyleGetItem)}, {NPY_DT_setitem, @@ -1107,15 +1118,15 @@ bool RegisterComplexDtype(PyObject* numpy) { reinterpret_cast(NPyCustomComplex_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyCustomComplex_CommonDType)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), + {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyCustomComplex_GetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), + {NPY_DT_PyArray_ArrFuncs_setitem, reinterpret_cast(NPyCustomComplex_SetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), + {NPY_DT_PyArray_ArrFuncs_nonzero, reinterpret_cast(NPyCustomComplex_NonZero)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), + {NPY_DT_PyArray_ArrFuncs_dotfunc, reinterpret_cast(NPyCustomComplex_DotFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), + {NPY_DT_PyArray_ArrFuncs_compare, reinterpret_cast(NPyCustomComplex_CompareFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec; diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 33b51880..75128d08 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -982,6 +982,15 @@ static PyArray_DTypeMeta* NPyCustomFloat_CommonDType(PyArray_DTypeMeta* cls, return reinterpret_cast(Py_NotImplemented); } +template +static PyObject* NPyCustomFloat_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyCustomFloat_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -1028,15 +1037,14 @@ bool RegisterFloatDtype(PyObject* numpy) { descr_proto.typeobj = reinterpret_cast(type); descr_proto.f = &arr_funcs; - // Set up the DTypeMeta. It must subclass PyArrayDescr_Type and have a - // metaclass of PyArrayDTypeMeta_Type (inherited via PyType_Ready from - // PyArrayDescr_Type.ob_type). PyArray_DTypeMeta& dm = CustomFloatType::dtype_meta; Py_SET_REFCNT(&dm, 1); auto* tp = reinterpret_cast(&dm); tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomFloat_DTypeRepr; + tp->tp_str = NPyCustomFloat_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -1062,7 +1070,8 @@ bool RegisterFloatDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { - {300, reinterpret_cast(&descr_proto)}, + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyCustomFloat_NewStyleGetItem)}, {NPY_DT_setitem, @@ -1073,21 +1082,21 @@ bool RegisterFloatDtype(PyObject* numpy) { reinterpret_cast(NPyCustomFloat_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyCustomFloat_CommonDType)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), + {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyCustomFloat_GetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), + {NPY_DT_PyArray_ArrFuncs_setitem, reinterpret_cast(NPyCustomFloat_SetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), + {NPY_DT_PyArray_ArrFuncs_nonzero, reinterpret_cast(NPyCustomFloat_NonZero)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_fill), + {NPY_DT_PyArray_ArrFuncs_fill, reinterpret_cast(NPyCustomFloat_Fill)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), + {NPY_DT_PyArray_ArrFuncs_dotfunc, reinterpret_cast(NPyCustomFloat_DotFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), + {NPY_DT_PyArray_ArrFuncs_compare, reinterpret_cast(NPyCustomFloat_CompareFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmax), + {NPY_DT_PyArray_ArrFuncs_argmax, reinterpret_cast(NPyCustomFloat_ArgMaxFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmin), + {NPY_DT_PyArray_ArrFuncs_argmin, reinterpret_cast(NPyCustomFloat_ArgMinFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec; diff --git a/ml_dtypes/_src/dtype_compat.h b/ml_dtypes/_src/dtype_compat.h index 84402e37..a5d51348 100644 --- a/ml_dtypes/_src/dtype_compat.h +++ b/ml_dtypes/_src/dtype_compat.h @@ -41,13 +41,6 @@ limitations under the License. #error "ml_dtypes dtype_compat.h requires NumPy >= 2.0" #endif -#if NPY_TARGET_VERSION >= 0x15 // NUMPY_2_4_API_VERSION -#define ARRFUNCS_OFFSET_FIX(v) (v) -#else -#define ARRFUNCS_OFFSET_FIX(v) \ - (v) - (NPY_DT_PyArray_ArrFuncs_getitem) + 1 + (((PyArray_RUNTIME_VERSION >= 0x15) ? (1 << 11) : (1 << 10))) -#endif - namespace ml_dtypes { /* @@ -72,121 +65,6 @@ static inline int TrivialStridedCopyLoop(PyArrayMethod_Context *context, return 0; } -/* - * PyArrayInitDTypeMeta_FromSpec_WithLegacy - * - * Initialises a new-style user DType (via PyArrayInitDTypeMeta_FromSpec) while - * also plumbing in legacy compatibility so that NumPy < 2.5 assigns a - * type_num, singleton, and legacy flag. - * - * Algorithm (for NumPy 2.0 – 2.4): - * - * Step 1 – Legacy registration (for type_num + singleton allocation): - * Temporarily replace proto->typeobj with &PyBaseObject_Type so that - * _PyArray_MapPyTypeToDType sees a non-generic type and hits the - * NPY_DT_is_legacy bail-out, meaning the auto-DTypeMeta created by - * PyArray_RegisterDataType is NOT inserted into the pytype-to-DType dict. - * - * Step 2 – New-style init: - * PyArrayInitDTypeMeta_FromSpec sets up slots, casts, and the - * pytype-to-DType mapping for the real scalar type with no conflict. - * - * Step 3 – Swap: - * Steal type_num + singleton from the old legacy registration and point - * the singleton at the user's new DType. Set the legacy flag so NumPy - * uses legacy descriptor code paths where needed. - * - * If proto is NULL the function just forwards to PyArrayInitDTypeMeta_FromSpec. - */ -#if NPY_TARGET_VERSION < 0x16 && NPY_TARGET_VERSION >=0x00000012 -#define _PyArrayInitDTypeMeta_FromSpec \ - (*(int (*)(PyArray_DTypeMeta *, PyArrayDTypeMeta_Spec *))PyArray_API[362]) -#undef PyArrayInitDTypeMeta_FromSpec - -static inline int PyArrayInitDTypeMeta_FromSpec( - PyArray_DTypeMeta *DType, PyArrayDTypeMeta_Spec *spec) { - PyArray_DescrProto *proto = nullptr; - if (spec->slots[0].slot == 300) { - proto = reinterpret_cast(spec->slots[0].pfunc); - } - if (proto == nullptr || PyArray_RUNTIME_VERSION >= 0x16) { - return _PyArrayInitDTypeMeta_FromSpec(DType, spec); - } - - /* - * Step 1: Register old-style with a garbage typeobj so that - * _PyArray_MapPyTypeToDType does NOT add the auto-DTypeMeta to the - * pytype-to-DType dict (it bails out on NPY_DT_is_legacy for non-generic - * types), regardless of whether the real scalar subclasses np.generic. - */ - PyArray_DescrProto new_proto = *proto; - new_proto.typeobj = &PyBaseObject_Type; - int typenum = PyArray_RegisterDataType(&new_proto); - if (typenum < 0) { - return -1; - } - - /* - * Step 2: Initialise the user's DType with new-style slots and casts. - * type_num stays at -1 / 0 for now; we fix it in step 3. - */ - PyArrayDTypeMeta_Spec new_spec = *spec; - new_spec.slots = &spec->slots[1]; // skip proto slot. - if (_PyArrayInitDTypeMeta_FromSpec(DType, &new_spec) < 0) { - return -1; - } - - /* - * Step 3: Steal the singleton descriptor and type_num from the legacy - * registration. Point the descriptor's Python type at the user's DType - * and fix up its typeobj field (which we temporarily set to PyBaseObject_Type - * in step 1). - */ - PyArray_Descr *descr = PyArray_DescrFromType(typenum); - if (descr == nullptr) { - return -1; - } - - /* Save the auto-DTypeMeta so we can decref it after the swap. */ - PyObject *old_meta = reinterpret_cast(Py_TYPE(descr)); - - DType->type_num = typenum; - /* PyArray_DescrFromType returns a new reference; transfer ownership. */ - DType->singleton = descr; - /* Set the legacy flag (bit 0 == _NPY_DT_LEGACY_FLAG) so NumPy uses legacy - * code paths (copyswap, ArrFuncs, etc.) where the new-style API doesn't - * cover them yet. */ - DType->flags |= 1; - - /* Re-type the descriptor so it belongs to the user's DType class. */ - Py_INCREF(reinterpret_cast(DType)); - Py_SET_TYPE(descr, reinterpret_cast(DType)); - Py_DECREF(old_meta); - - /* Fix the descriptor's scalar-type field (it was set to PyBaseObject_Type - * in step 1 by PyArray_RegisterDataType copying proto->typeobj). */ - Py_INCREF(proto->typeobj); - Py_XDECREF(descr->typeobj); - descr->typeobj = proto->typeobj; - - /* - * Patch copyswap/copyswapn into the new DType's legacy f-slots. - * - * copyswap and copyswapn are disabled as public NPY_DT_PyArray_ArrFuncs_* - * spec slots (commented out in dtype_api.h), so dtypemeta_initialize_struct_ - * from_spec leaves them as stubs from default_funcs. PyArray_Scalar and - * other legacy paths call copyswap directly, so we must fill it in. - */ - if (proto->f != nullptr) { - PyArray_ArrFuncs *f = _PyDataType_GetArrFuncs(descr); - f->copyswap = proto->f->copyswap; - f->copyswapn = proto->f->copyswapn; - } - - return 0; -} -#endif - } // namespace ml_dtypes #endif // ML_DTYPES_DTYPE_COMPAT_H_ diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 081f88d8..80f59c52 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -861,6 +861,15 @@ static PyArray_DTypeMeta* NPyIntN_CommonDType(PyArray_DTypeMeta* cls, return reinterpret_cast(Py_NotImplemented); } +template +static PyObject* NPyIntN_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyIntN_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -908,13 +917,14 @@ bool RegisterIntNDtype(PyObject* numpy) { descr_proto.typeobj = reinterpret_cast(type); descr_proto.f = &arr_funcs; - // Set up the DTypeMeta. PyArray_DTypeMeta& dm = IntNTypeDescriptor::dtype_meta; Py_SET_REFCNT(&dm, 1); auto* tp = reinterpret_cast(&dm); tp->tp_name = TypeDescriptor::kTypeName; tp->tp_base = &PyArrayDescr_Type; tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyIntN_DTypeRepr; + tp->tp_str = NPyIntN_DTypeStr; if (PyType_Ready(tp) < 0) { return false; } @@ -940,7 +950,8 @@ bool RegisterIntNDtype(PyObject* numpy) { // Build the new-style DType spec. PyType_Slot dtype_slots[] = { - {300, reinterpret_cast(&descr_proto)}, + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, {NPY_DT_getitem, reinterpret_cast(NPyIntN_NewStyleGetItem)}, {NPY_DT_setitem, @@ -951,21 +962,21 @@ bool RegisterIntNDtype(PyObject* numpy) { reinterpret_cast(NPyIntN_DefaultDescr)}, {NPY_DT_common_dtype, reinterpret_cast(NPyIntN_CommonDType)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_getitem), + {NPY_DT_PyArray_ArrFuncs_getitem, reinterpret_cast(NPyIntN_GetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_setitem), + {NPY_DT_PyArray_ArrFuncs_setitem, reinterpret_cast(NPyIntN_SetItem)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_nonzero), + {NPY_DT_PyArray_ArrFuncs_nonzero, reinterpret_cast(NPyIntN_NonZero)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_fill), + {NPY_DT_PyArray_ArrFuncs_fill, reinterpret_cast(NPyIntN_Fill)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_dotfunc), + {NPY_DT_PyArray_ArrFuncs_dotfunc, reinterpret_cast(NPyIntN_DotFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_compare), + {NPY_DT_PyArray_ArrFuncs_compare, reinterpret_cast(NPyIntN_CompareFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmax), + {NPY_DT_PyArray_ArrFuncs_argmax, reinterpret_cast(NPyIntN_ArgMaxFunc)}, - {ARRFUNCS_OFFSET_FIX(NPY_DT_PyArray_ArrFuncs_argmin), + {NPY_DT_PyArray_ArrFuncs_argmin, reinterpret_cast(NPyIntN_ArgMinFunc)}, {0, nullptr}}; PyArrayDTypeMeta_Spec dtype_spec;