diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..5b6b85a9 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/complex_types.h" #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build @@ -69,6 +70,9 @@ struct CustomComplexType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -79,6 +83,8 @@ template PyArray_DescrProto CustomComplexType::npy_descr_proto; template PyArray_Descr* CustomComplexType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomComplexType::dtype_meta = {}; // Representation of a Python custom float object. template @@ -904,6 +910,123 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomComplex types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomComplex_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyCustomComplex_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomComplex_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyCustomComplex_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyCustomComplex_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +template +static PyArray_DTypeMeta* NPyCustomComplex_CommonDType( + PyArray_DTypeMeta* cls, PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType || + other == &PyArray_PyComplexDType) { + Py_INCREF(cls); + return cls; + } + + switch (other->type_num) { + // bool, ints, half, float: wrap in the smallest complex that holds both. + // Our custom complex types all fit in cfloat. + case NPY_BOOL: + case NPY_BYTE: case NPY_SHORT: case NPY_INT: + case NPY_LONG: case NPY_LONGLONG: + case NPY_UBYTE: case NPY_USHORT: case NPY_UINT: + case NPY_ULONG: case NPY_ULONGLONG: + case NPY_HALF: case NPY_FLOAT: + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + case NPY_DOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CDoubleDType)); + return &PyArray_CDoubleDType; + case NPY_LONGDOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CLongDoubleDType)); + return &PyArray_CLongDoubleDType; + // Built-in complex: our types are smaller, return other. + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + + default: + break; + } + + // ---- Our own custom DTypes ---- + // Custom float or custom int: all fit in cfloat alongside our complex. + if (other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &CustomFloatType::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Another custom complex: both fit in cfloat. + if (other == &CustomComplexType::dtype_meta || + other == &CustomComplexType::dtype_meta) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Unknown user type: return NotImplemented. + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + +template +static PyObject* NPyCustomComplex_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyCustomComplex_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -928,7 +1051,7 @@ bool RegisterComplexDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomComplexType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomComplex_GetItem; @@ -937,29 +1060,87 @@ bool RegisterComplexDtype(PyObject* numpy) { arr_funcs.copyswapn = NPyCustomComplex_CopySwapN; arr_funcs.copyswap = NPyCustomComplex_CopySwap; arr_funcs.nonzero = NPyCustomComplex_NonZero; - arr_funcs.fill = nullptr; // NPyCustomComplex_Fill; + arr_funcs.fill = nullptr; arr_funcs.dotfunc = NPyCustomComplex_DotFunc; arr_funcs.compare = NPyCustomComplex_CompareFunc; - arr_funcs.argmax = nullptr; // NumPy defines them, but it's shaky + arr_funcs.argmax = nullptr; arr_funcs.argmin = nullptr; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto. PyArray_DescrProto& descr_proto = CustomComplexType::npy_descr_proto; descr_proto = GetCustomComplexDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + PyArray_DTypeMeta& dm = CustomComplexType::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomComplex_DTypeRepr; + tp->tp_str = NPyCustomComplex_DTypeStr; + if (PyType_Ready(tp) < 0) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyCustomComplex_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomComplex_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomComplex_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomComplex_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomComplex_CommonDType)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyCustomComplex_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyCustomComplex_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyCustomComplex_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyCustomComplex_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyCustomComplex_CompareFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index bf2568a7..75128d08 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build // Possible this has to do with numpy.h being included before @@ -66,6 +67,10 @@ struct CustomFloatType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. Zero-initialized; fields are filled in + // at registration time before PyType_Ready is called. + static PyArray_DTypeMeta dtype_meta; }; template @@ -76,6 +81,8 @@ template PyArray_DescrProto CustomFloatType::npy_descr_proto; template PyArray_Descr* CustomFloatType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomFloatType::dtype_meta = {}; // Representation of a Python custom float object. template @@ -841,6 +848,149 @@ bool RegisterFloatUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomFloat types +// --------------------------------------------------------------------------- + +// New-style getitem: (PyArray_Descr*, char*) -> PyObject* +template +static PyObject* NPyCustomFloat_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomFloat_GetItem(data, /*arr=*/nullptr); +} + +// New-style setitem: (PyArray_Descr*, PyObject*, char*) -> int +template +static int NPyCustomFloat_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomFloat_SetItem(item, data, /*arr=*/nullptr); +} + +// ensure_canonical: for a non-parametric dtype just return self. +template +static PyArray_Descr* NPyCustomFloat_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +// default_descr: return the singleton. +// This avoids use_new_as_default calling dm() -> arraydescr_new, which fails +// for legacy-flagged DTypes because the legacy-check branch errors out. +template +static PyArray_Descr* NPyCustomFloat_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +// True if every value of Src is exactly representable in Dst: Dst must have +// at least as many mantissa bits (precision) and at least as much exponent +// range as Src. +template +static constexpr bool CustomFloatSafeTo() { + return std::numeric_limits::digits >= std::numeric_limits::digits && + std::numeric_limits::max_exponent >= + std::numeric_limits::max_exponent; +} + +template +static PyArray_DTypeMeta* NPyCustomFloat_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. (should add complex here) + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) { + Py_INCREF(cls); + return cls; + } + + constexpr bool is_bfloat16 = std::is_same_v; + + if (PyTypeNum_ISINTEGER(other->type_num)) { + if (is_bfloat16 && (other->type_num == NPY_BYTE || other->type_num == NPY_UBYTE)) { + Py_INCREF(cls); + return cls; + } + /* Our precision is irrelevant, the integer one is higher. */ + return PyArray_CommonDType(&PyArray_PyFloatDType, other); + } + + switch (other->type_num) { + case NPY_BOOL: + Py_INCREF(cls); + return cls; + case NPY_HALF: + if (is_bfloat16) { + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); + return &PyArray_FloatDType; + } + [[fallthrough]]; + case NPY_FLOAT: case NPY_DOUBLE: case NPY_LONGDOUBLE: + [[fallthrough]]; + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + default: + break; + } + + // ---- Our own custom DTypes ---- + // Another custom float: use compile-time safe-cast predicate to pick the + // wider type; fall back to float32 when neither contains the other. + // T is known at compile time so all CustomFloatSafeTo calls fold away. +#define TRY_CUSTOM_FLOAT(OtherT) \ + if (other == &CustomFloatType::dtype_meta) { \ + if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(other); return other; \ + } else if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(cls); return cls; \ + } else { \ + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); \ + return &PyArray_FloatDType; \ + } \ + } + TRY_CUSTOM_FLOAT(bfloat16) + TRY_CUSTOM_FLOAT(float8_e3m4) + TRY_CUSTOM_FLOAT(float8_e4m3) + TRY_CUSTOM_FLOAT(float8_e4m3b11fnuz) + TRY_CUSTOM_FLOAT(float8_e4m3fn) + TRY_CUSTOM_FLOAT(float8_e4m3fnuz) + TRY_CUSTOM_FLOAT(float8_e5m2) + TRY_CUSTOM_FLOAT(float8_e5m2fnuz) + TRY_CUSTOM_FLOAT(float6_e2m3fn) + TRY_CUSTOM_FLOAT(float6_e3m2fn) + TRY_CUSTOM_FLOAT(float4_e2m1fn) + TRY_CUSTOM_FLOAT(float8_e8m0fnu) +#undef TRY_CUSTOM_FLOAT + + // Custom int: float dominates. NPyIntN_CommonDType returns NotImplemented + // for user types it can't see, so we handle this side explicitly. + if (other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + Py_INCREF(cls); + return cls; + } + + // Custom complex or unknown user type: swapping will work (NPyCustomComplex + // handles complex+float and returns the appropriate complex result). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + +template +static PyObject* NPyCustomFloat_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyCustomFloat_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -865,7 +1015,7 @@ bool RegisterFloatDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomFloatType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomFloat_GetItem; @@ -880,23 +1030,89 @@ bool RegisterFloatDtype(PyObject* numpy) { arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc; arr_funcs.argmin = NPyCustomFloat_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto (for PyArrayInitDTypeMeta_FromSpec_WithLegacy). PyArray_DescrProto& descr_proto = CustomFloatType::npy_descr_proto; descr_proto = GetCustomFloatDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + PyArray_DTypeMeta& dm = CustomFloatType::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyCustomFloat_DTypeRepr; + tp->tp_str = NPyCustomFloat_DTypeStr; + if (PyType_Ready(tp) < 0) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyCustomFloat_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomFloat_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomFloat_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomFloat_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomFloat_CommonDType)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyCustomFloat_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyCustomFloat_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyCustomFloat_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_fill, + reinterpret_cast(NPyCustomFloat_Fill)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyCustomFloat_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyCustomFloat_CompareFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmax, + reinterpret_cast(NPyCustomFloat_ArgMaxFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmin, + reinterpret_cast(NPyCustomFloat_ArgMinFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. + // The singleton is owned by dm; grab a borrowed reference for npy_descr. + // PyArray_DescrFromType returns a new reference — intentionally leaked. CustomFloatType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/dtype_compat.h b/ml_dtypes/_src/dtype_compat.h new file mode 100644 index 00000000..a5d51348 --- /dev/null +++ b/ml_dtypes/_src/dtype_compat.h @@ -0,0 +1,70 @@ +/* Copyright 2025 The ml_dtypes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + * Compatibility helper for registering DTypes that work with both the new-style + * NumPy DType API (PyArrayInitDTypeMeta_FromSpec) and the legacy + * PyArray_RegisterDataType path. + * + * The PyArrayInitDTypeMeta_FromSpec_WithLegacy function is a backport hack for + * NumPy 2.0. Starting with NumPy 2.5/2.6 (TBD) there will be native NumPy + * support and this can be replaced. + * + * NOTE: This file requires NumPy >= 2.0. + */ + +#ifndef ML_DTYPES_DTYPE_COMPAT_H_ +#define ML_DTYPES_DTYPE_COMPAT_H_ + +// clang-format off +#include "ml_dtypes/_src/numpy.h" // NOLINT (must be first) +// clang-format on + +#include +#include +#include "numpy/arrayobject.h" +#include "numpy/dtype_api.h" + +#if NPY_ABI_VERSION < 0x02000000 +#error "ml_dtypes dtype_compat.h requires NumPy >= 2.0" +#endif + +namespace ml_dtypes { + +/* + * Within-dtype copy strided loop: a plain memcpy per element. + * Registered as the within_dtype_castingimpl for every fixed-size dtype. + * MUST carry NPY_METH_SUPPORTS_UNALIGNED (NumPy enforces this for self-casts). + */ +static inline int TrivialStridedCopyLoop(PyArrayMethod_Context *context, + char *const data[], + npy_intp const dimensions[], + npy_intp const strides[], + NpyAuxData * /*auxdata*/) { + const npy_intp N = dimensions[0]; + const npy_intp elsize = context->descriptors[0]->elsize; + const char *in = data[0]; + char *out = data[1]; + for (npy_intp i = 0; i < N; ++i) { + std::memcpy(out, in, elsize); + in += strides[0]; + out += strides[1]; + } + return 0; +} + +} // namespace ml_dtypes + +#endif // ML_DTYPES_DTYPE_COMPAT_H_ diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 8e32a63c..80f59c52 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -25,8 +25,9 @@ limitations under the License. // clang-format on #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_compat.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/intn.h" #if NPY_ABI_VERSION < 0x02000000 @@ -56,6 +57,9 @@ struct IntNTypeDescriptor { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -66,6 +70,8 @@ template PyArray_DescrProto IntNTypeDescriptor::npy_descr_proto; template PyArray_Descr* IntNTypeDescriptor::npy_descr = nullptr; +template +PyArray_DTypeMeta IntNTypeDescriptor::dtype_meta = {}; // Representation of a Python custom integer object. template @@ -774,6 +780,96 @@ bool RegisterIntNUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for IntN types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyIntN_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { + return NPyIntN_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyIntN_NewStyleSetItem(PyArray_Descr* /*descr*/, PyObject* item, + char* data) { + return NPyIntN_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyIntN_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyIntN_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +template +static PyArray_DTypeMeta* NPyIntN_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) { + Py_INCREF(cls); + return cls; + } + else if (other == &PyArray_PyFloatDType) { + Py_INCREF(&PyArray_DoubleDType); + return &PyArray_DoubleDType; + } + else if (other == &PyArray_PyComplexDType) { + Py_INCREF(&PyArray_CDoubleDType); + return &PyArray_CDoubleDType; + } + + // Our intN types are smaller than every NumPy built-in except bool. + if (other->type_num == NPY_BOOL) { + Py_INCREF(cls); + return cls; + } + if (!PyTypeNum_ISUSERDEF(other->type_num)) { + Py_INCREF(other); + return other; + } + + // ---- Our own custom DTypes ---- + // Another custom int: lower type_num defers (swapping will work). + if (other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta || + other == &IntNTypeDescriptor::dtype_meta) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + // No cross-custom-int safe casts are registered; int16 contains all. + Py_INCREF(reinterpret_cast(&PyArray_Int16DType)); + return &PyArray_Int16DType; + } + + // Custom float or custom complex: swapping will work (NPyCustomFloat handles + // float+int, NPyCustomComplex handles complex+int). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + +template +static PyObject* NPyIntN_DTypeRepr(PyObject* /*self*/) { + return PyUnicode_FromFormat("dtype(%s)", TypeDescriptor::kTypeName); +} +template +static PyObject* NPyIntN_DTypeStr(PyObject* /*self*/) { + return PyUnicode_FromString(TypeDescriptor::kTypeName); +} + template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -799,7 +895,7 @@ bool RegisterIntNDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = IntNTypeDescriptor::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyIntN_GetItem; @@ -814,22 +910,87 @@ bool RegisterIntNDtype(PyObject* numpy) { arr_funcs.argmax = NPyIntN_ArgMaxFunc; arr_funcs.argmin = NPyIntN_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy proto. PyArray_DescrProto& descr_proto = IntNTypeDescriptor::npy_descr_proto; descr_proto = GetIntNDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + PyArray_DTypeMeta& dm = IntNTypeDescriptor::dtype_meta; + Py_SET_REFCNT(&dm, 1); + auto* tp = reinterpret_cast(&dm); + tp->tp_name = TypeDescriptor::kTypeName; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_repr = NPyIntN_DTypeRepr; + tp->tp_str = NPyIntN_DTypeStr; + if (PyType_Ready(tp) < 0) { + return false; + } + + // Build the within-dtype self-cast spec. + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {NPY_METH_unaligned_strided_loop, + reinterpret_cast(TrivialStridedCopyLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "copy"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + // Build the new-style DType spec. + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyIntN_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyIntN_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyIntN_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyIntN_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyIntN_CommonDType)}, + {NPY_DT_PyArray_ArrFuncs_getitem, + reinterpret_cast(NPyIntN_GetItem)}, + {NPY_DT_PyArray_ArrFuncs_setitem, + reinterpret_cast(NPyIntN_SetItem)}, + {NPY_DT_PyArray_ArrFuncs_nonzero, + reinterpret_cast(NPyIntN_NonZero)}, + {NPY_DT_PyArray_ArrFuncs_fill, + reinterpret_cast(NPyIntN_Fill)}, + {NPY_DT_PyArray_ArrFuncs_dotfunc, + reinterpret_cast(NPyIntN_DotFunc)}, + {NPY_DT_PyArray_ArrFuncs_compare, + reinterpret_cast(NPyIntN_CompareFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmax, + reinterpret_cast(NPyIntN_ArgMaxFunc)}, + {NPY_DT_PyArray_ArrFuncs_argmin, + reinterpret_cast(NPyIntN_ArgMinFunc)}, + {0, nullptr}}; + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = reinterpret_cast(type); + dtype_spec.flags = 0; + dtype_spec.casts = casts; + dtype_spec.slots = dtype_slots; + dtype_spec.baseclass = nullptr; + + if (PyArrayInitDTypeMeta_FromSpec(&dm, &dtype_spec) < 0) { + return false; + } + TypeDescriptor::npy_type = dm.type_num; - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { - return false; - } - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. IntNTypeDescriptor::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..f9946d42 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -22,6 +22,8 @@ limitations under the License. // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#define NPY_TARGET_VERSION NPY_2_0_API_VERSION + // We import_array in the ml_dtypes init function only. #define PY_ARRAY_UNIQUE_SYMBOL _ml_dtypes_numpy_api diff --git a/ml_dtypes/tests/result_type_test.py b/ml_dtypes/tests/result_type_test.py new file mode 100644 index 00000000..8554eb77 --- /dev/null +++ b/ml_dtypes/tests/result_type_test.py @@ -0,0 +1,267 @@ +# Copyright 2026 The ml_dtypes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for np.result_type() across ml_dtypes custom DTypes.""" + +import ml_dtypes +import numpy as np +import pytest + +# Short aliases for readability in parametrize lists +bf16 = ml_dtypes.bfloat16 +f4 = ml_dtypes.float4_e2m1fn +f6_e2m3 = ml_dtypes.float6_e2m3fn +f6_e3m2 = ml_dtypes.float6_e3m2fn +f8_e3m4 = ml_dtypes.float8_e3m4 +f8_e4m3 = ml_dtypes.float8_e4m3 +f8_e4m3fn = ml_dtypes.float8_e4m3fn +f8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +f8_e4m3b11 = ml_dtypes.float8_e4m3b11fnuz +f8_e5m2 = ml_dtypes.float8_e5m2 +f8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +f8_e8m0 = ml_dtypes.float8_e8m0fnu +bc32 = ml_dtypes.bcomplex32 +c32 = ml_dtypes.complex32 +i1, i2, i4 = ml_dtypes.int1, ml_dtypes.int2, ml_dtypes.int4 +u1, u2, u4 = ml_dtypes.uint1, ml_dtypes.uint2, ml_dtypes.uint4 + +ALL_CUSTOM_FLOATS = [bf16, f4, f6_e2m3, f6_e3m2, + f8_e3m4, f8_e4m3, f8_e4m3fn, f8_e4m3fnuz, + f8_e4m3b11, f8_e5m2, f8_e5m2fnuz, f8_e8m0] +ALL_INTN = [i1, i2, i4, u1, u2, u4] +ALL_CUSTOM_COMPLEX = [bc32, c32] + + +def rt(a, b): + return np.result_type(a, b) + + +# --------------------------------------------------------------------------- +# Custom float vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom float always wins ---- + (bf16, np.bool_, bf16), + (f8_e4m3fn, np.bool_, f8_e4m3fn), + (f4, np.bool_, f4), + # ---- floats: pick the wider ---- + (f4, np.float16, np.float16), # f4 fits in float16 + (f8_e4m3fn, np.float16, np.float16), # float8 fits in float16 + (f8_e5m2, np.float16, np.float16), # float8 fits in float16 + (bf16, np.float16, np.float32), # incomparable → float32 + (f8_e4m3fn, np.float32, np.float32), # all custom floats fit in float32 + (bf16, np.float32, np.float32), + (bf16, np.float64, np.float64), + (f8_e4m3fn, np.float64, np.float64), + # ---- integers: PyArray_CommonDType decides ---- + (bf16, np.int8, bf16), # bfloat16 has 8 sig bits, int8 needs 7 → bf16 wins + (bf16, np.int16, np.float64), # bfloat16 has 8 sig bits, int16 needs 15 → float64 + (f8_e4m3fn, np.int8, np.float64), # float8 can't represent all int8 values + (f8_e4m3fn, np.int32, np.float64), # float8 can't represent all int32 values + # ---- complex: other always wins ---- + (bf16, np.complex64, np.complex64), + (f8_e4m3fn, np.complex64, np.complex64), + (bf16, np.complex128, np.complex128), + (f8_e4m3fn, np.complex128, np.complex128), +]) +def test_custom_float_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom float vs custom float +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (bf16, bf16, bf16), + (f8_e4m3fn, f8_e4m3fn, f8_e4m3fn), + (f4, f4, f4), + # ---- narrower fits safely into wider ---- + (f4, f6_e2m3, f6_e2m3), # f4 ⊂ f6_e2m3 (more exp + mantissa) + (f4, f8_e4m3fn, f8_e4m3fn), # f4 fits in every float8+ + (f4, bf16, bf16), # f4 fits in bfloat16 + (f8_e4m3fn, bf16, bf16), # float8 fits in bfloat16 + (f8_e5m2, bf16, bf16), # float8 fits in bfloat16 + (f8_e3m4, bf16, bf16), # float8 fits in bfloat16 + # ---- incomparable: one has more exp, other more mantissa → float32 ---- + (bf16, f8_e5m2, bf16), # f8_e5m2 fits in bf16 (bf16 > in all dims) + (f8_e4m3fn, f8_e5m2, np.float32), # e4m3 has more mantissa, e5m2 has more exp + (f8_e4m3fn, f8_e4m3fnuz, f8_e4m3fn), # same digits/max_exp → numeric_limits match; fn wins + (f6_e2m3, f6_e3m2, np.float32), # one has more mantissa, other more exp +]) +def test_custom_float_vs_custom_float(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom float vs custom int (float always dominates) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("float_t, int_t", [ + (bf16, i4), + (bf16, u4), + (bf16, i1), + (f8_e4m3fn, i4), + (f8_e4m3fn, u4), + (f8_e5m2, i2), + (f4, i1), +]) +def test_custom_float_beats_custom_int(float_t, int_t): + assert rt(float_t, int_t) == np.dtype(float_t) + assert rt(int_t, float_t) == np.dtype(float_t) # symmetric + + +# --------------------------------------------------------------------------- +# Custom int vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom int always wins ---- + (i4, np.bool_, i4), + (u4, np.bool_, u4), + (i1, np.bool_, i1), + # ---- all other NumPy types: return other (intN is always smaller) ---- + (i4, np.int8, np.int8), + (i4, np.int16, np.int16), + (i4, np.int32, np.int32), + (i4, np.uint8, np.uint8), + (u4, np.int8, np.int8), + (i2, np.int8, np.int8), + (i4, np.float16, np.float16), + (i4, np.float32, np.float32), + (i4, np.float64, np.float64), + (i4, np.complex64, np.complex64), + (i4, np.complex128, np.complex128), + (u4, np.float32, np.float32), +]) +def test_custom_int_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom int vs custom int +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (i4, i4, i4), + (u4, u4, u4), + # ---- mixed sign: neither fits the other → int16 ---- + (i4, u4, np.int16), + (i2, u2, np.int16), + (i1, u1, np.int16), + # ---- same sign, different width → int16 fallback ---- + (i2, i4, np.int16), + (u2, u4, np.int16), + (i1, i4, np.int16), +]) +def test_custom_int_vs_custom_int(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom complex vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool + integers: wrap in cfloat ---- + (bc32, np.bool_, np.complex64), + (bc32, np.int8, np.complex64), + (bc32, np.int32, np.complex64), + (c32, np.bool_, np.complex64), + (c32, np.int8, np.complex64), + # ---- floats ≤ float32: wrap in cfloat ---- + (bc32, np.float16, np.complex64), + (bc32, np.float32, np.complex64), + (c32, np.float16, np.complex64), + (c32, np.float32, np.complex64), + # ---- float64+: need cdouble ---- + (bc32, np.float64, np.complex128), + (bc32, np.longdouble, np.clongdouble), + (c32, np.float64, np.complex128), + # ---- built-in complex: other always wins ---- + (bc32, np.complex64, np.complex64), + (bc32, np.complex128, np.complex128), + (c32, np.complex64, np.complex64), + (c32, np.complex128, np.complex128), +]) +def test_custom_complex_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom complex vs custom float / custom int +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- custom floats: all fit in cfloat alongside our complex ---- + (bc32, bf16, np.complex64), + (bc32, f8_e4m3fn, np.complex64), + (bc32, f8_e5m2, np.complex64), + (bc32, f4, np.complex64), + (c32, bf16, np.complex64), + (c32, f8_e4m3fn, np.complex64), + # ---- custom ints: all tiny, fit in cfloat ---- + (bc32, i4, np.complex64), + (bc32, u4, np.complex64), + (bc32, i1, np.complex64), + (c32, i4, np.complex64), + # ---- two custom complex types ---- + (bc32, c32, np.complex64), + (bc32, bc32, bc32), + (c32, c32, c32), +]) +def test_custom_complex_vs_custom(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Python scalars: 0, 0.0, 0.0j (abstract types) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype, scalar, expected", [ + # ---- custom floats dominate Python int and Python float ---- + (bf16, 0, bf16), + (bf16, 0.0, bf16), + (f8_e4m3fn, 0, f8_e4m3fn), + (f8_e4m3fn, 0.0, f8_e4m3fn), + (f4, 0, f4), + (f4, 0.0, f4), + # ---- custom float + Python complex → cfloat ---- + (bf16, 0.0j, np.complex64), + (f8_e4m3fn, 0.0j, np.complex64), + (f4, 0.0j, np.complex64), + # ---- custom ints dominate Python int and Python float ---- + (i4, 0, i4), + (i4, 0.0, i4), + (u4, 0, u4), + (u4, 0.0, u4), + # ---- custom complex dominates all Python scalars ---- + (bc32, 0, bc32), + (bc32, 0.0, bc32), + (bc32, 0.0j, bc32), + (c32, 0, c32), + (c32, 0.0, c32), + (c32, 0.0j, c32), +]) +def test_python_scalars(dtype, scalar, expected): + assert rt(dtype, scalar) == np.dtype(expected)