Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/quadrants/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from quadrants.types.enums import SNodeGradType
from quadrants.types.ndarray_type import NdarrayType
from quadrants.types.primitive_types import (
PrimitiveBase,
all_types,
f16,
f32,
Expand Down Expand Up @@ -110,6 +111,8 @@ def expr_init(rhs):
return dict((key, expr_init(val)) for key, val in rhs.items())
if isinstance(rhs, _qd_core.DataTypeCxx):
return rhs
if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase):
return rhs.cxx
if isinstance(rhs, _qd_core.Arch):
return rhs
if isinstance(rhs, _Ndrange):
Expand Down
107 changes: 78 additions & 29 deletions python/quadrants/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,59 @@
from quadrants.lang import impl
from quadrants.types import Template
from quadrants.types.primitive_types import (
PrimitiveBase,
all_types,
f16,
f16_cxx,
f32,
f32_cxx,
f64,
f64_cxx,
i8,
i8_cxx,
i16,
i16_cxx,
i32,
i32_cxx,
i64,
i64_cxx,
u1,
u1_cxx,
u8,
u8_cxx,
u16,
u16_cxx,
u32,
u32_cxx,
u64,
u64_cxx,
)

MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types}
MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types}
_all_cxx_objs = (
f16_cxx,
f32_cxx,
f64_cxx,
i8_cxx,
i16_cxx,
i32_cxx,
i64_cxx,
u1_cxx,
u8_cxx,
u16_cxx,
u32_cxx,
u64_cxx,
)
for _cxx in _all_cxx_objs:
MAP_TYPE_IDS[id(_cxx)] = _cxx

# Pre-computed id-based cache for cook_dtype hot path.
# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result.
_cook_cache: dict[int, _qd_core.DataTypeCxx] = {}
for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64):
_cook_cache[id(_cls)] = _cls.cxx
for _cxx in _all_cxx_objs:
_cook_cache[id(_cxx)] = _cxx


def has_pytorch():
Expand Down Expand Up @@ -177,71 +214,74 @@ def to_quadrants_type(dt):
dt (DataType): The desired data type to convert.

Returns:
DataType: The counterpart data type in quadrants.
DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx).

"""
_type = type(dt)
if _type is int:
return MAP_TYPE_IDS[dt]
return cook_dtype(MAP_TYPE_IDS[dt])

if isinstance(dt, type) and issubclass(dt, PrimitiveBase):
return dt.cxx

if issubclass(_type, _qd_core.DataTypeCxx):
return dt

if dt == np.float32:
return f32
return f32.cxx
if dt == np.float64:
return f64
return f64.cxx
if dt == np.int32:
return i32
return i32.cxx
if dt == np.int64:
return i64
return i64.cxx
if dt == np.int8:
return i8
return i8.cxx
if dt == np.int16:
return i16
return i16.cxx
if dt == np.bool_:
return u1
return u1.cxx
if dt == np.uint8:
return u8
return u8.cxx
if dt == np.uint16:
return u16
return u16.cxx
if dt == np.uint32:
return u32
return u32.cxx
if dt == np.uint64:
return u64
return u64.cxx
if dt == np.half:
return f16
return f16.cxx

if has_pytorch():
import torch # pylint: disable=C0415

# pylint: disable=E1101
if dt == torch.float32:
return f32
return f32.cxx
if dt == torch.float64:
return f64
return f64.cxx
if dt == torch.int32:
return i32
return i32.cxx
if dt == torch.int64:
return i64
return i64.cxx
if dt == torch.int8:
return i8
return i8.cxx
if dt == torch.int16:
return i16
return i16.cxx
if dt == torch.bool:
return u1
return u1.cxx
if dt == torch.uint8:
return u8
return u8.cxx
if dt == torch.float16:
return f16
return f16.cxx

if hasattr(torch, "uint16"):
if dt == torch.uint16:
return u16
return u16.cxx
if dt == torch.uint32:
return u32
return u32.cxx
if dt == torch.uint64:
return u64
return u64.cxx

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")

Expand All @@ -264,8 +304,17 @@ def __hash__(self):


def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
# Convert Python dtype to CPP dtype
"""Convert Python dtype to C++ DataTypeCxx.

Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances,
and Python builtins (float, int, bool). Uses id-based cache for hot paths.
"""
cached = _cook_cache.get(id(dtype))
if cached is not None:
return cached
_type = type(dtype)
if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase):
return dtype.cxx
if issubclass(_type, _qd_core.DataTypeCxx):
return dtype
if issubclass(_type, _qd_core.Type):
Expand All @@ -275,7 +324,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
if dtype is int:
return impl.get_runtime().default_ip
if dtype is bool:
return u1
return u1.cxx
raise ValueError(f"Invalid data type {dtype}")


Expand Down
Loading
Loading