Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/quadrants/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from quadrants.lang.common_ops import QuadrantsOperations
from quadrants.lang.exception import QuadrantsCompilationError, QuadrantsTypeError
from quadrants.lang.matrix import make_matrix
from quadrants.lang.util import is_matrix_class, is_quadrants_class, to_numpy_type
from quadrants.lang.util import (
cook_dtype,
is_matrix_class,
is_quadrants_class,
to_numpy_type,
)
from quadrants.types import primitive_types
from quadrants.types.primitive_types import integer_types, real_types

Expand Down Expand Up @@ -109,20 +114,25 @@ def _clamp_unsigned_to_range(npty, val: np.integer | int) -> np.integer | int:


def make_constant_expr(val, dtype):
# Normalise dtype once up front so the per-branch fallbacks only need to
# cook the runtime defaults (default_fp / default_ip).
if dtype is not None:
dtype = cook_dtype(dtype)

if isinstance(val, (bool, np.bool_)):
constant_dtype = primitive_types.u1
constant_dtype = cook_dtype(primitive_types.u1)
return Expr(_qd_core.make_const_expr_bool(constant_dtype, val))

if isinstance(val, (float, np.floating)):
constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_fp)
if constant_dtype not in real_types:
raise QuadrantsTypeError(
"Floating-point literals must be annotated with a floating-point type. For type casting, use `qd.cast`."
)
return Expr(_qd_core.make_const_expr_fp(constant_dtype, val))

if isinstance(val, (int, np.integer)):
constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
constant_dtype = dtype if dtype is not None else cook_dtype(impl.get_runtime().default_ip)
if constant_dtype not in integer_types:
raise QuadrantsTypeError(
"Integer literals must be annotated with a integer type. For type casting, use `qd.cast`."
Expand Down
37 changes: 31 additions & 6 deletions python/quadrants/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
def expr_init_shared_array(shape, element_type):
ast_builder = get_runtime().compiling_callable.ast_builder()
debug_info = _qd_core.DebugInfo(get_runtime().get_current_src_info())
element_type = cook_dtype(element_type)
return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)


Expand Down Expand Up @@ -355,9 +356,9 @@ def __init__(self, kernels=None):
self.grad_vars = []
self.dual_vars = []
self.matrix_fields = []
self.default_fp = f32
self.default_ip = i32
self.default_up = u32
self._default_fp = cook_dtype(f32)
self._default_ip = cook_dtype(i32)
self._default_up = cook_dtype(u32)
self.print_full_traceback: bool = False
self.target_tape = None
self.fwd_mode_manager = None
Expand All @@ -371,6 +372,30 @@ def __init__(self, kernels=None):
self.unrolling_limit: int = 0
self.src_ll_cache: bool = True

@property
def default_fp(self) -> DataTypeCxx:
return self._default_fp

@default_fp.setter
def default_fp(self, value: Any) -> None:
self._default_fp = cook_dtype(value)

@property
def default_ip(self) -> DataTypeCxx:
return self._default_ip

@default_ip.setter
def default_ip(self, value: Any) -> None:
self._default_ip = cook_dtype(value)

@property
def default_up(self) -> DataTypeCxx:
return self._default_up

@default_up.setter
def default_up(self, value: Any) -> None:
self._default_up = cook_dtype(value)

@property
def compiling_callable(self) -> KernelCxx | Kernel | Function:
if self._compiling_callable is None:
Expand Down Expand Up @@ -737,10 +762,10 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
if prog.config().debug:
# adjoint checkbit
x_grad_checkbit = Expr(prog.make_id_expr(""))
dtype = u8
checkbit_dtype = u8
if prog.config().arch == _qd_core.vulkan:
dtype = i32
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
checkbit_dtype = i32
x_grad_checkbit.ptr = _qd_core.expr_field(x_grad_checkbit.ptr, cook_dtype(checkbit_dtype))
x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)
Expand Down
2 changes: 1 addition & 1 deletion python/quadrants/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def make_matrix(arr, dt=None):
if len(arr) == 0:
# the only usage of an empty vector is to serve as field indices
shape = [0]
dt = primitive_types.i32
dt = cook_dtype(primitive_types.i32)
else:
if isinstance(arr[0], Iterable): # matrix
shape = [len(arr), len(arr[0])]
Expand Down
11 changes: 7 additions & 4 deletions python/quadrants/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from quadrants.lang.exception import QuadrantsRuntimeError
from quadrants.lang.field import Field
from quadrants.lang.impl import get_runtime
from quadrants.lang.util import cook_dtype
from quadrants.types import f32


Expand All @@ -24,11 +25,12 @@ class SparseMatrix:
"""

def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
if sm is None:
self.n = n
self.m = m if m else n
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype_cxx, storage_format)
else:
self.n = sm.num_rows()
self.m = sm.num_cols()
Expand Down Expand Up @@ -247,7 +249,8 @@ def __init__(
):
self.num_rows = num_rows
self.num_cols = num_cols if num_cols else num_rows
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
if num_rows is not None:
quadrants_arch = get_runtime().prog.config().arch
if quadrants_arch in [
Expand All @@ -259,7 +262,7 @@ def __init__(
num_rows,
num_cols,
max_num_triplets,
dtype,
dtype_cxx,
storage_format,
)
self.ptr.create_ndarray(get_runtime().prog)
Expand Down
8 changes: 5 additions & 3 deletions python/quadrants/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from quadrants.lang.exception import QuadrantsRuntimeError
from quadrants.lang.field import Field
from quadrants.lang.impl import get_runtime
from quadrants.lang.util import cook_dtype
from quadrants.linalg.sparse_matrix import SparseMatrix
from quadrants.types.primitive_types import f32

Expand All @@ -24,7 +25,8 @@ class SparseSolver:

def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
self.matrix = None
self.dtype = dtype
dtype_cxx = cook_dtype(dtype)
self.dtype = dtype_cxx
solver_type_list = ["LLT", "LDLT", "LU"]
solver_ordering = ["AMD", "COLAMD"]
if solver_type in solver_type_list and ordering in solver_ordering:
Expand All @@ -35,9 +37,9 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
or quadrants_arch == _qd_core.Arch.cuda
), "SparseSolver only supports CPU and CUDA for now."
if quadrants_arch == _qd_core.Arch.cuda:
self.solver = _qd_core.make_cusparse_solver(dtype, solver_type, ordering)
self.solver = _qd_core.make_cusparse_solver(dtype_cxx, solver_type, ordering)
else:
self.solver = _qd_core.make_sparse_solver(dtype, solver_type, ordering)
self.solver = _qd_core.make_sparse_solver(dtype_cxx, solver_type, ordering)
else:
raise QuadrantsRuntimeError(
f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_ad_basics_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,22 @@ def clear_dual_test():
with qd.ad.FwdMode(loss=loss, param=x):
clear_dual_test()
assert y.dual[None] == 4.0


@test_utils.test(debug=True)
def test_dual_field_dtype_preserved_in_debug_mode():
"""Regression: debug-mode checkbit must not shadow the outer dtype."""
x = qd.field(qd.f64, shape=(), needs_dual=True)
loss = qd.field(qd.f64, shape=(), needs_dual=True)

x[None] = 3.0

@qd.kernel
def compute():
loss[None] = x[None] * x[None]

with qd.ad.FwdMode(loss=loss, param=x):
compute()

assert loss[None] == 9.0
assert loss.dual[None] == 6.0
Loading