diff --git a/python/quadrants/lang/expr.py b/python/quadrants/lang/expr.py index 0369349d6..deee4de57 100644 --- a/python/quadrants/lang/expr.py +++ b/python/quadrants/lang/expr.py @@ -7,7 +7,12 @@ from quadrants.lang.common_ops import QuadrantsOperations from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError from quadrants.lang.matrix import make_matrix -from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type +from quadrants.lang.util import ( + cook_dtype, + is_matrix_class, + is_quadrants_class, + to_numpy_type, +) from quadrants.types import primitive_types from quadrants.types.primitive_types import integer_types, real_types @@ -109,12 +114,17 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int: def make_constant_expr(val, dtype): + # Normalise dtype once up front so the per-branch fallbacks only need to + # cook the runtime defaults (default_fp / default_ip). + if dtype is not None: + dtype = cook_dtype(dtype) + if isinstance(val, (bool, np.bool_)): - constant_dtype = primitive_types.u1 + constant_dtype = cook_dtype(primitive_types.u1) return Expr(_qd_core.make_const_expr_bool(constant_dtype, val)) if isinstance(val, (float, np.floating)): - constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp) if constant_dtype not in real_types: raise QuadrantsTypeError( "Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`." @@ -122,7 +132,7 @@ def make_constant_expr(val, dtype): return Expr(_qd_core.make_const_expr_fp(constant_dtype, val)) if isinstance(val, (int, np.integer)): - constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype + constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip) if constant_dtype not in integer_types: raise QuadrantsTypeError( "Integer literals must be annotated with a integer type. For type casting, use `qd.cast`." diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 08144b9b1..0036a3bb7 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -83,6 +83,7 @@ def expr_init_shared_array(shape, element_type): ast_builder = get_runtime().compiling_callable.ast_builder() debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info()) + element_type = cook_dtype(element_type) return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info) @@ -355,9 +356,9 @@ def __init__(self, kernels=None): self.grad_vars = [] self.dual_vars = [] self.matrix_fields = [] - self.default_fp = f32 - self.default_ip = i32 - self.default_up = u32 + self._default_fp = cook_dtype(f32) + self._default_ip = cook_dtype(i32) + self._default_up = cook_dtype(u32) self.print_full_traceback: bool = False self.target_tape = None self.fwd_mode_manager = None @@ -371,6 +372,30 @@ def __init__(self, kernels=None): self.unrolling_limit: int = 0 self.src_ll_cache: bool = True + @property + def default_fp(self) -> DataTypeCxx: + return self._default_fp + + @default_fp.setter + def default_fp(self, value: Any) -> None: + self._default_fp = cook_dtype(value) + + @property + def default_ip(self) -> DataTypeCxx: + return self._default_ip + + @default_ip.setter + def default_ip(self, value: Any) -> None: + self._default_ip = cook_dtype(value) + + @property + def default_up(self) -> DataTypeCxx: + return self._default_up + + @default_up.setter + def default_up(self, value: Any) -> None: + self._default_up = cook_dtype(value) + @property def compiling_callable(self) -> KernelCxx | Kernel | Function: if self._compiling_callable is None: @@ -737,10 +762,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual): if prog.config().debug: # adjoint checkbit x_grad_checkbit = Expr(prog.make_id_expr("")) - dtype = u8 + checkbit_dtype = u8 if prog.config().arch == _qd_core.vulkan: - dtype = i32 - x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype)) + checkbit_dtype = i32 + x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype)) x_grad_checkbit.ptr.set_name(name + ".grad_checkbit") x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT) x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr) diff --git a/python/quadrants/lang/matrix.py b/python/quadrants/lang/matrix.py index 01bf6eef6..6f44a0947 100644 --- a/python/quadrants/lang/matrix.py +++ b/python/quadrants/lang/matrix.py @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None): if len(arr) == 0: # the only usage of an empty vector is to serve as field indices shape = [0] - dt = primitive_types.i32 + dt = cook_dtype(primitive_types.i32) else: if isinstance(arr[0], Iterable): # matrix shape = [len(arr), len(arr[0])] diff --git a/python/quadrants/linalg/sparse_matrix.py b/python/quadrants/linalg/sparse_matrix.py index 7eb4f40be..09cfd75a3 100644 --- a/python/quadrants/linalg/sparse_matrix.py +++ b/python/quadrants/linalg/sparse_matrix.py @@ -9,6 +9,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.types import f32 @@ -24,11 +25,12 @@ class SparseMatrix: """ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"): - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if sm is None: self.n = n self.m = m if m else n - self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format) + self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format) else: self.n = sm.num_rows() self.m = sm.num_cols() @@ -247,7 +249,8 @@ def __init__( ): self.num_rows = num_rows self.num_cols = num_cols if num_cols else num_rows - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx if num_rows is not None: quadrants_arch = get_runtime().prog.config().arch if quadrants_arch in [ @@ -259,7 +262,7 @@ def __init__( num_rows, num_cols, max_num_triplets, - dtype, + dtype_cxx, storage_format, ) self.ptr.create_ndarray(get_runtime().prog) diff --git a/python/quadrants/linalg/sparse_solver.py b/python/quadrants/linalg/sparse_solver.py index 3544d1a95..e66de69b0 100644 --- a/python/quadrants/linalg/sparse_solver.py +++ b/python/quadrants/linalg/sparse_solver.py @@ -8,6 +8,7 @@ from quadrants.lang.exception import QuadrantsRuntimeError from quadrants.lang.field import Field from quadrants.lang.impl import get_runtime +from quadrants.lang.util import cook_dtype from quadrants.linalg.sparse_matrix import SparseMatrix from quadrants.types.primitive_types import f32 @@ -24,7 +25,8 @@ class SparseSolver: def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): self.matrix = None - self.dtype = dtype + dtype_cxx = cook_dtype(dtype) + self.dtype = dtype_cxx solver_type_list = ["LLT", "LDLT", "LU"] solver_ordering = ["AMD", "COLAMD"] if solver_type in solver_type_list and ordering in solver_ordering: @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): or quadrants_arch == _qd_core.Arch.cuda ), "SparseSolver only supports CPU and CUDA for now." if quadrants_arch == _qd_core.Arch.cuda: - self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering) else: - self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering) + self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering) else: raise QuadrantsRuntimeError( f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported." diff --git a/tests/python/test_ad_basics_fwd.py b/tests/python/test_ad_basics_fwd.py index fc37ef582..760db9a7c 100644 --- a/tests/python/test_ad_basics_fwd.py +++ b/tests/python/test_ad_basics_fwd.py @@ -124,3 +124,22 @@ def clear_dual_test(): with qd.ad.FwdMode(loss=loss, param=x): clear_dual_test() assert y.dual[None] == 4.0 + + +@test_utils.test(debug=True) +def test_dual_field_dtype_preserved_in_debug_mode(): + """Regression: debug-mode checkbit must not shadow the outer dtype.""" + x = qd.field(qd.f64, shape=(), needs_dual=True) + loss = qd.field(qd.f64, shape=(), needs_dual=True) + + x[None] = 3.0 + + @qd.kernel + def compute(): + loss[None] = x[None] * x[None] + + with qd.ad.FwdMode(loss=loss, param=x): + compute() + + assert loss[None] == 9.0 + assert loss.dual[None] == 6.0