From 71dcd622626322b91a61c2440e6524d52187b433 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 20:06:29 -0700 Subject: [PATCH 1/6] [Refactor] Add cook_dtype calls at C++ boundaries (no-op preparatory refactor) Add cook_dtype() calls at all points where dtype values are passed to C++ code. Make PyQuadrants.default_fp/ip/up into properties that always store DataTypeCxx. Rename shadowed dtype var in create_field_member. All changes are behavioral no-ops with current code, preparing for a future refactor of primitive dtypes into Python classes. --- python/quadrants/lang/expr.py | 14 +++++++-- python/quadrants/lang/impl.py | 37 ++++++++++++++++++++---- python/quadrants/lang/matrix.py | 2 +- python/quadrants/linalg/sparse_matrix.py | 11 ++++--- python/quadrants/linalg/sparse_solver.py | 8 +++-- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 0369349d6..aedacc287 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -7,7 +7,12 @@ from quadrants.lang.common_ops import QuadrantsOperations from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError from quadrants.lang.matrix import make_matrix -from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type +from quadrants.lang.util import ( + cook_dtype, + is_matrix_class, + is_quadrants_class, + to_numpy_type, +) from quadrants.types import primitive_types from quadrants.types.primitive_types import integer_types, real_types @@ -109,12 +114,16 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + if dtype is not None: + dtype = cook_dtype(dtype) + if isinstance(val, (bool, np.bool_)): - constant_dtype = primitive_types.u1 + constant_dtype = cook_dtype(primitive_types.u1) return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype + constant_dtype = cook_dtype(constant_dtype) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -123,6 +132,7 @@ def make_constant_expr(val, dtype): if isinstance(val, (int, np.integer)): constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype + constant_dtype = cook_dtype(constant_dtype) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 08144b9b1..0036a3bb7 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -83,6 +83,7 @@ def expr_init_shared_array(shape, element_type): ast_builder = get_runtime().compiling_callable.ast_builder() debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info()) + element_type = cook_dtype(element_type) return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info) @@ -355,9 +356,9 @@ def __init__(self, kernels=None): self.grad_vars = [] self.dual_vars = [] self.matrix_fields = [] - self.default_fp = f32 - self.default_ip = i32 - self.default_up = u32 + self._default_fp = cook_dtype(f32) + self._default_ip = cook_dtype(i32) + self._default_up = cook_dtype(u32) self.print_full_traceback: bool = False self.target_tape = None self.fwd_mode_manager = None @@ -371,6 +372,30 @@ def __init__(self, kernels=None): self.unrolling_limit: int = 0 self.src_ll_cache: bool = True + @property + def default_fp(self) -> DataTypeCxx: + return self._default_fp + + @default_fp.setter + def default_fp(self, value: Any) -> None: + self._default_fp = cook_dtype(value) + + @property + def default_ip(self) -> DataTypeCxx: + return self._default_ip + + @default_ip.setter + def default_ip(self, value: Any) -> None: + self._default_ip = cook_dtype(value) + + @property + def default_up(self) -> DataTypeCxx: + return self._default_up + + @default_up.setter + def default_up(self, value: Any) -> None: + self._default_up = cook_dtype(value) + @property def compiling_callable(self) -> KernelCxx | Kernel | Function: if self._compiling_callable is None: @@ -737,10 +762,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual): if prog.config().debug: # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) - dtype = u8 + checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: - dtype = i32 - x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype)) + checkbit_dtype = i32 + x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype)) x_grad_checkbit.ptr.set_name(name + ".grad_checkbit") x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT) x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr) diff --git a/python/quadrants/lang/matrix.py b/python/quadrants/lang/matrix.py index 01bf6eef6..6f44a0947 100644 --- a/python/quadrants/lang/matrix.py +++ b/python/quadrants/lang/matrix.py @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None): if len(arr) == 0: # the only usage of an empty vector is to serve as field indices shape = [0] - dt = primitive_types.i32 + dt = cook_dtype(primitive_types.i32) else: if isinstance(arr[0], Iterable): # matrix shape = [len(arr), len(arr[0])] diff --git a/python/quadrants/linalg/sparse_matrix.py b/python/quadrants/linalg/sparse_matrix.py index 7eb4f40be..09cfd75a3 100644 --- a/python/quadrants/linalg/sparse_matrix.py +++ b/python/quadrants/linalg/sparse_matrix.py @@ -9,6 +9,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.types import f32 @@ -24,11 +25,12 @@ class SparseMatrix: """ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"): - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if sm is None: self.n = n self.m = m if m else n - self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format) + self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format) else: self.n = sm.num_rows() self.m = sm.num_cols() @@ -247,7 +249,8 @@ def __init__( ): self.num_rows = num_rows self.num_cols = num_cols if num_cols else num_rows - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if num_rows is not None: quadrants_arch = get_runtime().prog.config().arch if quadrants_arch in [ @@ -259,7 +262,7 @@ def __init__( num_rows, num_cols, max_num_triplets, - dtype, + dtype_cxx, storage_format, ) self.ptr.create_ndarray(get_runtime().prog) diff --git a/python/quadrants/linalg/sparse_solver.py b/python/quadrants/linalg/sparse_solver.py index 3544d1a95..e66de69b0 100644 --- a/python/quadrants/linalg/sparse_solver.py +++ b/python/quadrants/linalg/sparse_solver.py @@ -8,6 +8,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.linalg.sparse_matrix import SparseMatrix from quadrants.types.primitive_types import f32 @@ -24,7 +25,8 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): or quadrants_arch == _qd_core.Arch.cuda ), "SparseSolver only supports CPU and CUDA for now." if quadrants_arch == _qd_core.Arch.cuda: - self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering) else: - self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering) else: raise QuadrantsRuntimeError( f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported." From 1c6b67335c61ff0e4cd9b4512b70401eb05d844c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:38:28 -0700 Subject: [PATCH 2/6] [Doc] Document checkbit_dtype rename to avoid shadowing outer dtype The previous code overwrote the outer `dtype` parameter in the debug checkbit block, causing x_dual to be created with the wrong dtype. --- python/quadrants/lang/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 0036a3bb7..7ef6d18f0 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -760,7 +760,7 @@ def create_field_member(dtype, name, needs_grad, needs_dual): pyquadrants.grad_vars.append(x_grad) if prog.config().debug: - # adjoint checkbit + # adjoint checkbit — use a separate var to avoid shadowing the outer `dtype` x_grad_checkbit = Expr(prog.make_id_expr("")) checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: From a694d81500b1df037bd323b6f20e149d4119298b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:40:03 -0700 Subject: [PATCH 3/6] [Test] Add regression test for debug-mode dual field dtype shadowing Verify that forward-mode AD produces correct results when debug=True, guarding against the previous bug where the checkbit block's local dtype variable shadowed the outer dtype parameter. --- tests/python/test_ad_basics_fwd.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index fc37ef582..760db9a7c 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -124,3 +124,22 @@ def clear_dual_test(): with qd.ad.FwdMode(loss=loss, param=x): clear_dual_test() assert y.dual[None] == 4.0 + + +@test_utils.test(debug=True) +def test_dual_field_dtype_preserved_in_debug_mode(): + """Regression: debug-mode checkbit must not shadow the outer dtype.""" + x = qd.field(qd.f64, shape=(), needs_dual=True) + loss = qd.field(qd.f64, shape=(), needs_dual=True) + + x[None] = 3.0 + + @qd.kernel + def compute(): + loss[None] = x[None] * x[None] + + with qd.ad.FwdMode(loss=loss, param=x): + compute() + + assert loss[None] == 9.0 + assert loss.dual[None] == 6.0 From a5b414b7484697a09c240196721195050469e924 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:47:55 -0700 Subject: [PATCH 4/6] Remove inline comment about checkbit_dtype rename The regression test covers this; no need for a code comment. --- python/quadrants/lang/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 7ef6d18f0..0036a3bb7 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -760,7 +760,7 @@ def create_field_member(dtype, name, needs_grad, needs_dual): pyquadrants.grad_vars.append(x_grad) if prog.config().debug: - # adjoint checkbit — use a separate var to avoid shadowing the outer `dtype` + # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: From 215544f15d1a8296f32d4da35110d095b5a8651d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:52:02 -0700 Subject: [PATCH 5/6] Remove redundant cook_dtype calls in make_constant_expr When dtype is provided it is already cooked at the top of the function, so the per-branch cook_dtype(constant_dtype) was a no-op. Now only the fallback default_fp/default_ip paths are cooked. --- python/quadrants/lang/expr.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index aedacc287..8112aa154 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -122,8 +122,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): - constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype - constant_dtype = cook_dtype(constant_dtype) + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -131,8 +130,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_fp(constant_dtype, val)) if isinstance(val, (int, np.integer)): - constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype - constant_dtype = cook_dtype(constant_dtype) + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." From 2876fa15917ad8760cb55b376fa7aa921c03b08d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 11 Mar 2026 21:52:36 -0700 Subject: [PATCH 6/6] Add comment clarifying cook_dtype strategy in make_constant_expr --- python/quadrants/lang/expr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 8112aa154..deee4de57 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -114,6 +114,8 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + # Normalise dtype once up front so the per-branch fallbacks only need to + # cook the runtime defaults (default_fp / default_ip). if dtype is not None: dtype = cook_dtype(dtype)