From 9bbc16bc810861e161b2b9b7a59bbd1b9f1462a1 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:08:38 -0700 Subject: [PATCH] [Refactor] Make primitive dtypes Python classes wrapping DataTypeCxx Convert primitive dtypes (f32, i32, etc.) from bare DataTypeCxx module-level variables into Python classes with a PrimitiveMeta metaclass. Each class has a .cxx attribute holding the underlying DataTypeCxx, and the metaclass delegates __eq__, __hash__, __getattr__ for backward compatibility. Update cook_dtype, to_quadrants_type, MAP_TYPE_IDS, and type utility functions to handle the new class-based types. Add PrimitiveBase checks in expr_init and quant.py. --- python/quadrants/lang/impl.py | 3 + python/quadrants/lang/util.py | 107 +++++++--- python/quadrants/types/primitive_types.py | 248 +++++++++++++--------- python/quadrants/types/quant.py | 22 +- python/quadrants/types/utils.py | 32 ++- tests/python/test_binding.py | 4 +- 6 files changed, 272 insertions(+), 144 deletions(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 0036a3bb7..701ec7c37 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -62,6 +62,7 @@ from quadrants.types.enums import SNodeGradType from quadrants.types.ndarray_type import NdarrayType from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, f32, @@ -110,6 +111,8 @@ def expr_init(rhs): return dict((key, expr_init(val)) for key, val in rhs.items()) if isinstance(rhs, _qd_core.DataTypeCxx): return rhs + if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase): + return rhs.cxx if isinstance(rhs, _qd_core.Arch): return rhs if isinstance(rhs, _Ndrange): diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index 93536ce9d..17b4333a8 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -12,22 +12,59 @@ from quadrants.lang import impl from quadrants.types import Template from quadrants.types.primitive_types import ( + PrimitiveBase, all_types, f16, + f16_cxx, f32, + f32_cxx, f64, + f64_cxx, i8, + i8_cxx, i16, + i16_cxx, i32, + i32_cxx, i64, + i64_cxx, u1, + u1_cxx, u8, + u8_cxx, u16, + u16_cxx, u32, + u32_cxx, u64, + u64_cxx, ) -MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types} +MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types} +_all_cxx_objs = ( + f16_cxx, + f32_cxx, + f64_cxx, + i8_cxx, + i16_cxx, + i32_cxx, + i64_cxx, + u1_cxx, + u8_cxx, + u16_cxx, + u32_cxx, + u64_cxx, +) +for _cxx in _all_cxx_objs: + MAP_TYPE_IDS[id(_cxx)] = _cxx + +# Pre-computed id-based cache for cook_dtype hot path. +# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result. +_cook_cache: dict[int, _qd_core.DataTypeCxx] = {} +for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64): + _cook_cache[id(_cls)] = _cls.cxx +for _cxx in _all_cxx_objs: + _cook_cache[id(_cxx)] = _cxx def has_pytorch(): @@ -177,71 +214,74 @@ def to_quadrants_type(dt): dt (DataType): The desired data type to convert. Returns: - DataType: The counterpart data type in quadrants. + DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx). """ _type = type(dt) if _type is int: - return MAP_TYPE_IDS[dt] + return cook_dtype(MAP_TYPE_IDS[dt]) + + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dt if dt == np.float32: - return f32 + return f32.cxx if dt == np.float64: - return f64 + return f64.cxx if dt == np.int32: - return i32 + return i32.cxx if dt == np.int64: - return i64 + return i64.cxx if dt == np.int8: - return i8 + return i8.cxx if dt == np.int16: - return i16 + return i16.cxx if dt == np.bool_: - return u1 + return u1.cxx if dt == np.uint8: - return u8 + return u8.cxx if dt == np.uint16: - return u16 + return u16.cxx if dt == np.uint32: - return u32 + return u32.cxx if dt == np.uint64: - return u64 + return u64.cxx if dt == np.half: - return f16 + return f16.cxx if has_pytorch(): import torch # pylint: disable=C0415 # pylint: disable=E1101 if dt == torch.float32: - return f32 + return f32.cxx if dt == torch.float64: - return f64 + return f64.cxx if dt == torch.int32: - return i32 + return i32.cxx if dt == torch.int64: - return i64 + return i64.cxx if dt == torch.int8: - return i8 + return i8.cxx if dt == torch.int16: - return i16 + return i16.cxx if dt == torch.bool: - return u1 + return u1.cxx if dt == torch.uint8: - return u8 + return u8.cxx if dt == torch.float16: - return f16 + return f16.cxx if hasattr(torch, "uint16"): if dt == torch.uint16: - return u16 + return u16.cxx if dt == torch.uint32: - return u32 + return u32.cxx if dt == torch.uint64: - return u64 + return u64.cxx raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") @@ -264,8 +304,17 @@ def __hash__(self): def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: - # Convert Python dtype to CPP dtype + """Convert Python dtype to C++ DataTypeCxx. + + Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances, + and Python builtins (float, int, bool). Uses id-based cache for hot paths. + """ + cached = _cook_cache.get(id(dtype)) + if cached is not None: + return cached _type = type(dtype) + if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase): + return dtype.cxx if issubclass(_type, _qd_core.DataTypeCxx): return dtype if issubclass(_type, _qd_core.Type): @@ -275,7 +324,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx: if dtype is int: return impl.get_runtime().default_ip if dtype is bool: - return u1 + return u1.cxx raise ValueError(f"Invalid data type {dtype}") diff --git a/python/quadrants/types/primitive_types.py b/python/quadrants/types/primitive_types.py index 04b8fc4cb..c6fcdc856 100644 --- a/python/quadrants/types/primitive_types.py +++ b/python/quadrants/types/primitive_types.py @@ -1,159 +1,193 @@ -from typing import Union +from typing import ClassVar, Union from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx # ======================================== -# real types +# Raw C++ DataType instances (internal use) +# ======================================== + +f16_cxx = qd_python_core.DataType_f16 +f32_cxx = qd_python_core.DataType_f32 +f64_cxx = qd_python_core.DataType_f64 + +i8_cxx = qd_python_core.DataType_i8 +i16_cxx = qd_python_core.DataType_i16 +i32_cxx = qd_python_core.DataType_i32 +i64_cxx = qd_python_core.DataType_i64 + +u1_cxx = qd_python_core.DataType_u1 +u8_cxx = qd_python_core.DataType_u8 +u16_cxx = qd_python_core.DataType_u16 +u32_cxx = qd_python_core.DataType_u32 +u64_cxx = qd_python_core.DataType_u64 + + +# ======================================== +# Metaclass and base class for Python dtype wrappers +# ======================================== -# ---------------------------------------- -float16 = qd_python_core.DataType_f16 -"""16-bit precision floating point data type. -""" +class PrimitiveMeta(type): + """Metaclass that makes dtype classes behave like DataTypeCxx objects. -# ---------------------------------------- + Delegates attribute access and comparisons to the underlying .cxx object, + allowing existing code that does e.g. dtype.to_string() to keep working. + """ -f16 = float16 -"""Alias for :const:`~quadrants.types.primitive_types.float16` -""" + def __eq__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is other + if isinstance(other, DataTypeCxx): + return cls.cxx == other + return NotImplemented -# ---------------------------------------- + def __ne__(cls, other): + if isinstance(other, PrimitiveMeta): + return cls is not other + if isinstance(other, DataTypeCxx): + return cls.cxx != other + return NotImplemented -float32 = qd_python_core.DataType_f32 -"""32-bit single precision floating point data type. -""" + def __hash__(cls): + return hash(cls.cxx) -# ---------------------------------------- + def __repr__(cls): + return cls.cxx.to_string() -f32 = float32 -"""Alias for :const:`~quadrants.types.primitive_types.float32` -""" + def __getattr__(cls, name): + try: + return getattr(cls.cxx, name) + except AttributeError: + raise AttributeError(f"type object '{cls.__name__}' has no attribute '{name}'") from None -# ---------------------------------------- -float64 = qd_python_core.DataType_f64 -"""64-bit double precision floating point data type. -""" +class PrimitiveBase(metaclass=PrimitiveMeta): + """Base class for all primitive dtype classes. -# ---------------------------------------- + Each subclass has a `cxx` class variable holding the corresponding DataTypeCxx instance. + Subclasses auto-register themselves in the _registry for reverse lookup (DataTypeCxx -> Python class). + """ + + cxx: ClassVar[DataTypeCxx] + _registry: ClassVar[dict[DataTypeCxx, "type[PrimitiveBase]"]] = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if hasattr(cls, "cxx"): + PrimitiveBase._registry[cls.cxx] = cls + + +def cxx_to_py(dtype_cxx: DataTypeCxx) -> "type[PrimitiveBase]": + """Convert a DataTypeCxx to its corresponding Python dtype class.""" + return PrimitiveBase._registry[dtype_cxx] -f64 = float64 -"""Alias for :const:`~quadrants.types.primitive_types.float64` -""" -# ---------------------------------------- # ======================================== -# Integer types +# Floating point types +# ======================================== + + +class f16(PrimitiveBase): + """16-bit precision floating point data type.""" + + cxx = f16_cxx -# ---------------------------------------- -int8 = qd_python_core.DataType_i8 -"""8-bit signed integer data type. -""" +class f32(PrimitiveBase): + """32-bit single precision floating point data type.""" -# ---------------------------------------- + cxx = f32_cxx -i8 = int8 -"""Alias for :const:`~quadrants.types.primitive_types.int8` -""" -# ---------------------------------------- +class f64(PrimitiveBase): + """64-bit double precision floating point data type.""" -int16 = qd_python_core.DataType_i16 -"""16-bit signed integer data type. -""" + cxx = f64_cxx -# ---------------------------------------- -i16 = int16 -"""Alias for :const:`~quadrants.types.primitive_types.int16` -""" +float16 = f16 +float32 = f32 +float64 = f64 -# ---------------------------------------- +# ======================================== +# Signed integer types +# ======================================== + + +class i8(PrimitiveBase): + """8-bit signed integer data type.""" -int32 = qd_python_core.DataType_i32 -"""32-bit signed integer data type. -""" + cxx = i8_cxx -# ---------------------------------------- -i32 = int32 -"""Alias for :const:`~quadrants.types.primitive_types.int32` -""" +class i16(PrimitiveBase): + """16-bit signed integer data type.""" -# ---------------------------------------- + cxx = i16_cxx -int64 = qd_python_core.DataType_i64 -"""64-bit signed integer data type. -""" -# ---------------------------------------- +class i32(PrimitiveBase): + """32-bit signed integer data type.""" -i64 = int64 -"""Alias for :const:`~quadrants.types.primitive_types.int64` -""" + cxx = i32_cxx -# ---------------------------------------- -uint8 = qd_python_core.DataType_u8 -"""8-bit unsigned integer data type. -""" +class i64(PrimitiveBase): + """64-bit signed integer data type.""" -# ---------------------------------------- + cxx = i64_cxx + + +int8 = i8 +int16 = i16 +int32 = i32 +int64 = i64 + +# ======================================== +# Unsigned integer types +# ======================================== -uint1 = qd_python_core.DataType_u1 -"""1-bit unsigned integer data type. Same as booleans. -""" -# ---------------------------------------- +class u1(PrimitiveBase): + """1-bit unsigned integer data type. Same as booleans.""" -u1 = uint1 -"""Alias for :const:`~quadrants.types.primitive_types.uint1` -""" + cxx = u1_cxx -# ---------------------------------------- -u8 = uint8 -"""Alias for :const:`~quadrants.types.primitive_types.uint8` -""" +class u8(PrimitiveBase): + """8-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u8_cxx -uint16 = qd_python_core.DataType_u16 -"""16-bit unsigned integer data type. -""" -# ---------------------------------------- +class u16(PrimitiveBase): + """16-bit unsigned integer data type.""" -u16 = uint16 -"""Alias for :const:`~quadrants.types.primitive_types.uint16` -""" + cxx = u16_cxx -# ---------------------------------------- -uint32 = qd_python_core.DataType_u32 -"""32-bit unsigned integer data type. -""" +class u32(PrimitiveBase): + """32-bit unsigned integer data type.""" -# ---------------------------------------- + cxx = u32_cxx -u32 = uint32 -"""Alias for :const:`~quadrants.types.primitive_types.uint32` -""" -# ---------------------------------------- +class u64(PrimitiveBase): + """64-bit unsigned integer data type.""" -uint64 = qd_python_core.DataType_u64 -"""64-bit unsigned integer data type. -""" + cxx = u64_cxx -# ---------------------------------------- -u64 = uint64 -"""Alias for :const:`~quadrants.types.primitive_types.uint64` -""" +uint1 = u1 +uint8 = u8 +uint16 = u16 +uint32 = u32 +uint64 = u64 -# ---------------------------------------- +# ======================================== +# Ref type (unchanged) +# ======================================== class RefType: @@ -165,6 +199,10 @@ def ref(tp): return RefType(tp) +# ======================================== +# Type sets for fast lookup +# ======================================== + real_types = {f16, f32, f64, float} real_type_ids = {id(t) for t in real_types} @@ -172,7 +210,13 @@ def ref(tp): integer_type_ids = {id(t) for t in integer_types} all_types = real_types | integer_types -type_ids = {id(t) for t in all_types} +_py_type_ids = {id(t) for t in all_types} + +_all_cxx = {f16_cxx, f32_cxx, f64_cxx, i8_cxx, i16_cxx, i32_cxx, i64_cxx, u1_cxx, u8_cxx, u16_cxx, u32_cxx, u64_cxx} +cxx_type_ids = {id(t) for t in _all_cxx} + +# Combined set: matches both Python classes and DataTypeCxx instances +type_ids = _py_type_ids | cxx_type_ids _python_primitive_types = Union[int, float, bool, str, None] diff --git a/python/quadrants/types/quant.py b/python/quadrants/types/quant.py index b67e79ea1..780c2f208 100644 --- a/python/quadrants/types/quant.py +++ b/python/quadrants/types/quant.py @@ -3,12 +3,23 @@ For more details, read https://yuanming.quadrants.graphics/publication/2021-quanquadrants/quanquadrants.pdf. """ +from typing import Any + from quadrants._lib.utils import qd_python_core as _qd_python_core -from quadrants.types.primitive_types import i32 +from quadrants.types.primitive_types import PrimitiveBase, i32 _type_factory = _qd_python_core.get_type_factory_instance() +def _to_ptr(compute: Any) -> Any: + """Convert a dtype (Python class or DataTypeCxx) to a Type pointer for C++ APIs.""" + if isinstance(compute, type) and issubclass(compute, PrimitiveBase): + compute = compute.cxx + if isinstance(compute, _qd_python_core.DataTypeCxx): + return compute.get_ptr() + return compute + + def int(bits, signed=True, compute=None): # pylint: disable=W0622 """Generates a quantized type for integers. @@ -24,8 +35,7 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_ip if signed else impl.get_runtime().default_up - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) return _type_factory.get_quant_int_type(bits, signed, compute) @@ -46,8 +56,7 @@ def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None): from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # TODO: handle cases with bits > 32 underlying_type = int(bits=bits, signed=signed, compute=i32) if scale is None: @@ -74,8 +83,7 @@ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 from quadrants.lang import impl # pylint: disable=C0415 compute = impl.get_runtime().default_fp - if isinstance(compute, _qd_python_core.DataTypeCxx): - compute = compute.get_ptr() + compute = _to_ptr(compute) # Exponent is always unsigned exp_type = int(bits=exp, signed=False, compute=i32) # TODO: handle cases with frac > 32 diff --git a/python/quadrants/types/utils.py b/python/quadrants/types/utils.py index 0803085e2..e268279d8 100644 --- a/python/quadrants/types/utils.py +++ b/python/quadrants/types/utils.py @@ -1,11 +1,35 @@ +from typing import Any + from quadrants._lib import core as qd_python_core +from quadrants._lib.core.quadrants_python import DataTypeCxx +from quadrants.types.primitive_types import PrimitiveBase + +_is_signed_cxx = qd_python_core.is_signed +_is_integral_cxx = qd_python_core.is_integral +_is_real_cxx = qd_python_core.is_real +_is_tensor_cxx = qd_python_core.is_tensor + + +def _cook_if_needed(dt: Any) -> DataTypeCxx: + if isinstance(dt, type) and issubclass(dt, PrimitiveBase): + return dt.cxx + return dt # type: ignore[return-value] + + +def is_signed(dt: Any) -> bool: + return _is_signed_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + + +def is_integral(dt: Any) -> bool: + return _is_integral_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] + -is_signed = qd_python_core.is_signed +def is_real(dt: Any) -> bool: + return _is_real_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_integral = qd_python_core.is_integral -is_real = qd_python_core.is_real +def is_tensor(dt: Any) -> bool: + return _is_tensor_cxx(_cook_if_needed(dt)) # type: ignore[arg-type] -is_tensor = qd_python_core.is_tensor __all__ = ["is_signed", "is_integral", "is_real", "is_tensor"] diff --git a/tests/python/test_binding.py b/tests/python/test_binding.py index 4bf7fdb75..0e6338787 100644 --- a/tests/python/test_binding.py +++ b/tests/python/test_binding.py @@ -5,7 +5,7 @@ def test_binding(): qd.init() quadrants_lang = qd._lib.core print(quadrants_lang.BinaryOpType.mul) - one = quadrants_lang.make_const_expr_int(qd.i32, 1) - two = quadrants_lang.make_const_expr_int(qd.i32, 2) + one = quadrants_lang.make_const_expr_int(qd.i32.cxx, 1) + two = quadrants_lang.make_const_expr_int(qd.i32.cxx, 2) expr = quadrants_lang.make_binary_op_expr(quadrants_lang.BinaryOpType.add, one, two) print(quadrants_lang.make_global_store_stmt(None, None))