Skip to content

Commit 9b5e1a8

Browse files
committed
[Refactor] Add cook_dtype calls at C++ boundaries (no-op preparatory refactor)
Add cook_dtype() calls at all points where dtype values are passed to C++ code. Make PyQuadrants.default_fp/ip/up into properties that always store DataTypeCxx. Rename shadowed dtype var in create_field_member. All changes are behavioral no-ops with current code, preparing for a future refactor of primitive dtypes into Python classes.
1 parent a402c99 commit 9b5e1a8

5 files changed

Lines changed: 56 additions & 16 deletions

File tree

python/quadrants/lang/expr.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from quadrants.lang.common_ops import QuadrantsOperations
88
from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError
99
from quadrants.lang.matrix import make_matrix
10-
from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type
10+
from quadrants.lang.util import (
11+
cook_dtype,
12+
is_matrix_class,
13+
is_quadrants_class,
14+
to_numpy_type,
15+
)
1116
from quadrants.types import primitive_types
1217
from quadrants.types.primitive_types import integer_types, real_types
1318

@@ -109,12 +114,16 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int:
109114

110115

111116
def make_constant_expr(val, dtype):
117+
if dtype is not None:
118+
dtype = cook_dtype(dtype)
119+
112120
if isinstance(val, (bool, np.bool_)):
113-
constant_dtype = primitive_types.u1
121+
constant_dtype = cook_dtype(primitive_types.u1)
114122
return Expr(_qd_core.make_const_expr_bool(constant_dtype, val))
115123

116124
if isinstance(val, (float, np.floating)):
117125
constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
126+
constant_dtype = cook_dtype(constant_dtype)
118127
if constant_dtype not in real_types:
119128
raise QuadrantsTypeError(
120129
"Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`."
@@ -123,6 +132,7 @@ def make_constant_expr(val, dtype):
123132

124133
if isinstance(val, (int, np.integer)):
125134
constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
135+
constant_dtype = cook_dtype(constant_dtype)
126136
if constant_dtype not in integer_types:
127137
raise QuadrantsTypeError(
128138
"Integer literals must be annotated with a integer type. For type casting, use `qd.cast`."

python/quadrants/lang/impl.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
def expr_init_shared_array(shape, element_type):
8484
ast_builder = get_runtime().compiling_callable.ast_builder()
8585
debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info())
86+
element_type = cook_dtype(element_type)
8687
return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)
8788

8889

@@ -355,9 +356,9 @@ def __init__(self, kernels=None):
355356
self.grad_vars = []
356357
self.dual_vars = []
357358
self.matrix_fields = []
358-
self.default_fp = f32
359-
self.default_ip = i32
360-
self.default_up = u32
359+
self._default_fp = cook_dtype(f32)
360+
self._default_ip = cook_dtype(i32)
361+
self._default_up = cook_dtype(u32)
361362
self.print_full_traceback: bool = False
362363
self.target_tape = None
363364
self.fwd_mode_manager = None
@@ -371,6 +372,30 @@ def __init__(self, kernels=None):
371372
self.unrolling_limit: int = 0
372373
self.src_ll_cache: bool = True
373374

375+
@property
376+
def default_fp(self) -> DataTypeCxx:
377+
return self._default_fp
378+
379+
@default_fp.setter
380+
def default_fp(self, value: Any) -> None:
381+
self._default_fp = cook_dtype(value)
382+
383+
@property
384+
def default_ip(self) -> DataTypeCxx:
385+
return self._default_ip
386+
387+
@default_ip.setter
388+
def default_ip(self, value: Any) -> None:
389+
self._default_ip = cook_dtype(value)
390+
391+
@property
392+
def default_up(self) -> DataTypeCxx:
393+
return self._default_up
394+
395+
@default_up.setter
396+
def default_up(self, value: Any) -> None:
397+
self._default_up = cook_dtype(value)
398+
374399
@property
375400
def compiling_callable(self) -> KernelCxx | Kernel | Function:
376401
if self._compiling_callable is None:
@@ -737,10 +762,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
737762
if prog.config().debug:
738763
# adjoint checkbit
739764
x_grad_checkbit = Expr(prog.make_id_expr(""))
740-
dtype = u8
765+
checkbit_dtype = u8
741766
if prog.config().arch == _qd_core.vulkan:
742-
dtype = i32
743-
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
767+
checkbit_dtype = i32
768+
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype))
744769
x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
745770
x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
746771
x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)

python/quadrants/lang/matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def make_matrix(arr, dt=None):
176176
if len(arr) == 0:
177177
# the only usage of an empty vector is to serve as field indices
178178
shape = [0]
179-
dt = primitive_types.i32
179+
dt = cook_dtype(primitive_types.i32)
180180
else:
181181
if isinstance(arr[0], Iterable): # matrix
182182
shape = [len(arr), len(arr[0])]

python/quadrants/linalg/sparse_matrix.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from quadrants.lang.exception import QuadrantsRuntimeError
1010
from quadrants.lang.field import Field
1111
from quadrants.lang.impl import get_runtime
12+
from quadrants.lang.util import cook_dtype
1213
from quadrants.types import f32
1314

1415

@@ -24,11 +25,12 @@ class SparseMatrix:
2425
"""
2526

2627
def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
27-
self.dtype = dtype
28+
dtype_cxx = cook_dtype(dtype)
29+
self.dtype = dtype_cxx
2830
if sm is None:
2931
self.n = n
3032
self.m = m if m else n
31-
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
33+
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format)
3234
else:
3335
self.n = sm.num_rows()
3436
self.m = sm.num_cols()
@@ -247,7 +249,8 @@ def __init__(
247249
):
248250
self.num_rows = num_rows
249251
self.num_cols = num_cols if num_cols else num_rows
250-
self.dtype = dtype
252+
dtype_cxx = cook_dtype(dtype)
253+
self.dtype = dtype_cxx
251254
if num_rows is not None:
252255
quadrants_arch = get_runtime().prog.config().arch
253256
if quadrants_arch in [
@@ -259,7 +262,7 @@ def __init__(
259262
num_rows,
260263
num_cols,
261264
max_num_triplets,
262-
dtype,
265+
dtype_cxx,
263266
storage_format,
264267
)
265268
self.ptr.create_ndarray(get_runtime().prog)

python/quadrants/linalg/sparse_solver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from quadrants.lang.exception import QuadrantsRuntimeError
99
from quadrants.lang.field import Field
1010
from quadrants.lang.impl import get_runtime
11+
from quadrants.lang.util import cook_dtype
1112
from quadrants.linalg.sparse_matrix import SparseMatrix
1213
from quadrants.types.primitive_types import f32
1314

@@ -24,7 +25,8 @@ class SparseSolver:
2425

2526
def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
2627
self.matrix = None
27-
self.dtype = dtype
28+
dtype_cxx = cook_dtype(dtype)
29+
self.dtype = dtype_cxx
2830
solver_type_list = ["LLT", "LDLT", "LU"]
2931
solver_ordering = ["AMD", "COLAMD"]
3032
if solver_type in solver_type_list and ordering in solver_ordering:
@@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
3537
or quadrants_arch == _qd_core.Arch.cuda
3638
), "SparseSolver only supports CPU and CUDA for now."
3739
if quadrants_arch == _qd_core.Arch.cuda:
38-
self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering)
40+
self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering)
3941
else:
40-
self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering)
42+
self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering)
4143
else:
4244
raise QuadrantsRuntimeError(
4345
f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."

0 commit comments

Comments
 (0)