From c7d673724ddd4616e86270bcd84c334b2f0f0c3a Mon Sep 17 00:00:00 2001 From: hugh Date: Sun, 17 May 2026 16:25:15 +0000 Subject: [PATCH 01/46] [wip] preserved baseline: stable_members mitigations + new failing test Baseline state of branch is the perf-mitigations work (cache bound callable, opt-in stable_members short-circuit for spec-key and args-hash walks, skip per-call _BoundedDifferentiableMethod alloc) plus a new test file pinning down the failure mode when calling a @qd.func taking a typed-dataclass arg from inside a @qd.data_oriented method that passes self.dataclass_member. The baseline (typed-dataclass kernel arg + @qd.func) passes. The four data_oriented variants all fail. --- python/quadrants/lang/_quadrants_callable.py | 19 +- python/quadrants/lang/_template_mapper.py | 9 + .../lang/_template_mapper_hotpath.py | 5 + python/quadrants/lang/kernel.py | 10 +- python/quadrants/lang/kernel_impl.py | 40 +++- .../test_data_oriented_qd_func_dataclass.py | 179 ++++++++++++++++++ 6 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 tests/python/test_data_oriented_qd_func_dataclass.py diff --git a/python/quadrants/lang/_quadrants_callable.py b/python/quadrants/lang/_quadrants_callable.py index ba7e7b8217..1bd50efe04 100644 --- a/python/quadrants/lang/_quadrants_callable.py +++ b/python/quadrants/lang/_quadrants_callable.py @@ -90,15 +90,32 @@ def __init__(self, fn: Callable, wrapper: Callable) -> None: self._adjoint: "Kernel | None" = None self.grad: "Kernel | None" = None self.is_pure: bool = False + self._attr_name: str | None = None update_wrapper(self, fn) + def __set_name__(self, owner: type, name: str) -> None: + # Captured at class-body time. ``data_oriented.make_kernel_indirect`` sets this + # explicitly on its replacement callable since setattr-after-class doesn't trigger + # __set_name__. + self._attr_name = name + def __call__(self, *args, **kwargs): return self.wrapper.__call__(*args, **kwargs) def __get__(self, instance, owner): if instance is None: return self - return BoundQuadrantsCallable(instance, self) + bound = BoundQuadrantsCallable(instance, self) + # Non-data descriptor (no __set__): a __dict__ entry on the instance wins over the + # descriptor on subsequent attribute lookups. Stash the bound callable there so future + # ``instance.method`` accesses skip __get__ allocation entirely (~0.6-1.2 us/call). + # Skip if the class uses __slots__ (no __dict__) or the attribute name wasn't captured. + name = self._attr_name + if name is not None: + inst_dict = getattr(instance, "__dict__", None) + if inst_dict is not None: + inst_dict[name] = bound + return bound class BoundQuadrantsCallable: diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index fd45b9913e..a700000d65 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -98,6 +98,15 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl nd_ids: list = [] for arg in args: if is_data_oriented(arg): + # Opt-out: classes that promise their ndarray members never reassign between calls + # (set ``_qd_stable_members = True`` on the class, or use + # ``@qd.data_oriented(stable_members=True)``) skip the per-call walk. The spec key + # then falls back to weakref(arg) alone — see _extract_arg's data_oriented branch. + # Saves ~1.1-1.5 us per kernel call on Genesis-style containers. Reassigning a + # member on a stable-marked instance is silently undefined behaviour: the cached + # kernel for the prior shape will be reused. + if type(arg).__dict__.get("_qd_stable_members"): + continue _collect_data_oriented_nd_ids(arg, nd_ids) if nd_ids: args_hash = args_hash + tuple(nd_ids) diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 6df1b54358..18c00e2d62 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -214,6 +214,11 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati # # Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is # a no-op for the existing data_oriented + qd.field workloads (genesis field-backend). + # + # Opt-out: ``_qd_stable_members = True`` on the class (or + # ``@qd.data_oriented(stable_members=True)``) skips the per-call descriptor walk. + if type(arg).__dict__.get("_qd_stable_members"): + return weakref.ref(arg) nd_descriptors: list = [] _collect_struct_nd_descriptors(arg, nd_descriptors) if nd_descriptors: diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 6b636e717d..be8c96eca4 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -476,10 +476,18 @@ def launch_kernel( if self._struct_ndarray_launch_info_by_key: struct_nd_info = self._struct_ndarray_launch_info_by_key.get(key) if struct_nd_info: + # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated + # with ``@qd.data_oriented(stable_members=True)``) promise their ndarray + # members are never reassigned, so we exclude them from the per-call + # ``_resolve_struct_ndarray`` walk that builds ``args_hash``. self._mutable_nd_cached_val = [ (idx, chain) for _, idx, chain in struct_nd_info - if type(args[idx]).__hash__ is None or is_data_oriented(args[idx]) + if type(args[idx]).__hash__ is None + or ( + is_data_oriented(args[idx]) + and not type(args[idx]).__dict__.get("_qd_stable_members") + ) ] else: self._mutable_nd_cached_val = [] diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 01c74a256d..9270050e74 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -275,7 +275,7 @@ def grad(self, *args, **kwargs) -> "Kernel": return self._adjoint(self._kernel_owner, *args, **kwargs) -def data_oriented(cls): +def data_oriented(cls=None, *, stable_members: bool = False): """Marks a class as Quadrants compatible. To allow for modularized code, Quadrants provides this decorator so that @@ -299,21 +299,41 @@ def data_oriented(cls): >>> a.inc() Args: - cls (Class): the class to be decorated + cls (Class): the class to be decorated. + stable_members (bool): if ``True``, declares that the class's ndarray-typed members are + allocated once and never reassigned between kernel calls. Quadrants will skip a + per-call walk of the instance's attributes (~1-2 us/call savings on Genesis-style + containers with several ndarray attrs). Reassigning a member on a ``stable_members`` + class is undefined behaviour — the previously-compiled kernel will be reused even if + the new ndarray has different dtype/ndim/layout. May also be set as a class-level + attribute ``_qd_stable_members = True`` (equivalent). Returns: - The decorated class. + The decorated class (or, when called with arguments, a decorator). """ + if cls is None: + return lambda c: data_oriented(c, stable_members=stable_members) + + def make_kernel_indirect(fun, is_property, attr_name): + # Capture the primal at decoration time so the per-call path skips the + # ``_BoundedDifferentiableMethod`` allocation. The class itself is validated when + # ``_BoundedDifferentiableMethod`` is invoked via the `.grad()` path; for the common + # primal call here we replicate the check inline. + primal = fun._primal - def make_kernel_indirect(fun, is_property): @wraps(fun) def _kernel_indirect(self, *args, **kwargs): - nonlocal fun - ret = _BoundedDifferentiableMethod(self, fun) - ret.__name__ = fun.__name__ # type: ignore - return ret(*args, **kwargs) + try: + return primal(self, *args, **kwargs) + except (QuadrantsCompilationError, QuadrantsRuntimeError) as e: + if impl.get_runtime().print_full_traceback: + raise e + raise type(e)("\n" + str(e)) from None ret = QuadrantsCallable(fun, _kernel_indirect) + # setattr-after-class doesn't trigger __set_name__; set the name explicitly so + # QuadrantsCallable.__get__ can cache the BoundQuadrantsCallable on instance.__dict__. + ret._attr_name = attr_name if is_property: ret = property(ret) return ret @@ -331,8 +351,10 @@ def _kernel_indirect(self, *args, **kwargs): if isinstance(fun, (BoundQuadrantsCallable, QuadrantsCallable)): if fun._is_wrapped_kernel: if fun._is_classkernel and attr_type is not staticmethod: - setattr(cls, name, make_kernel_indirect(fun, is_property)) + setattr(cls, name, make_kernel_indirect(fun, is_property, name)) cls._data_oriented = True + if stable_members: + cls._qd_stable_members = True return cls diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py new file mode 100644 index 0000000000..3733d54982 --- /dev/null +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -0,0 +1,179 @@ +"""Tests for calling @qd.func that takes a typed-dataclass arg, from a @qd.kernel +method of a @qd.data_oriented class, passing ``self.dataclass_member`` as the arg. + +Genesis's @qd.func helpers declare typed-dataclass parameters (e.g. +``def func(links_state: LinksState, ...):``) and are designed to be called from kernels +that also take typed-dataclass kernel args (so the dataclass is flattened into per-leaf +kernel-locals on both sides of the call boundary). + +When migrating Genesis modules to @qd.data_oriented, we'd like to call the same @qd.func +helpers from a data_oriented kernel method, passing ``self.links_state`` as the arg. +Today this fails at AST resolution: + + Missing argument '__qd_links_state__qd_cinr_inertial'. + Unexpected argument 'links_state'. + +These tests pin down the failure modes so we can fix them. +""" + +import dataclasses + +import numpy as np +import pytest + +import quadrants as qd +from tests import test_utils + + +# ----- typed-dataclass kernel-arg baseline (works) ---------------------------- + +@test_utils.test(arch=qd.cpu) +def test_baseline_typed_dataclass_kernel_arg_calls_qd_func(): + """Baseline: typed-dataclass kernel arg + qd.func taking same dataclass type — works.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.kernel + def run(state: State): + for i in range(N): + write_x(state, i, i * 3) + + state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + run(state) + np.testing.assert_array_equal(state.x.to_numpy(), np.arange(N) * 3) + + +# ----- data_oriented self-method calling qd.func (the broken case) ----------- + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_dataclass_member(): + """data_oriented holds a dataclass; self-kernel calls a @qd.func taking that dataclass.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + + @qd.kernel + def run(self): + for i in range(N): + write_x(self.state, i, i * 5) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 5) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_keyword_dataclass_member(): + """Same as above but the qd.func arg is passed by keyword (Genesis pattern).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 7) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 7) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_stable_members_method_calls_qd_func_with_dataclass_member(): + """Same as above but with stable_members=True (the FPS-relevant case).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 11) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 11) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_two_dataclass_members(): + """Two dataclass members, qd.func takes both — Genesis-shaped scenario.""" + N = 4 + + @dataclasses.dataclass + class StateA: + a: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class StateB: + b: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_both(sa: StateA, sb: StateB, i: qd.i32, va: qd.i32, vb: qd.i32): + sa.a[i] = va + sb.b[i] = vb + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.sa = StateA(a=qd.ndarray(qd.i32, shape=(N,))) + self.sb = StateB(b=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_both(sa=self.sa, sb=self.sb, i=i, va=i * 2, vb=i * 13) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.sa.a.to_numpy(), np.arange(N) * 2) + np.testing.assert_array_equal(solver.sb.b.to_numpy(), np.arange(N) * 13) From 93f597e3b7b452911034e719982ddbb95ccd63c1 Mon Sep 17 00:00:00 2001 From: hugh Date: Sun, 17 May 2026 17:12:10 +0000 Subject: [PATCH 02/46] [fix] Option A: expand dataclass-instance args in @qd.func calls from data_oriented self @qd.func helpers with typed-dataclass parameters were unreachable from @qd.data_oriented kernel methods that wanted to pass self.dataclass_member: the caller-side AST expansion in _expand_Call_dataclass_args / _expand_Call_dataclass_kwargs only fired for dataclass *types* attached to bare ast.Name nodes (typed kernel args), not for dataclass *instances* attached to ast.Attribute nodes (self.X access). Extend both expansion paths to recognise the instance-of-dataclass case and emit per-leaf ast.Attribute children. The positional path additionally threads the callee parameter name and callee_needed set through, so callee-side pruning of unused dataclass fields stays consistent with caller-side emission. Tests in tests/python/test_data_oriented_qd_func_dataclass.py: - baseline typed-arg + qd.func call (passes today) - data_oriented method + qd.func with positional dataclass member - ... with keyword dataclass member - ... with stable_members=True - ... with two dataclass members (Genesis-shaped) All 5 pass. Design doc: perso_hugh/doc/data_oriented_qd_func_dataclass.md (Option A chosen). --- .../ast/ast_transformers/call_transformer.py | 118 +++++++++++++++++- 1 file changed, 114 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformers/call_transformer.py b/python/quadrants/lang/ast/ast_transformers/call_transformer.py index 0d709ebd01..d9ae1a66b7 100644 --- a/python/quadrants/lang/ast/ast_transformers/call_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/call_transformer.py @@ -166,17 +166,26 @@ def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywo @staticmethod def _expand_Call_dataclass_args( - ctx: ASTTransformerFuncContext, args: tuple[ast.stmt, ...] + ctx: ASTTransformerFuncContext, + args: tuple[ast.stmt, ...], + called_needed: set[str] | None = None, + callee_arg_names: list[str] | None = None, ) -> tuple[tuple[ast.stmt, ...], tuple[ast.stmt, ...]]: """ We require that each node has a .ptr attribute added to it, that contains - the associated Python object + the associated Python object. + + ``called_needed`` and ``callee_arg_names`` are used only for the + attribute-accessed-instance branch (Option A for data_oriented @qd.func calls): + the caller cannot construct a flat name from its own ``arg.id`` (the arg is + an ast.Attribute), so we look up pruning against the callee's parameter name + at the same positional index. """ args_new = [] added_args = [] pruning = ctx.global_context.pruning func_id = ctx.func.func_id - for arg in args: + for arg_idx, arg in enumerate(args): val = arg.ptr if dataclasses.is_dataclass(val) and isinstance(val, type): dataclass_type = val @@ -204,6 +213,56 @@ def _expand_Call_dataclass_args( else: args_new.append(arg_node) added_args.append(arg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed positionally (e.g. ``self.state`` inside a + # @qd.data_oriented kernel method). Expand into per-leaf attribute accesses + # against the same AST node, mirroring the typed-arg (instance-of-type) path + # above but emitting ``ast.Attribute`` children rather than ``ast.Name``. + # ``added_args`` items must not carry ``.ptr`` (build_stmt populates it + # downstream); only the intermediate node used for recursion does. + dataclass_type = type(val) + # For pruning, match the callee's flat name (it may have pruned unused + # fields). Use the callee's parameter name at this positional index. + callee_param = ( + callee_arg_names[arg_idx] + if (called_needed is not None and callee_arg_names is not None and arg_idx < len(callee_arg_names)) + else None + ) + for field in dataclasses.fields(dataclass_type): + if called_needed is not None and callee_param is not None: + callee_flat_name = create_flat_name(callee_param, field.name) + if callee_flat_name not in called_needed: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + child_node = ast.Attribute( + value=arg, + attr=field.name, + ctx=load_ctx, + lineno=arg.lineno, + end_lineno=arg.end_lineno, + col_offset=arg.col_offset, + end_col_offset=arg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + child_node.ptr = child_val + # Recurse, threading the renamed scope: the callee's expanded flat + # name (e.g. ``__qd_state__inner``) is the synthetic param name for + # the nested level. + nested_callee_param = ( + create_flat_name(callee_param, field.name) if callee_param is not None else None + ) + _added_args, _args_new = CallTransformer._expand_Call_dataclass_args( + ctx, + (child_node,), + called_needed=called_needed, + callee_arg_names=[nested_callee_param] if nested_callee_param is not None else None, + ) + args_new.extend(_args_new) + added_args.extend(_added_args) + else: + args_new.append(child_node) + added_args.append(child_node) else: args_new.append(arg) return tuple(added_args), tuple(args_new) @@ -261,6 +320,47 @@ def _expand_Call_dataclass_kwargs( else: kwargs_new.append(kwarg_node) added_kwargs.append(kwarg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed as a keyword arg (e.g. + # ``write(state=self.state)`` inside a @qd.data_oriented kernel method). + # Expand into per-leaf keyword args whose values are attribute accesses + # against the original value node (e.g. ``__qd_state__x=self.state.x``). + dataclass_type = type(val) + for field in dataclasses.fields(dataclass_type): + child_name = create_flat_name(kwarg.arg, field.name) + if used_args is not None and child_name not in used_args: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + src_node = ast.Attribute( + value=kwarg.value, + attr=field.name, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + src_node.ptr = child_val + kwarg_node = ast.keyword( + arg=child_name, + value=src_node, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + kwarg_node.ptr = {child_name: child_val} + _added_kwargs, _kwargs_new = CallTransformer._expand_Call_dataclass_kwargs( + ctx, [kwarg_node], used_args + ) + kwargs_new.extend(_kwargs_new) + added_kwargs.extend(_added_kwargs) + else: + kwargs_new.append(kwarg_node) + added_kwargs.append(kwarg_node) else: kwargs_new.append(kwarg) return added_kwargs, kwargs_new @@ -286,11 +386,21 @@ def build_Call(ctx: ASTTransformerFuncContext, node: ast.Call, build_stmt, build is_func_base_wrapper = func_type in {QuadrantsCallable, BoundQuadrantsCallable} pruning = ctx.global_context.pruning called_needed = None + callee_arg_names: list[str] | None = None if pruning.enforcing and is_func_base_wrapper: called_func_id_ = func.wrapper.func_id # type: ignore called_needed = pruning.used_vars_by_func_id[called_func_id_] + if is_func_base_wrapper: + # callee param names (used by the attribute-instance positional-expansion path + # so it can match the callee's already-pruned flat names). + try: + callee_arg_names = [m.name for m in func.wrapper.arg_metas] # type: ignore[attr-defined] + except AttributeError: + callee_arg_names = None - added_args, node_args = CallTransformer._expand_Call_dataclass_args(ctx, node.args) + added_args, node_args = CallTransformer._expand_Call_dataclass_args( + ctx, node.args, called_needed=called_needed, callee_arg_names=callee_arg_names + ) added_keywords, node_keywords = CallTransformer._expand_Call_dataclass_kwargs(ctx, node.keywords, called_needed) # Create variables for the now-expanded dataclass members. From c25f49cc164023ec59ef705d0bde1dc8dbb5d318 Mon Sep 17 00:00:00 2001 From: hugh Date: Sun, 17 May 2026 17:26:37 +0000 Subject: [PATCH 03/46] [test] nested dataclasses + chained @qd.func calls from data_oriented self Adds 4 tests: - nested dataclass (Outer{Inner{ndarray}}) passed via self.outer, positional - nested dataclass passed via self.outer, kwarg (stable_members=True) - two-step @qd.func chain (outer_write -> inner_write) with self.state - combined: nested dataclass threaded through a 2-step @qd.func chain All pass. The outermost data_oriented call site uses the new instance-of-dataclass branch (with recursion threading callee_param); inner qd.func -> qd.func calls use the original typed-arg expansion path unchanged. --- .../test_data_oriented_qd_func_dataclass.py | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py index 3733d54982..4aee084f28 100644 --- a/tests/python/test_data_oriented_qd_func_dataclass.py +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -177,3 +177,146 @@ def run(self): solver.run() np.testing.assert_array_equal(solver.sa.a.to_numpy(), np.arange(N) * 2) np.testing.assert_array_equal(solver.sb.b.to_numpy(), np.arange(N) * 13) + + +# ----- nested dataclass -------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_member(): + """data_oriented holds an Outer{ Inner{ ndarray } } and passes ``self.outer`` to a + @qd.func that expands the nested dataclass into flat leaves on both sides.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(self.outer, i, i * 17) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 17) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_kwarg(): + """Same as above but the dataclass arg is passed by keyword.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(outer=self.outer, i=i, v=i * 19) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 19) + + +# ----- chained @qd.func calls (qd.func -> qd.func, dataclass threaded through) --- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_dataclass_member(): + """data_oriented kernel calls outer @qd.func, which in turn calls inner @qd.func, + threading the same dataclass arg through. Both qd.funcs have the typed-dataclass + parameter; only the outermost call site (data_oriented method body) uses self.X. + The two inner call sites use the typed-arg path that already worked.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def inner_write(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.func + def outer_write(state: State, i: qd.i32, v: qd.i32): + inner_write(state, i, v) + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.state, i, i * 23) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 23) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_nested_dataclass_member(): + """Combination: nested dataclass passed through a chain of two @qd.func calls + from a @qd.data_oriented self-method via self.outer.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def inner_write(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.func + def outer_write(outer: Outer, i: qd.i32, v: qd.i32): + inner_write(outer, i, v) + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.outer, i, i * 29) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 29) From 8f64016afbbc7c2bc161098033ef19c036e96173 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 11:31:59 -0700 Subject: [PATCH 04/46] [Perf] Prune unused @qd.data_oriented ndarrays via existing pruning machinery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a @qd.data_oriented `self` is passed as a `qd.template()` kernel arg, `_predeclare_struct_ndarrays` walks the entire object graph and registers every reachable ndarray as a kernel parameter. For real-world classes (e.g. Genesis's RigidSolver) that's hundreds of ndarrays per kernel, even when the kernel only touches a few — every extra arg slows down each launch's launch-context population. Hook into the same 2-pass compile machinery that prunes typed-dataclass arg flat-names: - Pass 0 (non-enforcing): `_predeclare_struct_ndarrays` registers every reachable ndarray as today. `_promote_ndarray_if_declared` now records `id(ndarray)` in `pruning.used_struct_ndarray_ids` whenever an attribute chain like `self.x.y` resolves to one of these pre-declared ndarrays — both for direct accesses in the kernel body and for accesses inside inlined `@qd.func` bodies. - Pass 1 (enforcing): `_predeclare_struct_ndarrays` only registers ndarrays whose id was observed in pass 0. Unused ndarrays are dropped from the kernel's parameter list and from `struct_ndarray_launch_info`, so neither compile nor each launch pays for them. On a Genesis non-batched single-Franka CPU rigid step with `RigidSolver` migrated to `@qd.data_oriented(stable_members=True)`: - step_1 ndarray-args: 326 -> 217 (-109) - step_2 ndarray-args: 326 -> 145 (-181) - steady-state step time: 493 us -> 403 us (FPS 2030 -> 2482) Fastcache hit (pass-0 skipped) is gated via `pruning.pass_0_ran`: the set is unreliable in that case so we fall back to registering every reachable ndarray, matching historical behavior. --- python/quadrants/lang/_pruning.py | 11 +++++++++++ python/quadrants/lang/ast/ast_transformer.py | 17 ++++++++++++++--- .../function_def_transformer.py | 17 +++++++++++++++++ python/quadrants/lang/kernel.py | 2 ++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 3289365767..b5b1f97a27 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -39,6 +39,17 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_used_parameters) # only needed for args, not kwargs self.callee_param_by_caller_arg_name_by_func_id: dict[int, dict[str, str]] = defaultdict(dict) + # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. + # Populated by the AST builder when a chain like ``self.x.y`` resolves to an ndarray + # that was pre-declared by ``_predeclare_struct_ndarrays``. On the second (enforcing) + # pass, ``_predeclare_struct_ndarrays`` only registers ndarrays whose id is in this set + # — dropping every reachable-but-unused ndarray from the kernel's parameter list. + self.used_struct_ndarray_ids: set[int] = set() + # Whether the non-enforcing first pass actually ran for this kernel materialize. + # When fastcache hits, we skip pass 0 entirely and ``used_struct_ndarray_ids`` is + # therefore unreliable — in that case ``_predeclare_struct_ndarrays`` falls back to + # registering every reachable ndarray (same as the historical behavior). + self.pass_0_ran: bool = False def mark_used(self, func_id: int, parameter_flat_name: str) -> None: assert not self.enforcing diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 263a4a11a3..f22543f32a 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -656,14 +656,25 @@ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerFuncContext, n @staticmethod def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> Any: """If *value* is a bare ``Ndarray`` that was pre-declared as a kernel arg (in ``_predeclare_struct_ndarrays``), - return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged.""" + return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged. + + Also records the ndarray id in ``pruning.used_struct_ndarray_ids`` on the non-enforcing + first pass, so that the enforcing second-pass ``_predeclare_struct_ndarrays`` can skip + ndarrays that the kernel never actually accesses. + """ from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 if not isinstance(value, Ndarray): return value cache = ctx.global_context.ndarray_to_any_array - arr = cache.get(id(value)) - return arr if arr is not None else value + key = id(value) + arr = cache.get(key) + if arr is not None: + pruning = ctx.global_context.pruning + if not pruning.enforcing: + pruning.used_struct_ndarray_ids.add(key) + return arr + return value @staticmethod def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 1bdd14dbd8..40f3750163 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -227,11 +227,26 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: Also stores ``(arg_id, template_arg_idx, attr_chain)`` tuples in ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. + + Pruning: in the enforcing (second) compile pass, ``pruning.used_struct_ndarray_ids`` + contains the set of ``id(ndarray)`` values that ``_promote_ndarray_if_declared`` + observed being accessed during the first pass (directly in the kernel body, or + transitively through ``@qd.func`` inlining). We register only those, dropping every + unused ndarray from the kernel's parameter list. On the first pass the set is empty + / not yet populated, so we register everything as today (correctness: the first + pass needs every reachable ndarray in the cache for ``build_Attribute`` to resolve + the accesses that *will* populate the set). """ from quadrants.lang.util import cook_dtype # pylint: disable=C0415 cache = ctx.global_context.ndarray_to_any_array launch_info = ctx.global_context.struct_ndarray_launch_info + pruning = ctx.global_context.pruning + used_ids = getattr(pruning, "used_struct_ndarray_ids", None) + # Only prune on the enforcing pass when we actually ran pass 0 to populate the + # used-ndarray set. On a fastcache hit pass 0 is skipped and the set is empty — + # fall back to registering every reachable ndarray. + prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) def _walk_obj(obj, arg_idx, path): if dataclasses.is_dataclass(obj) and not isinstance(obj, type): @@ -258,6 +273,8 @@ def _register_ndarray(nd, arg_idx, attr_chain): key = id(nd) if key in cache: return + if prune and key not in used_ids: + return from quadrants._lib import core as _qd_core # pylint: disable=C0415 element_type = cook_dtype(nd.element_type) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index be8c96eca4..5ce7a936b4 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -403,6 +403,8 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . range_begin = 0 if used_py_dataclass_parameters is None else 1 runtime = impl.get_runtime() for _pass in range(range_begin, 2): + if _pass == 0: + pruning.pass_0_ran = True if _pass >= 1: pruning.enforce() tree, ctx = self.get_tree_and_ctx( From aa9a88fb7c368145b92d1e73e81018053850ff19 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 13:37:42 -0700 Subject: [PATCH 05/46] [Fix] Fastcache hasher: skip QuadrantsCallable/BoundQuadrantsCallable in data_oriented walk Mitigation 1 (perf branch) stashes a per-instance BoundQuadrantsCallable in instance.__dict__ on first instance.method access so subsequent lookups skip __get__ allocation. The fastcache args-hasher's @qd.data_oriented walk iterates over obj.__dict__ and previously fell through to the [FASTCACHE][PARAM_INVALID] warning when it encountered that cached entry, disabling fastcache for the whole call (reproduced by test_fastcache_kernel_parameter). These descriptor-cache entries are not data; skip them in the walk so the fastcache key only reflects real members. --- python/quadrants/lang/_fast_caching/args_hasher.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 1a949d3007..455d51667c 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -12,6 +12,7 @@ from quadrants.types.annotations import Template from .._ndarray import ScalarNdarray +from .._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from ..field import ScalarField from ..kernel_arguments import ArgMetadata from ..matrix import MatrixField, MatrixNdarray, VectorNdarray @@ -182,6 +183,13 @@ def stringify_obj_type( except AttributeError: _dict = obj.__dict__ for k, v in _dict.items(): + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` + # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so + # that subsequent ``instance.method`` lookups skip the descriptor allocation; + # those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) if _child_repr is None: if _should_warn: From fd8c4402a8ba7a92a49fe0360f2a12a68cd04f0a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 13:55:46 -0700 Subject: [PATCH 06/46] [Perf] Don't over-mark ndarrays during @qd.func dataclass-arg expansion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mitigation 5's first cut over-conservatively marked every ndarray reachable from a wholesale-passed dataclass: Option A in call_transformer expands func(self.dc) to per-leaf children func(self.dc.x, self.dc.y, ...), build_stmt runs on each, and _promote_ndarray_if_declared was marking the id as used regardless of whether the callee actually touches it. This left ~205 unused ndarray args still registered per step in the Genesis rigid_solver migration. Two coordinated fixes: 1. Mirror build_Name's expanding_dataclass_call_parameters gate in _promote_ndarray_if_declared. The leaf accesses synthesized by Option A don't represent the kernel body actually touching the ndarray — only the callee body's own accesses (which run with the flag = False) should count. 2. Tag each pre-declared ndarray's AnyArray proxy with _qd_source_ndarray_id. After Option A's expansion, the callee's typed-arg flat-name locals are bound to already-promoted AnyArrays, so when the inlined callee body accesses them, the value reaching _promote_ndarray_if_declared isn't an Ndarray anymore. Tagging lets us mark via the AnyArray too. On Genesis non-batched single-Franka CPU with rigid_solver migrated to @qd.data_oriented(stable_members=True): - step_1 ndarray-args: 217 -> 120 (matches baseline exactly) - step_2 ndarray-args: 145 -> 37 (matches baseline exactly) - total ndarray-args/step: 644 -> 439 (matches baseline exactly) - steady-state step time: 403 us -> 337 us (vs baseline 338-345 us) The migration is now performance-neutral (was -33% FPS, then -22%, now ~0%). 1173 tests pass; the same 8 quadrants-main pre-existing failures remain (4x test_ad_global_data_access_rule_checker, etc.). --- python/quadrants/lang/ast/ast_transformer.py | 39 +++++++++++++------ .../function_def_transformer.py | 4 ++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index f22543f32a..977eec7fca 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -658,22 +658,37 @@ def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> """If *value* is a bare ``Ndarray`` that was pre-declared as a kernel arg (in ``_predeclare_struct_ndarrays``), return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged. - Also records the ndarray id in ``pruning.used_struct_ndarray_ids`` on the non-enforcing - first pass, so that the enforcing second-pass ``_predeclare_struct_ndarrays`` can skip - ndarrays that the kernel never actually accesses. + Also records the source ndarray id in ``pruning.used_struct_ndarray_ids`` on the + non-enforcing first pass, so that the enforcing second-pass + ``_predeclare_struct_ndarrays`` can skip ndarrays that the kernel never actually + accesses. Both ``Ndarray`` instances and pre-existing ``AnyArray`` proxies (tagged + with ``_qd_source_ndarray_id``) are handled — the latter is the case for accesses + in inlined ``@qd.func`` bodies whose params were bound to already-promoted proxies + by Option A in ``call_transformer``. """ from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 - if not isinstance(value, Ndarray): + pruning = ctx.global_context.pruning + # Mirror ``build_Name``'s mark_used gate: only mark on the non-enforcing first pass + # and not during synthetic per-leaf argument expansion for ``@qd.func`` calls. The + # callee body's own accesses (which run with ``expanding_dataclass_call_parameters + # = False``) are what we want to count. + should_mark = not pruning.enforcing and not ctx.expanding_dataclass_call_parameters + if isinstance(value, Ndarray): + cache = ctx.global_context.ndarray_to_any_array + key = id(value) + arr = cache.get(key) + if arr is not None: + if should_mark: + pruning.used_struct_ndarray_ids.add(key) + return arr return value - cache = ctx.global_context.ndarray_to_any_array - key = id(value) - arr = cache.get(key) - if arr is not None: - pruning = ctx.global_context.pruning - if not pruning.enforcing: - pruning.used_struct_ndarray_ids.add(key) - return arr + # Pre-promoted ``AnyArray`` flowing through an inlined ``@qd.func`` body. Mark the + # underlying ndarray as used so it survives the enforcing-pass pruning. + if should_mark: + src_id = getattr(value, "_qd_source_ndarray_id", None) + if src_id is not None: + pruning.used_struct_ndarray_ids.add(src_id) return value @staticmethod diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 40f3750163..ba74bddc7a 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -289,6 +289,10 @@ def _register_ndarray(nd, arg_idx, attr_chain): _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), _qd_layout=layout, ) + # Tag the AnyArray with the source ndarray id so ``_promote_ndarray_if_declared`` + # can mark this ndarray as used even when the access reaches it via an already- + # promoted AnyArray (e.g. callee bodies bound to per-leaf args by Option A). + arr._qd_source_ndarray_id = key cache[key] = arr launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) From e3a3d88ba0dc0eb10b61503fb02ab6175cf5dc50 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 17:12:46 -0700 Subject: [PATCH 07/46] [Perf] TemplateMapper.lookup: only walk template-slot args, cache per-class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The args_hash data_oriented walker added in a0db648b2 ([Fix] args_hash invalidates when data_oriented ndarray member is reassigned) ran unconditionally for every arg of every kernel call. Even after 93893e5f2 cached the per-class attribute paths, the per-call ``is_data_oriented(arg)`` + ``type(arg).__dict__.get`` chain still cost ~15% FPS on small-step CPU benches (anymal_zero CPU bs=0: 7231 -> 5955 FPS = -17.6% vs the pre-branch reference). Two coordinated optimisations: 1. Only iterate ``self.template_slot_locations`` instead of all args. Typed-dataclass args carry a specific dataclass type by construction and a data_oriented class is never a dataclass, so the only positions where a data_oriented container can appear are the ``qd.template()`` annotated ones — already tracked by the kernel decorator. Genesis main ``kernel_step_1`` has 4 template positions of 16 args; reduces the per-call work proportionally. 2. Per-``type(arg)`` precomputed dispatch: ``_arg_nd_paths_or_none`` maps each seen type to either the cached path list to walk, or ``None`` (skip — covers primitive templates, non-data_oriented composites, ``_qd_stable_members`` data_oriented, and data_oriented with zero ndarrays). One ``dict.get`` per candidate per call after warmup, replacing the previous ``is_data_oriented`` + ``__dict__.get`` + ``_struct_nd_paths_for`` chain. Measured on cluster ``rtx-mid`` single process, ``test_speed[anymal_zero-None-None- 0-cpu]``, 3-run median, Genesis main + Quadrants branch: - pre-fix tip (02e566040): 5955 FPS (-17.6% vs a22cc2ded reference 7231) - after this commit: 6935 FPS (-4.1% vs reference) Recovery: +16.5% FPS on Genesis main; +11.2% on Genesis ``hp/data-oriented-rigid- solver`` (6315 -> 7020). Brings CPU bs=0 within ~3-4% of the pre-branch baseline. Other Quadrants tests (test_data_oriented_ndarray, test_data_oriented_qd_func_dataclass, test_callable_template_mapper, test_kernel_templates, test_template_typing) still pass. --- python/quadrants/lang/_template_mapper.py | 72 +++++++++++++++-------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index a700000d65..9acebd4db2 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -16,17 +16,28 @@ ) -def _collect_data_oriented_nd_ids(arg: Any, out: list) -> None: - """Append ``id(ndarray)`` for every ndarray reachable from ``arg``, using the per-class path cache in - ``_template_mapper_hotpath._struct_nd_paths_for`` so the first call walks ``vars(arg)`` once and subsequent calls - are just ``getattr`` chains. Empty path list short-circuits with zero work — critical for genesis's - ``@qd.data_oriented`` Solver passed as ``self`` to every kernel. - """ - for chain in _struct_nd_paths_for(arg): - v = arg - for a in chain: - v = getattr(v, a) - out.append(id(v)) +# Per-``type(arg)`` precomputed dispatch for the args_hash ndarray-id walk in ``TemplateMapper.lookup``. Each entry is +# either the cached attribute path list (when the class is data_oriented, opted into ndarray tracking, and actually +# holds ndarrays) or ``None`` (when the per-call walk is a no-op — covers the common case of typed-dataclass args, +# non-data_oriented composite args, primitives, AND data_oriented classes with ``_qd_stable_members = True`` or with +# no ndarray members). One dict lookup per arg per call, ~30 ns, replacing the previous unconditional +# ``is_data_oriented(arg)`` + ``type(arg).__dict__.get`` chain. +_arg_nd_paths_or_none: dict[type, "list[tuple] | None"] = {} +_UNCLASSIFIED = object() + + +def _classify_for_args_hash(arg: Any) -> "list[tuple] | None": + """First-sighting classification for ``type(arg)`` in the args_hash walk. Returns the path list to walk (when the + arg is a data_oriented container without ``_qd_stable_members`` that actually contains ndarrays), or ``None`` to + skip subsequent per-call work for this type.""" + if not is_data_oriented(arg): + return None + if type(arg).__dict__.get("_qd_stable_members"): + return None + paths = _struct_nd_paths_for(arg) + if not paths: + return None + return paths Key: TypeAlias = tuple[Any, ...] @@ -93,21 +104,32 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl # ``@qd.data_oriented`` containers can have their member ndarrays reassigned between calls on the same instance # (``state.x = other_ndarray``). The id(arg) alone does not capture that, so the spec-key cache below would # serve a stale entry and the new ndarray's dtype/ndim would be wrong. Fold the reachable ndarray ids into the - # hash. No-op for data_oriented containers that hold no ndarrays — the walker returns an empty list. See - # ``_collect_data_oriented_nd_ids``. + # hash for the (small) set of arg positions that need it. + # + # The kernel's ``template_slot_locations`` already gives us the subset of arg positions annotated as + # ``qd.template()`` — the only positions where a data_oriented container could appear (typed-dataclass args + # carry a specific dataclass type by construction and a data_oriented class is never a dataclass). So we only + # iterate ``template_slot_locations`` instead of all args (Genesis main kernel_step_1: 4 template positions + # of 16 args; Genesis branch step_1/step_2: 4 of 4). + # + # For each candidate position, a per-class cache in ``_arg_nd_paths_or_none`` maps ``type(arg)`` to either the + # cached ndarray-path list to walk or ``None`` to skip (typical for primitive template-args, stable_members + # data_oriented, and data_oriented with zero ndarrays). One dict.get per candidate per call after warmup. nd_ids: list = [] - for arg in args: - if is_data_oriented(arg): - # Opt-out: classes that promise their ndarray members never reassign between calls - # (set ``_qd_stable_members = True`` on the class, or use - # ``@qd.data_oriented(stable_members=True)``) skip the per-call walk. The spec key - # then falls back to weakref(arg) alone — see _extract_arg's data_oriented branch. - # Saves ~1.1-1.5 us per kernel call on Genesis-style containers. Reassigning a - # member on a stable-marked instance is silently undefined behaviour: the cached - # kernel for the prior shape will be reused. - if type(arg).__dict__.get("_qd_stable_members"): - continue - _collect_data_oriented_nd_ids(arg, nd_ids) + for i in self.template_slot_locations: + arg = args[i] + cls = type(arg) + paths = _arg_nd_paths_or_none.get(cls, _UNCLASSIFIED) + if paths is _UNCLASSIFIED: + paths = _classify_for_args_hash(arg) + _arg_nd_paths_or_none[cls] = paths + if paths is None: + continue + for chain in paths: + v = arg + for a in chain: + v = getattr(v, a) + nd_ids.append(id(v)) if nd_ids: args_hash = args_hash + tuple(nd_ids) try: From 067a471b3f64e89012270a2b9bdeab65a910ba5c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 23:42:35 -0700 Subject: [PATCH 08/46] [Fix] Walker robustness: cycle-safe + Pydantic-metaclass-safe is_data_oriented MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Genesis unit tests on cluster hit RecursionError (118 instances across test_rigid_physics, test_fem, test_hybrid, test_render, ...). Two independent root causes, both in the recursive ndarray-graph walkers used to discover ndarray members of ``@qd.data_oriented`` / ``dataclass`` kernel args: 1. ``is_data_oriented(obj)`` did ``getattr(type(obj), "_data_oriented", False)``. For Genesis containers like ``RigidOptions`` (a ``pydantic.BaseModel`` subclass), the metaclass ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attribute names, blowing the stack on every call. Fix: walk MRO and look up ``_data_oriented`` directly in each class's ``__dict__`` — never goes through ``getattr`` / ``__getattr__`` so it's immune to pathological metaclasses. ``@qd.data_oriented`` sets the flag directly on the decorated class so the MRO walk still finds it. 2. ``_build_struct_nd_paths`` (in ``_template_mapper_hotpath.py``) and ``_walk_obj`` (in ``function_def_transformer.py``) had no cycle detection. Genesis object graphs have cross-references (e.g. ``solver <-> scene <-> sim <-> solver``) so the walkers recurse forever on real workloads. Fix: track ``id(obj)`` in a per-traversal ``seen`` set and skip re-entering a node we've already expanded. Adds ``test_is_data_oriented_safe_on_pydantic_like_metaclass``, ``test_data_oriented_with_pydantic_like_child``, and ``test_data_oriented_with_cyclic_attr_graph`` to pin both fixes. --- .../lang/_template_mapper_hotpath.py | 14 ++- .../function_def_transformer.py | 22 ++++- python/quadrants/lang/util.py | 14 ++- tests/python/test_data_oriented_ndarray.py | 92 +++++++++++++++++++ 4 files changed, 132 insertions(+), 10 deletions(-) diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 18c00e2d62..1d4b08de9e 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -82,7 +82,13 @@ _struct_nd_paths_cache: dict[type, list[tuple]] = {} -def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None: +def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] | None" = None) -> None: + # Cycle-safe walker. Genesis object graphs have cross-references (e.g. ``solver -> scene -> sim -> solver``) and + # Pydantic-options-style children. ``_seen`` tracks ``id(obj)`` for the current traversal to avoid re-entering a + # node we've already expanded. Cheap (one ``set`` op per frame, only allocated when we actually start recursing) + # and bounds the walk to a finite depth regardless of the graph shape. + if _seen is None: + _seen = {id(obj)} if dataclasses.is_dataclass(obj) and not isinstance(obj, type): children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)) else: @@ -102,7 +108,11 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None: if issubclass(v_type, Ndarray): out.append(chain) elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)): - _build_struct_nd_paths(v, chain, out) + v_id = id(v) + if v_id in _seen: + continue + _seen.add(v_id) + _build_struct_nd_paths(v, chain, out, _seen) def _struct_nd_paths_for(arg: Any) -> list[tuple]: diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index ba74bddc7a..a2011ca49b 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -248,7 +248,11 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: # fall back to registering every reachable ndarray. prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) - def _walk_obj(obj, arg_idx, path): + # Cycle-safe walker: Genesis object graphs have cross-references (e.g. solver <-> scene <-> sim) so we must + # avoid re-entering the same node. ``seen`` is shared across the whole arg's traversal — ``id(obj)`` is + # stable for the duration of this compile and we never need to revisit a node since the ndarray-set rooted at + # it doesn't depend on the path we took to reach it. + def _walk_obj(obj, arg_idx, path, seen): if dataclasses.is_dataclass(obj) and not isinstance(obj, type): for field in dataclasses.fields(obj): child = getattr(obj, field.name) @@ -257,7 +261,11 @@ def _walk_obj(obj, arg_idx, path): if isinstance(child, _ndarray.Ndarray): _register_ndarray(child, arg_idx, (*path, field.name)) elif (dataclasses.is_dataclass(child) and not isinstance(child, type)) or is_data_oriented(child): - _walk_obj(child, arg_idx, (*path, field.name)) + child_id = id(child) + if child_id in seen: + continue + seen.add(child_id) + _walk_obj(child, arg_idx, (*path, field.name), seen) else: for attr_name, attr_val in vars(obj).items(): if isinstance(attr_val, _TensorClass): @@ -267,7 +275,11 @@ def _walk_obj(obj, arg_idx, path): elif (dataclasses.is_dataclass(attr_val) and not isinstance(attr_val, type)) or is_data_oriented( attr_val ): - _walk_obj(attr_val, arg_idx, (*path, attr_name)) + attr_id = id(attr_val) + if attr_id in seen: + continue + seen.add(attr_id) + _walk_obj(attr_val, arg_idx, (*path, attr_name), seen) def _register_ndarray(nd, arg_idx, attr_chain): key = id(nd) @@ -309,9 +321,9 @@ def _register_ndarray(nd, arg_idx, attr_chain): if isinstance(val, _ndarray.Ndarray): continue if dataclasses.is_dataclass(val) and not isinstance(val, type): - _walk_obj(val, i, ()) + _walk_obj(val, i, (), {id(val)}) elif hasattr(val, "__dict__"): - _walk_obj(val, i, ()) + _walk_obj(val, i, (), {id(val)}) @staticmethod def _unwrap_tensor(data: Any) -> Any: diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index a9f2f4bc07..e346f4f6d3 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -350,9 +350,17 @@ def get_traceback(stacklevel=1): def is_data_oriented(obj: Any) -> bool: - # Use getattr on class instead of object to bypass custom __getattr__ method that is - # overwritten at instance level and very slow. - return getattr(type(obj), "_data_oriented", False) + # Look up ``_data_oriented`` directly via ``__dict__`` on each class in the MRO, never through ``getattr``. Some + # third-party metaclasses (notably Pydantic's ``ModelMetaclass``) override ``__getattr__`` and recurse infinitely + # on missing attributes when probed for arbitrary names — ``getattr(type(obj), "_data_oriented", False)`` blows + # the stack on a Genesis ``RigidOptions`` instance. The MRO walk via ``__dict__`` skips any descriptor / + # ``__getattr__`` machinery; ``@qd.data_oriented`` always sets the flag directly on the decorated class so this + # finds it via ``cls.__dict__["_data_oriented"]`` without ever touching the metaclass attribute protocol. + for klass in type(obj).__mro__: + flag = klass.__dict__.get("_data_oriented") + if flag is not None: + return flag + return False def is_qd_template(annotation: Any) -> bool: diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index 083fd1a19d..4b6f64c11e 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -888,3 +888,95 @@ def run(s: qd.template()): # Run a second time on the same instance — should reuse the same compiled kernel. run(state) + + +# --------------------------------------------------------------------------- +# 22. Robustness: object graphs with Pydantic-style metaclass ``__getattr__`` recursion, +# and cyclic attribute references. Real-world container classes (notably Genesis's +# ``RigidOptions`` / ``SimOptions``) inherit from ``pydantic.BaseModel`` whose +# ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attributes. +# Quadrants' walker must not blow the stack when it traverses a ``data_oriented`` arg +# that contains such an object, or that contains a back-reference to itself / its +# parent (e.g. ``solver.scene.solver``). +# --------------------------------------------------------------------------- + + +def test_is_data_oriented_safe_on_pydantic_like_metaclass(): + """``is_data_oriented`` must not invoke ``__getattr__`` on the class (or metaclass), + so it stays safe in the presence of pathological metaclasses whose ``__getattr__`` + blows the Python recursion limit on arbitrary attribute lookups (e.g. Pydantic's + ``ModelMetaclass`` when probed for a name not in its private-attrs cache).""" + + from quadrants.lang.util import is_data_oriented + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Pathological(metaclass=RecursingMeta): + pass + + # Pre-fix this raised RecursionError; with the MRO+__dict__ lookup it just returns False. + assert is_data_oriented(Pathological()) is False + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_with_pydantic_like_child(): + """A ``@qd.data_oriented`` class holding a child whose metaclass has the recursing + ``__getattr__`` (Pydantic-style). Walker must classify the child as non-data-oriented + and continue without blowing the stack.""" + N = 4 + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Options(metaclass=RecursingMeta): + pass + + @qd.data_oriented + class State: + def __init__(self, x, opts): + self.x = x + self.opts = opts + + x = qd.ndarray(qd.i32, shape=(N,)) + state = State(x=x, opts=Options()) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + run(state) + np.testing.assert_array_equal(x.to_numpy(), np.arange(1, N + 1)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_with_cyclic_attr_graph(): + """A ``@qd.data_oriented`` class whose attribute graph contains a cycle + (``parent.child.parent is parent``). Walker must not re-enter the cycle.""" + N = 4 + + @qd.data_oriented + class Child: + def __init__(self): + self.parent = None + + @qd.data_oriented + class Parent: + def __init__(self, x): + self.x = x + self.child = Child() + self.child.parent = self # cycle + + x = qd.ndarray(qd.i32, shape=(N,)) + p = Parent(x=x) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 10 + + run(p) + np.testing.assert_array_equal(x.to_numpy(), np.arange(10, 10 + N)) From cc1e3805e9500e5804ccc75c0d4b8603c9ca6adc Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sun, 17 May 2026 23:47:23 -0700 Subject: [PATCH 09/46] [Style] Apply pre-commit (black + ruff): import order, single-line conditionals --- python/quadrants/lang/_template_mapper.py | 1 - python/quadrants/lang/kernel.py | 5 +---- tests/python/test_data_oriented_qd_func_dataclass.py | 5 +++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index 9acebd4db2..1a3db47ebd 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -15,7 +15,6 @@ _struct_nd_paths_for, ) - # Per-``type(arg)`` precomputed dispatch for the args_hash ndarray-id walk in ``TemplateMapper.lookup``. Each entry is # either the cached attribute path list (when the class is data_oriented, opted into ndarray tracking, and actually # holds ndarrays) or ``None`` (when the per-call walk is a no-op — covers the common case of typed-dataclass args, diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 5ce7a936b4..e5865da677 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -486,10 +486,7 @@ def launch_kernel( (idx, chain) for _, idx, chain in struct_nd_info if type(args[idx]).__hash__ is None - or ( - is_data_oriented(args[idx]) - and not type(args[idx]).__dict__.get("_qd_stable_members") - ) + or (is_data_oriented(args[idx]) and not type(args[idx]).__dict__.get("_qd_stable_members")) ] else: self._mutable_nd_cached_val = [] diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py index 4aee084f28..20d620aa97 100644 --- a/tests/python/test_data_oriented_qd_func_dataclass.py +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -19,14 +19,14 @@ import dataclasses import numpy as np -import pytest import quadrants as qd -from tests import test_utils +from tests import test_utils # ----- typed-dataclass kernel-arg baseline (works) ---------------------------- + @test_utils.test(arch=qd.cpu) def test_baseline_typed_dataclass_kernel_arg_calls_qd_func(): """Baseline: typed-dataclass kernel arg + qd.func taking same dataclass type — works.""" @@ -56,6 +56,7 @@ def run(state: State): # ----- data_oriented self-method calling qd.func (the broken case) ----------- + @test_utils.test(arch=qd.cpu) def test_data_oriented_method_calls_qd_func_with_dataclass_member(): """data_oriented holds a dataclass; self-kernel calls a @qd.func taking that dataclass.""" From 34f85325eab9277fc547c27b4966b9f35b055437 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 00:32:25 -0700 Subject: [PATCH 10/46] [Fix] stable_members: tolerate opaque members in fastcache hasher + cache-stale leaves MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two related robustness fixes surfaced by the Genesis ``hp/data-oriented- rigid-solver`` migration on cluster unit tests. ## Problem 1: ``_uid: UID`` disables fastcache on stable_members classes After Genesis migrated ``kernel_step_1`` / ``kernel_step_2`` to methods on ``@qd.data_oriented(stable_members=True) class RigidSolver``, the fastcache args-hasher walks ``RigidSolver.__dict__``, encounters ``_uid`` of type ``genesis.utils.uid.UID``, can't recognise it, and disables fastcache for the whole call: [FASTCACHE][PARAM_INVALID] Parameter with path ('0', '_uid') and type not allowed by fast cache. [FASTCACHE][INVALID_FUNC] The pure function step_1 could not be fast cached, because one or more parameter types were invalid Causes 5 ``test_quadrants.py`` failures (``test_num_envs``, ``test_ndarray_no_compile`` on both backends) that all assert fastcache fires for ndarray-backend ``RigidSolver`` invocations. ``stable_members=True`` already promises the class's member set / types don't change after construction. Under that contract, opaque metadata (``UID``, etc.) is inert from fastcache's perspective: it doesn't affect kernel codegen. Treat ``stable_members=True`` containers as tolerant — skip unrecognised members silently and continue, instead of returning None and killing fastcache. Also silence the per-member ``[FASTCACHE][PARAM_INVALID]`` log inside a stable_members walk via a depth counter, so the user doesn't see warnings for members they explicitly opted out of caring about. ## Problem 2: cached ndarray-path leaves can be stale across instances ``_struct_nd_paths_cache`` is keyed on ``type(arg)`` and assumes the set of ndarray-reachable attribute chains is stable across instances. That's the common case but breaks on polymorphic Genesis solvers: ``FEMSolver`` / ``MPMSolver`` / ``SPHSolver`` can hold a ``qd.Tensor`` whose underlying impl swaps between an ``Ndarray`` and a ``MatrixField`` between instances. ``_collect_struct_nd_descriptors`` then walks a cached path to a ``MatrixField`` and crashes with:: AttributeError: 'MatrixField' object has no attribute 'element_type' Fix: defensively check ``isinstance(v, Ndarray)`` after the tensor-wrapper unwrap and skip stale entries silently. ``element_type`` / ``shape`` / ``_qd_layout`` are Ndarray-only; non-Ndarray leaves can't contribute a meaningful descriptor anyway, and the per-instance ``weakref(arg)`` part of the spec key still ensures cache discrimination. Adds ``test_data_oriented_polymorphic_attr_across_instances`` to pin the cache-stale-leaf behaviour. --- .../lang/_fast_caching/args_hasher.py | 73 ++++++++++++++----- .../lang/_template_mapper_hotpath.py | 9 +++ tests/python/test_data_oriented_ndarray.py | 42 +++++++++++ 3 files changed, 104 insertions(+), 20 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 455d51667c..658b0cf2ed 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -53,6 +53,15 @@ class FastcacheSkip(enum.Enum): _should_warn = False +# Counter set by the data_oriented walker when entering a ``_qd_stable_members`` object. While nonzero, the +# unknown-type branch of ``stringify_obj_type`` returns ``None`` silently instead of logging +# ``[FASTCACHE][PARAM_INVALID]``. ``stable_members=True`` is the user's promise that the class's member set / types +# don't change after construction — under that promise, opaque members like ``RigidSolver._uid`` (a +# ``genesis.utils.uid.UID``) don't affect kernel codegen so they can be skipped silently rather than killing +# fastcache for the whole call. Single-threaded by construction (the hasher only runs during JIT compile). +_skip_unknown_warn_depth = 0 + + def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: """Flag that a warning is needed if the Field didn't arrive through a qd.Tensor annotation.""" global _should_warn # pylint: disable=global-statement @@ -174,6 +183,14 @@ def stringify_obj_type( if dataclasses.is_dataclass(obj): return dataclass_to_repr(raise_on_templated_floats, path, obj) if is_data_oriented(obj): + # ``@qd.data_oriented(stable_members=True)``: the class promises its member *set* and *types* don't change + # after construction. Under that contract, unrecognised member types (e.g. Genesis's ``RigidSolver._uid`` of + # type ``genesis.utils.uid.UID``, or any other opaque metadata) are treated as inert from fastcache's + # perspective: they don't affect kernel codegen so they can be skipped silently rather than killing fastcache + # for the whole call. Without this, migrating a kernel from a standalone ``@qd.kernel`` function to a method + # on a ``@qd.data_oriented`` class disables fastcache the moment the class holds any opaque metadata, even + # though the kernel's compiled output would be identical. + stable_members = bool(type(obj).__dict__.get("_qd_stable_members")) child_repr_l = ["da"] _dict = {} try: @@ -182,26 +199,37 @@ def stringify_obj_type( _dict = _asdict() except AttributeError: _dict = obj.__dict__ - for k, v in _dict.items(): - # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` - # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so - # that subsequent ``instance.method`` lookups skip the descriptor allocation; - # those entries are not data and must not invalidate the fastcache key. - v_type = type(v) - if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: - continue - _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) - if _child_repr is None: - if _should_warn: - _logging.warn( - f"A kernel that has been marked as eligible for fast cache was passed 1 or more parameters " - f"that are not, in fact, eligible for fast cache: one of the parameters was a " - f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " - f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " - f"information, the path of the value was {path}." - ) - return None - child_repr_l.append(f"{k}: {_child_repr}") + global _skip_unknown_warn_depth # pylint: disable=global-statement + if stable_members: + _skip_unknown_warn_depth += 1 + try: + for k, v in _dict.items(): + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` + # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so + # that subsequent ``instance.method`` lookups skip the descriptor allocation; + # those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) + if _child_repr is None: + if stable_members: + # Member is opaque to fastcache; under the stable_members contract it's inert and skipping + # is safe. Don't kill fastcache for the whole call. + continue + if _should_warn: + _logging.warn( + f"A kernel that has been marked as eligible for fast cache was passed 1 or more " + f"parameters that are not, in fact, eligible for fast cache: one of the parameters was a " + f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " + f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " + f"information, the path of the value was {path}." + ) + return None + child_repr_l.append(f"{k}: {_child_repr}") + finally: + if stable_members: + _skip_unknown_warn_depth -= 1 return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): if _is_template(arg_meta): @@ -218,6 +246,11 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" + if _skip_unknown_warn_depth > 0: + # Inside a ``stable_members=True`` data_oriented walk: opaque members are tolerated by contract, so don't log + # the per-member ``[FASTCACHE][PARAM_INVALID]`` warning. The data_oriented walker reads the returned ``None`` + # and skips this member. + return None _mark_should_warn() # The bit in caps should not be modified without updating corresponding test # The rest of free text can be freely modified diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 1d4b08de9e..d124328f81 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -140,12 +140,21 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding ndarrays — see the data_oriented branch in ``_extract_arg``. """ + # The path cache is keyed on ``type(arg)`` and assumes the *set* of ndarray-reachable attribute chains is stable + # across instances of the same class. That holds for the typical ``@qd.data_oriented`` container, but Genesis + # ``FEMSolver`` / ``MPMSolver`` / ``SPHSolver`` and similar can hold polymorphic children (e.g. ``self.material`` + # of a different concrete subclass) or swap a ``qd.Tensor``'s underlying impl between an ``Ndarray`` and a + # ``MatrixField``. When the leaf at a cached path is no longer an ``Ndarray`` we silently skip it: ``v.element_type`` + # / ``v.shape`` / ``v._qd_layout`` are Ndarray-only accessors. The per-instance ``weakref(arg)`` part of the spec + # key still ensures correct cache discrimination across instances. for chain in _struct_nd_paths_for(arg): v = arg for a in chain: v = getattr(v, a) if type(v) in _TENSOR_WRAPPER_TYPES: v = v._unwrap() + if not isinstance(v, Ndarray): + continue type_id = id(v.element_type) element_type = type_id if type_id in primitive_types.type_ids else v.element_type out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout)) diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index 4b6f64c11e..7c73a714b4 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -952,6 +952,48 @@ def run(s: qd.template()): np.testing.assert_array_equal(x.to_numpy(), np.arange(1, N + 1)) +@test_utils.test(arch=qd.cpu) +def test_data_oriented_polymorphic_attr_across_instances(): + """The path cache in ``_struct_nd_paths_cache`` is keyed on ``type(arg)`` and assumes the set of + ndarray-reachable attribute chains is stable across instances. Some real-world ``@qd.data_oriented`` + containers (Genesis FEMSolver / MPMSolver / SPHSolver, etc.) hold polymorphic children whose + types differ between instances — e.g. ``self.material.x`` is an ``Ndarray`` on instance A and + a ``qd.field`` (``MatrixField``) on instance B. ``_collect_struct_nd_descriptors`` walks cached + paths verbatim and must not crash with ``'MatrixField' object has no attribute 'element_type'`` + when a path's leaf is no longer an ``Ndarray``; it should silently skip the stale entry.""" + N = 4 + + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + + # First instance: ``self.x`` is an Ndarray. The walker emits path ``('x',)`` and caches it. + x_nd = qd.ndarray(qd.i32, shape=(N,)) + state_a = State(x=x_nd) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + run(state_a) + np.testing.assert_array_equal(x_nd.to_numpy(), np.arange(1, N + 1)) + + # Second instance of the SAME class, ``self.x`` is now a ``qd.field`` (MatrixField via Vector.field). + # The cached path ``('x',)`` from instance A points to a non-Ndarray on this instance — the descriptor + # walk must skip it cleanly rather than crash on ``v.element_type``. + f = qd.Vector.field(2, qd.i32, shape=(N,)) + state_b = State(x=f) + + @qd.kernel + def run_field(s: qd.template()): + for i in range(N): + s.x[i] = [i, i + 1] + + run_field(state_b) + + @test_utils.test(arch=qd.cpu) def test_data_oriented_with_cyclic_attr_graph(): """A ``@qd.data_oriented`` class whose attribute graph contains a cycle From 5e5490256009cf4b500b926016d563cd22f4694a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 00:53:19 -0700 Subject: [PATCH 11/46] [Fix] stable_members fastcache: only tolerate truly-opaque members, fail on Field/MatrixField MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit My previous commit ``5add57b6a`` was too loose: it silently skipped *any* member that ``stringify_obj_type`` returned ``None`` for, including ``Field`` / ``MatrixField``. That broke ``test_quadrants.test_num_envs[ False-*]`` (field backend), which pins the contract that fastcache must fail when an arg's subtree contains a recognised-but-unsupported tensor-like type (whose value affects kernel codegen). Differentiate two reasons ``stringify_obj_type`` returns ``None``: (a) RECOGNISED-BUT-UNSUPPORTED: ``ScalarField`` / ``MatrixField`` (and any future type explicitly hitting ``_mark_warn_if_not_tensor_ annotation``). These now also call ``_mark_hit_recognised_ unsupported()`` to flip a module-level flag. The flag bubbles up naturally through nested dataclass / data_oriented walkers since they propagate ``None``. (b) TRULY-OPAQUE: unknown types falling through to the ``[FASTCACHE][PARAM_INVALID]`` branch (``RigidSolver._uid: UID``, etc.). These don't set the flag. The ``stable_members=True`` data_oriented walker snapshots the flag around each child's recursive call. If a child returned ``None`` AND the flag was set, fastcache fails (any tensor-like leaf in the subtree invalidates the hash). If the flag was clear, the child is truly opaque metadata — skip it silently under the user's stability contract. ``_hit_recognised_unsupported`` is reset at the top of ``hash_args`` and before each child probe so the snapshot reflects only the just-completed recursion. --- .../lang/_fast_caching/args_hasher.py | 41 +++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 658b0cf2ed..6dcfe9a8b0 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -41,6 +41,20 @@ _DC_REPR_NONE = object() +# Set by ``stringify_obj_type`` when it encounters a recognised-but-unsupported tensor-like type (``Field`` / +# ``MatrixField``) anywhere in the traversal — including nested under a dataclass or another data_oriented object. +# The ``stable_members=True`` data_oriented walker uses this to differentiate two reasons a child returned ``None``: +# truly-opaque metadata (``RigidSolver._uid: UID``, etc.) which is inert and can be skipped, vs a tensor-like type +# whose value affects kernel codegen and must invalidate fastcache for the whole call. Reset at the top of each +# ``hash_args``; snapshotted/restored around each nested ``stringify_obj_type`` call inside the data_oriented walker. +_hit_recognised_unsupported = False + + +def _mark_hit_recognised_unsupported() -> None: + global _hit_recognised_unsupported # pylint: disable=global-statement + _hit_recognised_unsupported = True + + class FastcacheSkip(enum.Enum): """Why fastcache does not apply to this call.""" @@ -167,6 +181,7 @@ def stringify_obj_type( # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) + _mark_hit_recognised_unsupported() return None if isinstance(obj, MatrixNdarray): return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] @@ -179,6 +194,7 @@ def stringify_obj_type( # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) + _mark_hit_recognised_unsupported() return None if dataclasses.is_dataclass(obj): return dataclass_to_repr(raise_on_templated_floats, path, obj) @@ -211,11 +227,27 @@ def stringify_obj_type( v_type = type(v) if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: continue + # Snapshot the recognised-but-unsupported flag around the recursive call so we can tell whether + # *this child's* subtree hit a ``Field`` / ``MatrixField`` (in which case we must fail fastcache + # even under ``stable_members``). + global _hit_recognised_unsupported # pylint: disable=global-statement + _hit_recognised_unsupported = False _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) + child_hit_field = _hit_recognised_unsupported if _child_repr is None: - if stable_members: - # Member is opaque to fastcache; under the stable_members contract it's inert and skipping - # is safe. Don't kill fastcache for the whole call. + # Differentiate two reasons ``stringify_obj_type`` returns None: + # + # (a) RECOGNISED-BUT-UNSUPPORTED: ``Field`` / ``MatrixField`` somewhere in this child's + # subtree. These are *known* tensor-like types whose values affect kernel codegen but + # which fastcache doesn't yet handle. Killing fastcache for the whole call is the + # intended contract — ``test_num_envs[False-...]`` pins this behaviour for the field + # backend. + # (b) TRULY-OPAQUE: anything that falls through to the ``[FASTCACHE][PARAM_INVALID]`` + # warning at the bottom of ``stringify_obj_type`` (``RigidSolver._uid`` of type + # ``UID``, etc.). For ``stable_members=True`` containers, opaque metadata is inert by + # the user's contract and can be skipped without invalidating the hash for the rest + # of the members. + if stable_members and not child_hit_field: continue if _should_warn: _logging.warn( @@ -265,8 +297,9 @@ def hash_args( raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] ) -> str | FastcacheSkip: """Return the args hash string, or a HashFailure explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn, _hit_recognised_unsupported # pylint: disable=line-too-long _should_warn = False + _hit_recognised_unsupported = False g_num_calls += 1 g_num_args += len(args) hash_l = [] From 55ecf95d55b3d764829802bff5ad29f087f514c9 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 01:07:19 -0700 Subject: [PATCH 12/46] [Fix] Metaclass-safe is_dataclass for walker over user objects MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `dataclasses.is_dataclass(obj)` calls `hasattr(type(obj), '__dataclass_fields__')`, which delegates to the metaclass `__getattr__` for missing names. Pydantic's `ModelMetaclass` (and our `RecursingMeta` regression fixture) recurse infinitely on arbitrary lookups and blow the stack — same class of failure as the previously-fixed `is_data_oriented(obj)` path. Add `is_dataclass_instance` in `lang/util.py` that walks `type(obj).__mro__` probing `klass.__dict__` directly (never via `getattr`), and use it everywhere the kernel pipeline tests user values for dataclass-ness: - `_template_mapper_hotpath._build_struct_nd_paths` - `function_def_transformer._walk_obj` (both branches) - `function_def_transformer` dataclass-vs-`__dict__` walker dispatch - `args_hasher.stringify_obj_type` Annotations/types are untouched (`call_transformer`, `_signature`, `_kernel_impl_dataclass`): those check user-declared dataclass types, not runtime values that can carry pathological metaclasses. Fixes `test_data_oriented_with_pydantic_like_child` (added in b3457a6e5 to pin this exact regression but caught only the `is_data_oriented` half of it). --- .../quadrants/lang/_fast_caching/args_hasher.py | 4 ++-- python/quadrants/lang/_template_mapper_hotpath.py | 6 +++--- .../ast_transformers/function_def_transformer.py | 12 +++++------- python/quadrants/lang/util.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 6dcfe9a8b0..aa8df595eb 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -16,7 +16,7 @@ from ..field import ScalarField from ..kernel_arguments import ArgMetadata from ..matrix import MatrixField, MatrixNdarray, VectorNdarray -from ..util import is_data_oriented +from ..util import is_data_oriented, is_dataclass_instance from .hash_utils import hash_iterable_strings _FIELD_TYPES = (ScalarField, MatrixField) @@ -196,7 +196,7 @@ def stringify_obj_type( _mark_warn_if_not_tensor_annotation(arg_meta) _mark_hit_recognised_unsupported() return None - if dataclasses.is_dataclass(obj): + if is_dataclass_instance(obj): return dataclass_to_repr(raise_on_templated_floats, path, obj) if is_data_oriented(obj): # ``@qd.data_oriented(stable_members=True)``: the class promises its member *set* and *types* don't change diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index d124328f81..e6ba24e46e 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -46,7 +46,7 @@ from quadrants.lang.expr import Expr from quadrants.lang.matrix import MatrixType from quadrants.lang.snode import SNode -from quadrants.lang.util import is_data_oriented, to_quadrants_type +from quadrants.lang.util import is_data_oriented, is_dataclass_instance, to_quadrants_type from quadrants.types import ( buffer_view_type, ndarray_type, @@ -89,7 +89,7 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] # and bounds the walk to a finite depth regardless of the graph shape. if _seen is None: _seen = {id(obj)} - if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + if is_dataclass_instance(obj): children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)) else: # ``NamedTuple`` (decorated as ``@qd.data_oriented``) has no instance ``__dict__`` — fall back to ``_asdict()`` @@ -107,7 +107,7 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] v_type = type(v) if issubclass(v_type, Ndarray): out.append(chain) - elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)): + elif is_data_oriented(v) or is_dataclass_instance(v): v_id = id(v) if v_id in _seen: continue diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index a2011ca49b..456d13391a 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -34,7 +34,7 @@ from quadrants.lang.matrix import MatrixType from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType -from quadrants.lang.util import is_data_oriented, to_quadrants_type +from quadrants.lang.util import is_data_oriented, is_dataclass_instance, to_quadrants_type from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types @@ -253,14 +253,14 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: # stable for the duration of this compile and we never need to revisit a node since the ndarray-set rooted at # it doesn't depend on the path we took to reach it. def _walk_obj(obj, arg_idx, path, seen): - if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + if is_dataclass_instance(obj): for field in dataclasses.fields(obj): child = getattr(obj, field.name) if isinstance(child, _TensorClass): child = child._unwrap() if isinstance(child, _ndarray.Ndarray): _register_ndarray(child, arg_idx, (*path, field.name)) - elif (dataclasses.is_dataclass(child) and not isinstance(child, type)) or is_data_oriented(child): + elif is_dataclass_instance(child) or is_data_oriented(child): child_id = id(child) if child_id in seen: continue @@ -272,9 +272,7 @@ def _walk_obj(obj, arg_idx, path, seen): attr_val = attr_val._unwrap() if isinstance(attr_val, _ndarray.Ndarray): _register_ndarray(attr_val, arg_idx, (*path, attr_name)) - elif (dataclasses.is_dataclass(attr_val) and not isinstance(attr_val, type)) or is_data_oriented( - attr_val - ): + elif is_dataclass_instance(attr_val) or is_data_oriented(attr_val): attr_id = id(attr_val) if attr_id in seen: continue @@ -320,7 +318,7 @@ def _register_ndarray(nd, arg_idx, attr_chain): val = val._unwrap() if isinstance(val, _ndarray.Ndarray): continue - if dataclasses.is_dataclass(val) and not isinstance(val, type): + if is_dataclass_instance(val): _walk_obj(val, i, (), {id(val)}) elif hasattr(val, "__dict__"): _walk_obj(val, i, (), {id(val)}) diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index e346f4f6d3..0fb153c684 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -363,6 +363,21 @@ def is_data_oriented(obj: Any) -> bool: return False +def is_dataclass_instance(obj: Any) -> bool: + # Metaclass-safe replacement for ``dataclasses.is_dataclass(obj) and not isinstance(obj, type)``. The stdlib + # implementation calls ``hasattr(type(obj), '__dataclass_fields__')``, which delegates to the metaclass + # ``__getattr__`` for missing names. Pathological metaclasses (Pydantic's ``ModelMetaclass``) recurse infinitely + # on arbitrary attribute lookups and blow the stack. Walking the MRO and probing ``__dict__`` directly avoids + # any descriptor / ``__getattr__`` machinery, mirroring ``is_data_oriented`` above. Also folds in the + # ``not isinstance(obj, type)`` guard since callers always pair the two. + if isinstance(obj, type): + return False + for klass in type(obj).__mro__: + if "__dataclass_fields__" in klass.__dict__: + return True + return False + + def is_qd_template(annotation: Any) -> bool: return annotation is Template or type(annotation) is Template From 6d9c307f2b06627ee26b8cb70100104b0fb1ba20 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 02:19:30 -0700 Subject: [PATCH 13/46] [Style] pre-commit: import formatting Black/ruff reformatted multi-import statements onto multiple lines. --- python/quadrants/lang/_template_mapper_hotpath.py | 6 +++++- .../lang/ast/ast_transformers/function_def_transformer.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index e6ba24e46e..b5621ea0f1 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -46,7 +46,11 @@ from quadrants.lang.expr import Expr from quadrants.lang.matrix import MatrixType from quadrants.lang.snode import SNode -from quadrants.lang.util import is_data_oriented, is_dataclass_instance, to_quadrants_type +from quadrants.lang.util import ( + is_data_oriented, + is_dataclass_instance, + to_quadrants_type, +) from quadrants.types import ( buffer_view_type, ndarray_type, diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 456d13391a..1810c086c4 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -34,7 +34,11 @@ from quadrants.lang.matrix import MatrixType from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType -from quadrants.lang.util import is_data_oriented, is_dataclass_instance, to_quadrants_type +from quadrants.lang.util import ( + is_data_oriented, + is_dataclass_instance, + to_quadrants_type, +) from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types From 49ffb3b446cba44da09155b9292dae619854a464 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 02:33:24 -0700 Subject: [PATCH 14/46] [Fix] Fastcache: skip opaque-typed members silently by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous design used a ``stable_members=True`` opt-in flag (or per-class ``_qd_stable_members`` attribute) to tell the fastcache hasher to silently skip opaque-typed members of ``@qd.data_oriented`` containers. Without the opt-in, any unrecognised member type disabled fastcache for the whole call, which made adding a UUID, Pydantic config object, or back-pointer to ``self`` silently kill fastcache. That contract was brittle: adding any new metadata member to a long-lived ``@qd.data_oriented`` class could disable fastcache without warning, and the opt-in was an "I promise the layout doesn't change" contract that has nothing to do with the actual fastcache invariant. The actual invariant: opaque Python types cannot affect kernel codegen because the kernel cannot read them. Only recognised types — ndarrays, primitives, enums, dataclasses, nested ``@qd.data_oriented`` objects — can be read by kernel code. So *all* container walkers (``dataclass_to_repr`` and the ``data_oriented`` branch in ``stringify_obj_type``) can safely skip opaque members from the hash, no opt-in needed. Recognised-but-unsupported types (``qd.field`` / ``qd.Matrix.field``) are distinct: their shape/dtype affect kernel codegen but fastcache doesn't yet know how to hash them. These still disable fastcache for the whole call — behaviour is unchanged. Top-level kernel-arg opaqueness is also distinct: an opaque top-level arg is a user error (the kernel's argument is uninterpretable to fastcache) and still emits the ``[FASTCACHE][PARAM_INVALID]`` warning. Implementation: ``stringify_obj_type`` now takes a ``nested: bool`` parameter that suppresses the warning for nested opaque types. Container walkers pass ``nested=True``. Removed the global ``_skip_unknown_warn_depth`` counter and ``_hit_recognised_unsupported`` flag — replaced with a clean ``_FAIL_FASTCACHE`` sentinel distinct from ``None`` (opaque/silent-skip). The ``@qd.data_oriented(stable_members=True)`` flag and ``_qd_stable_members`` attribute remain — they still gate the launch-context per-call walker optimization in ``_template_mapper`` and ``Kernel.launch_kernel``. Removed from the fastcache hasher's logic only. Added 3 regression tests pinning the new defaults: - data_oriented with opaque member: silently hashable. - data_oriented with nested field: still FastcacheSkip. - dataclass with opaque field: silently hashable. All 130 fastcache + data_oriented tests pass on x64. --- .../lang/_fast_caching/args_hasher.py | 219 ++++++++---------- .../lang/fast_caching/test_args_hasher.py | 70 ++++++ 2 files changed, 172 insertions(+), 117 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index aa8df595eb..77d3ee2ba6 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -41,18 +41,22 @@ _DC_REPR_NONE = object() -# Set by ``stringify_obj_type`` when it encounters a recognised-but-unsupported tensor-like type (``Field`` / -# ``MatrixField``) anywhere in the traversal — including nested under a dataclass or another data_oriented object. -# The ``stable_members=True`` data_oriented walker uses this to differentiate two reasons a child returned ``None``: -# truly-opaque metadata (``RigidSolver._uid: UID``, etc.) which is inert and can be skipped, vs a tensor-like type -# whose value affects kernel codegen and must invalidate fastcache for the whole call. Reset at the top of each -# ``hash_args``; snapshotted/restored around each nested ``stringify_obj_type`` call inside the data_oriented walker. -_hit_recognised_unsupported = False +# Sentinel returned by ``stringify_obj_type`` when a recognised-but-unsupported tensor-like type (``Field`` / +# ``MatrixField``) is encountered anywhere in the traversal. Containers that see this sentinel (``dataclass_to_repr``, +# the ``data_oriented`` branch, and the top-level ``hash_args`` loop) must propagate it upward — fastcache cannot +# safely hash the call. Distinct from ``None``, which means "opaque type, safe to silently skip at nested levels". +class _FailFastcache: + """Singleton sentinel; identity-compared. See module docstring on ``stringify_obj_type``'s return contract.""" + _instance = None -def _mark_hit_recognised_unsupported() -> None: - global _hit_recognised_unsupported # pylint: disable=global-statement - _hit_recognised_unsupported = True + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +_FAIL_FASTCACHE = _FailFastcache() class FastcacheSkip(enum.Enum): @@ -67,15 +71,6 @@ class FastcacheSkip(enum.Enum): _should_warn = False -# Counter set by the data_oriented walker when entering a ``_qd_stable_members`` object. While nonzero, the -# unknown-type branch of ``stringify_obj_type`` returns ``None`` silently instead of logging -# ``[FASTCACHE][PARAM_INVALID]``. ``stable_members=True`` is the user's promise that the class's member set / types -# don't change after construction — under that promise, opaque members like ``RigidSolver._uid`` (a -# ``genesis.utils.uid.UID``) don't affect kernel codegen so they can be skipped silently rather than killing -# fastcache for the whole call. Single-threaded by construction (the hasher only runs during JIT compile). -_skip_unknown_warn_depth = 0 - - def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: """Flag that a warning is needed if the Field didn't arrive through a qd.Tensor annotation.""" global _should_warn # pylint: disable=global-statement @@ -88,24 +83,38 @@ def _mark_should_warn() -> None: _should_warn = True -def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | None: +def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | _FailFastcache | None: + """Hash a dataclass instance. + + Returns: + - ``str``: a string representation suitable for the fastcache key. + - ``_FAIL_FASTCACHE``: a recognised-but-unsupported tensor-like field was hit; fastcache must be disabled + for the whole call. + - ``None``: dataclass-level skip. Currently unused (dataclasses always succeed unless they hit Field/MatrixField), + but defined symmetrically with ``stringify_obj_type``. + + Note that opaque-typed fields (UUID, plain Python objects, ...) are *silently skipped* — they cannot affect + kernel codegen because the kernel cannot read non-recognised Python types, so omitting them from the hash is + safe by construction. + """ # PERF: For frozen dataclasses, the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``None`` is stored as the sentinel ``_DC_REPR_NONE`` to distinguish "not yet computed" from - # "computed but not fast-cacheable". + # A cached ``_DC_REPR_NONE`` is stored to distinguish "not yet computed" from "computed but not fast-cacheable". is_frozen = type(arg).__hash__ is not None if is_frozen: cached = getattr(arg, "_qd_dc_repr", None) if cached is _DC_REPR_NONE: - return None + return _FAIL_FASTCACHE if cached is not None: return cached repr_l = [] for field in dataclasses.fields(arg): child_value = getattr(arg, field.name) _repr = stringify_obj_type(raise_on_templated_floats, path + (field.name,), child_value, arg_meta=None) - if _repr is None: + if _repr is _FAIL_FASTCACHE: + # Recognised-but-unsupported (Field/MatrixField) somewhere in this child's subtree. Mark whether the + # field arrived via a non-Tensor annotation so the top-level decides between WARN and FIELD_VIA_TENSOR. if isinstance(child_value, _FIELD_TYPES) and field.type is not _TensorWrapper: _mark_should_warn() if is_frozen: @@ -113,7 +122,11 @@ def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], ar object.__setattr__(arg, "_qd_dc_repr", _DC_REPR_NONE) except AttributeError: pass - return None + return _FAIL_FASTCACHE + if _repr is None: + # Opaque-typed field; skip silently. Opaque types cannot affect kernel codegen because the kernel + # cannot read non-recognised Python types — they are inert metadata. + continue full_repr = f"{field.name}: ({_repr})" if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False): full_repr += f" = {child_value}" @@ -135,23 +148,32 @@ def _is_template(arg_meta: ArgMetadata | None) -> bool: def stringify_obj_type( - raise_on_templated_floats: bool, path: tuple[str, ...], obj: object, arg_meta: ArgMetadata | None -) -> str | None: + raise_on_templated_floats: bool, + path: tuple[str, ...], + obj: object, + arg_meta: ArgMetadata | None, + nested: bool = False, +) -> str | _FailFastcache | None: """ - Convert an object into a string representation that only depends on its type. - - String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have - to be the actual python type string, just a string that is representative of the type, and won't collide - with different (allowed) types. String should be non-empty. - - Note that fields are not included in fast cache. - - arg_meta should only be non-None for the top level arguments and for data oriented objects. It is - used currently to determine whether a value is added to the cache key, as well as the name. eg - - at the top level, primitive types have their values added to the cache key if their annotation is qd.Template, - since they are baked into the kernel - - in data oriented objects, the values of all primitive types are added to the cache key, since they are baked - into the kernel, and require a kernel recompilation, when they change + Convert an object into a string representation that only depends on its type (and, where relevant, its value). + + Return contract: + - ``str``: the object is hashable for fastcache; the returned string contributes to the cache key. + - ``_FAIL_FASTCACHE``: a recognised-but-unsupported type (``qd.field`` / ``qd.Matrix.field``) was encountered. + Containers must propagate this upward; fastcache will be disabled for the whole call. + - ``None``: the object's type is *opaque* — not recognised by the hasher. Containers (``dataclass_to_repr`` + and the ``data_oriented`` branch below) silently skip opaque members because opaque types cannot affect + kernel codegen (the kernel can only read recognised types: ndarrays, primitives, enums, dataclasses, + nested ``@qd.data_oriented`` objects). At the top level (``nested=False``), opaque is treated as + an error and a ``[FASTCACHE][PARAM_INVALID]`` warning is emitted. + + Parameters: + - ``nested``: ``True`` if this call comes from a container walker (dataclass / data_oriented). Suppresses + the top-level ``[FASTCACHE][PARAM_INVALID]`` warning for opaque types so nested opaque members are + skipped silently. ``False`` at the top of each kernel-arg traversal. + - ``arg_meta``: non-``None`` only for the top-level kernel arguments and for ``@qd.data_oriented`` members. + Used to determine whether to bake values into the cache key (primitives in template positions, and all + primitive members of data-oriented containers). """ # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` strips # wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / data-oriented @@ -159,8 +181,6 @@ def stringify_obj_type( # fields, so a wrapper stored as a struct field arrives here un-stripped. Without this branch the hasher falls # through to the ``[FASTCACHE][PARAM_INVALID]`` warning and disables the fast path for the whole call. See # ``perso_hugh/doc/quadrants-tensor.md`` §8.14. - # ``qd.Tensor`` wrappers: unwrap to the bare impl so the type checks below match. After unwrap, ``_qd_layout`` (if - # any) is on the impl. # # PERF-CRITICAL: The _any_tensor_constructed guard makes this check zero-cost when no qd.Tensor has been created. # ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer comparison (~10 @@ -177,12 +197,11 @@ def stringify_obj_type( if isinstance(obj, VectorNdarray): return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): - # disabled for now, because we need to think about how to handle field offset - # etc + # Recognised-but-unsupported: Field's shape/dtype affect kernel codegen but fastcache doesn't yet know how + # to handle them. Disable fastcache for the whole call. # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - _mark_hit_recognised_unsupported() - return None + return _FAIL_FASTCACHE if isinstance(obj, MatrixNdarray): return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): @@ -190,78 +209,42 @@ def stringify_obj_type( if isinstance(obj, np.ndarray): return f"[np-{obj.dtype}-{obj.ndim}]" if isinstance(obj, MatrixField): - # disabled for now, because we need to think about how to handle field offset - # etc + # Recognised-but-unsupported, same as ScalarField above. # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - _mark_hit_recognised_unsupported() - return None + return _FAIL_FASTCACHE if is_dataclass_instance(obj): return dataclass_to_repr(raise_on_templated_floats, path, obj) if is_data_oriented(obj): - # ``@qd.data_oriented(stable_members=True)``: the class promises its member *set* and *types* don't change - # after construction. Under that contract, unrecognised member types (e.g. Genesis's ``RigidSolver._uid`` of - # type ``genesis.utils.uid.UID``, or any other opaque metadata) are treated as inert from fastcache's - # perspective: they don't affect kernel codegen so they can be skipped silently rather than killing fastcache - # for the whole call. Without this, migrating a kernel from a standalone ``@qd.kernel`` function to a method - # on a ``@qd.data_oriented`` class disables fastcache the moment the class holds any opaque metadata, even - # though the kernel's compiled output would be identical. - stable_members = bool(type(obj).__dict__.get("_qd_stable_members")) + # Walk the data_oriented container's members. Recognised members contribute to the cache key; recognised- + # but-unsupported (Field/MatrixField) propagates _FAIL_FASTCACHE; opaque-typed members are skipped silently. + # + # Silently skipping opaque members is safe by construction: the kernel can only read recognised member types + # (ndarrays, primitives, enums, dataclasses, nested data_oriented). Opaque Python objects (UUIDs, Pydantic + # ``BaseModel`` instances, back-pointers up the object graph, etc.) cannot be read by kernel code, so they + # cannot affect kernel codegen and omitting them from the hash is correct. child_repr_l = ["da"] - _dict = {} try: - # pyright is ok with this approach _asdict = getattr(obj, "_asdict") _dict = _asdict() except AttributeError: _dict = obj.__dict__ - global _skip_unknown_warn_depth # pylint: disable=global-statement - if stable_members: - _skip_unknown_warn_depth += 1 - try: - for k, v in _dict.items(): - # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` - # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so - # that subsequent ``instance.method`` lookups skip the descriptor allocation; - # those entries are not data and must not invalidate the fastcache key. - v_type = type(v) - if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: - continue - # Snapshot the recognised-but-unsupported flag around the recursive call so we can tell whether - # *this child's* subtree hit a ``Field`` / ``MatrixField`` (in which case we must fail fastcache - # even under ``stable_members``). - global _hit_recognised_unsupported # pylint: disable=global-statement - _hit_recognised_unsupported = False - _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) - child_hit_field = _hit_recognised_unsupported - if _child_repr is None: - # Differentiate two reasons ``stringify_obj_type`` returns None: - # - # (a) RECOGNISED-BUT-UNSUPPORTED: ``Field`` / ``MatrixField`` somewhere in this child's - # subtree. These are *known* tensor-like types whose values affect kernel codegen but - # which fastcache doesn't yet handle. Killing fastcache for the whole call is the - # intended contract — ``test_num_envs[False-...]`` pins this behaviour for the field - # backend. - # (b) TRULY-OPAQUE: anything that falls through to the ``[FASTCACHE][PARAM_INVALID]`` - # warning at the bottom of ``stringify_obj_type`` (``RigidSolver._uid`` of type - # ``UID``, etc.). For ``stable_members=True`` containers, opaque metadata is inert by - # the user's contract and can be skipped without invalidating the hash for the rest - # of the members. - if stable_members and not child_hit_field: - continue - if _should_warn: - _logging.warn( - f"A kernel that has been marked as eligible for fast cache was passed 1 or more " - f"parameters that are not, in fact, eligible for fast cache: one of the parameters was a " - f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " - f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " - f"information, the path of the value was {path}." - ) - return None - child_repr_l.append(f"{k}: {_child_repr}") - finally: - if stable_members: - _skip_unknown_warn_depth -= 1 + for k, v in _dict.items(): + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the per-instance + # ``BoundQuadrantsCallable`` on ``instance.__dict__`` so that subsequent ``instance.method`` lookups skip + # the descriptor allocation; those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + _child_repr = stringify_obj_type( + raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, ""), nested=True + ) + if _child_repr is _FAIL_FASTCACHE: + return _FAIL_FASTCACHE + if _child_repr is None: + # Opaque member; skip silently. + continue + child_repr_l.append(f"{k}: {_child_repr}") return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): if _is_template(arg_meta): @@ -278,10 +261,10 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - if _skip_unknown_warn_depth > 0: - # Inside a ``stable_members=True`` data_oriented walk: opaque members are tolerated by contract, so don't log - # the per-member ``[FASTCACHE][PARAM_INVALID]`` warning. The data_oriented walker reads the returned ``None`` - # and skips this member. + # Opaque (unrecognised) type. At nested levels, container walkers skip these silently — opaque types cannot + # affect kernel codegen because the kernel cannot read non-recognised Python types. At the top level, this is + # a user error (the kernel's argument is uninterpretable to fastcache) and we emit a warning. + if nested: return None _mark_should_warn() # The bit in caps should not be modified without updating corresponding test @@ -296,10 +279,9 @@ def stringify_obj_type( def hash_args( raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] ) -> str | FastcacheSkip: - """Return the args hash string, or a HashFailure explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn, _hit_recognised_unsupported # pylint: disable=line-too-long + """Return the args hash string, or a FastcacheSkip explaining why hashing failed.""" + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement _should_warn = False - _hit_recognised_unsupported = False g_num_calls += 1 g_num_args += len(args) hash_l = [] @@ -309,9 +291,12 @@ def hash_args( ) for i_arg, arg in enumerate(args): start = time.time() - _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg]) + _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg], nested=False) g_repr_time += time.time() - start - if not _hash: + # Both ``_FAIL_FASTCACHE`` (recognised-but-unsupported) and ``None`` (opaque at top level) disable + # fastcache. ``_should_warn`` selects between WARN (loud) and FIELD_VIA_TENSOR (silent — Field reached via + # qd.Tensor annotation, which is a normal path). + if _hash is _FAIL_FASTCACHE or _hash is None or not _hash: g_num_ignored_calls += 1 return FastcacheSkip.WARN if _should_warn else FastcacheSkip.FIELD_VIA_TENSOR hash_l.append(_hash) diff --git a/tests/python/quadrants/lang/fast_caching/test_args_hasher.py b/tests/python/quadrants/lang/fast_caching/test_args_hasher.py index bda9adf808..de7cf43d57 100644 --- a/tests/python/quadrants/lang/fast_caching/test_args_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_args_hasher.py @@ -107,6 +107,76 @@ class Foo: ... assert args_hasher.hash_args(False, [foo], [None]) is not None +@test_utils.test() +def test_args_hasher_data_oriented_with_opaque_member_silently_skipped() -> None: + """Default ``@qd.data_oriented`` (no ``stable_members``) tolerates opaque-typed members. + + Opaque members (custom Python types not in the recognised set — UUIDs, plain Python objects, + references to non-data-oriented classes, ...) are silently skipped from the fastcache key, because + the kernel cannot read non-recognised Python types and therefore opaque members cannot affect + kernel codegen. Pinning this here so we don't regress to silently disabling fastcache for any + data_oriented holding metadata. + """ + + class Opaque: + def __init__(self, val: int) -> None: + self.val = val + + @qd.data_oriented + class Container: + def __init__(self, opaque: Opaque) -> None: + self.meta = opaque + self.nd = qd.ndarray(qd.i32, shape=(4,)) + + c1 = Container(Opaque(1)) + c2 = Container(Opaque(2)) + h1 = args_hasher.hash_args(False, [c1], [None]) + h2 = args_hasher.hash_args(False, [c2], [None]) + assert isinstance(h1, str), f"opaque meta member must not disable fastcache, got {h1!r}" + assert isinstance(h2, str) + assert h1 == h2, "opaque-typed members must not affect the fastcache key" + + +@test_utils.test() +def test_args_hasher_data_oriented_nested_field_still_fails() -> None: + """Recognised-but-unsupported (``qd.field``) inside a data_oriented container still disables fastcache. + + The opaque-member-silent-skip default does NOT relax this. Fields are *recognised* tensor-like types + whose shape/dtype would affect kernel codegen, but fastcache doesn't yet support hashing them; the + safe default is to disable fastcache for the whole call when one is encountered. + """ + + @qd.data_oriented + class Container: + def __init__(self) -> None: + self.f = qd.field(qd.i32, shape=(4,)) + + c = Container() + h = args_hasher.hash_args(False, [c], [None]) + assert isinstance(h, FastcacheSkip), f"field inside data_oriented must disable fastcache, got {h!r}" + + +@test_utils.test() +def test_args_hasher_dataclass_with_opaque_field_silently_skipped() -> None: + """A dataclass with an opaque-typed field should not disable fastcache.""" + + class Opaque: + def __init__(self, val: int) -> None: + self.val = val + + @dataclasses.dataclass + class State: + meta: object + nd: qd.types.NDArray[qd.i32, 1] + + s1 = State(meta=Opaque(1), nd=qd.ndarray(qd.i32, (4,))) + s2 = State(meta=Opaque(2), nd=qd.ndarray(qd.i32, (4,))) + h1 = args_hasher.hash_args(False, [s1], [None]) + h2 = args_hasher.hash_args(False, [s2], [None]) + assert isinstance(h1, str), f"opaque dataclass field must not disable fastcache, got {h1!r}" + assert h1 == h2, "opaque-typed fields must not affect the fastcache key" + + @test_utils.test() def test_args_hasher_ndarray() -> None: seen = set() From 775790778cf08bbdf1d5a5264993ed0392604203 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 02:34:49 -0700 Subject: [PATCH 15/46] [Doc] Fastcache: opaque-member silencing is the default; clarify stable_members scope Docs the new behaviour committed in 49ffb3b44: - ``compound_types.md`` ``### Fastcache``: explain the three-bucket type-based classification (recognised+valid / opaque / recognised+unsupported) that applies to every ``@qd.data_oriented`` argument by default. Add a separate ``### stable_members=True`` subsection clarifying that the flag is a per-call launch performance hint (template-mapper + launch-context cache), not a fastcache contract. - ``fastcache.md`` compound-type rules: add the opaque-skip bullet and the opaque-vs-recognised-unsupported note. - ``kernel_impl.data_oriented`` docstring: narrow ``stable_members`` to its actual scope (per-call walker skip) and explicitly note that fastcache silences opaque members regardless of the flag. --- docs/source/user_guide/compound_types.md | 36 ++++++++++++++++++++++++ docs/source/user_guide/fastcache.md | 5 ++++ python/quadrants/lang/kernel_impl.py | 19 ++++++++----- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 7b942e4cab..50da314840 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -198,6 +198,42 @@ state.step() `@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes, but is disabled for fields; see [Advanced — compound-type cache keying](fastcache.md#compound-type-cache-keying) for more information. +The fastcache hasher classifies each member of a `@qd.data_oriented` argument by `type(member)`: + +| Member kind (by `type`) | Examples | Behaviour | +|---|---|---| +| **Recognised + valid** | `qd.ndarray` (or `qd.Tensor`-wrapped), `dataclasses.dataclass` whose fields are recognised-valid, primitives, enums, `numpy.ndarray`, `torch.Tensor`, nested `@qd.data_oriented` whose own walk succeeds | Contributes to the cache key. | +| **Opaque** (`type(member)` is none of the above) | UUID-typed identifiers, Pydantic `BaseModel`, plain Python classes that are *not* `@qd.data_oriented` / `dataclasses.dataclass`, `list`, `dict` | Skipped silently. | +| **Recognised + unsupported** | `qd.field`, `qd.Vector.field`, `qd.Matrix.field` (or a `dataclasses.dataclass` whose fields are fields), nested `@qd.data_oriented` whose own walk hits a field somewhere | Fastcache fails for the whole call. | + +Skipping opaque members from the cache key is safe by construction: the kernel can only read recognised types (ndarrays, primitives, enums, dataclasses, nested `@qd.data_oriented`). Opaque Python objects — `UUID` identifiers, Pydantic config objects, back-pointers up the object graph, etc. — cannot be read by kernel code and therefore cannot affect kernel codegen, so omitting them from the hash is correct. The classification is purely type-based; there is no special-casing for "metadata" or "back-pointer" as semantic roles. + +`qd.field` / `qd.Matrix.field` are NOT opaque — they are recognised tensor-like types whose shape and dtype would affect kernel codegen, but fastcache doesn't yet support hashing them. Fastcache correctly fails when those are present. + +### `stable_members=True` + +`@qd.data_oriented(stable_members=True)` is an opt-in performance hint for large container classes whose `qd.ndarray` member set is fixed after `__init__`. The flag is unrelated to fastcache — it only affects launch-time bookkeeping. + +```python +@qd.data_oriented(stable_members=True) +class Solver: + def __init__(self, parent, options): + self._uid = make_uuid() + self._parent = parent + self.options = options + self.dofs_state = ... # holds ndarrays; never reassigned + self.links_state = ... +``` + +What changes: + +1. **Per-call template-mapper walker fast path.** The mapper skips the per-call walk that re-discovers ndarray attribute paths and reuses the path set captured on the first instance walked. For solver-style classes (one `self` argument, many attributes, many kernels) this drops a per-call cost of ~100 ns/kernel back below the noise floor. +2. **Per-call launch-context cache.** The launch path skips the per-call `_resolve_struct_ndarray` walk that folds live ndarray ids into the launch-context cache key, reusing the cached set instead. + +Use it when the container has many `qd.ndarray` members that are allocated once in `__init__` and never reassigned. Do not use it if you reassign ndarray members after construction — the previously-compiled kernel will be reused even if the new ndarray has different dtype/ndim/layout (undefined behaviour). May also be set as a class-level attribute `_qd_stable_members = True` (equivalent). + +`stable_members=True` does **not** affect fastcache. Opaque-member silencing in the cache key is the default behaviour for all `@qd.data_oriented` classes. + ### Under the hood Like `dataclasses.dataclass`, a `@qd.data_oriented` object is Python-only — the compiler flattens it into individual kernel parameters and the object itself has no kernel-side representation. Unlike `dataclasses.dataclass` it needs no member annotations: the compiler reads the live instance's attributes directly. Primitive members are baked into the kernel as constants, so each distinct primitive value compiles a new specialised kernel. diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 5d4e9381c8..b9cfd8cc5c 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -154,6 +154,11 @@ The args hasher walks compound-type kernel parameters recursively. For each leaf - Nested `@qd.data_oriented` member — recurses. - Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). - `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. +- **Opaque-typed member** (any type not in the recognised list above — UUID identifiers, Pydantic `BaseModel`, plain Python classes, references up the object graph, `list`, `dict`, ...) — **skipped silently** from the cache key. + +Skipping opaque members is safe by construction: the kernel can only read recognised types, so opaque members cannot affect kernel codegen. This is the default for all `@qd.data_oriented` classes; no opt-in flag is required. (`@qd.data_oriented(stable_members=True)` is a separate per-call launch performance hint — see [compound_types.md](compound_types.md#stable_memberstrue) — and does not affect fastcache behaviour.) + +Note the distinction between **opaque** (any unrecognised type, skipped silently) and **recognised-but-unsupported** (`qd.field` / `qd.Matrix.field`, which disable fastcache). Field-like types are *recognised* — their shape/dtype affect kernel codegen — but the hasher does not yet know how to include them in the cache key, so the safe default is to disable fastcache when one is present, rather than silently emit a stale cache key. **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 9270050e74..0d44343392 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -300,13 +300,18 @@ def data_oriented(cls=None, *, stable_members: bool = False): Args: cls (Class): the class to be decorated. - stable_members (bool): if ``True``, declares that the class's ndarray-typed members are - allocated once and never reassigned between kernel calls. Quadrants will skip a - per-call walk of the instance's attributes (~1-2 us/call savings on Genesis-style - containers with several ndarray attrs). Reassigning a member on a ``stable_members`` - class is undefined behaviour — the previously-compiled kernel will be reused even if - the new ndarray has different dtype/ndim/layout. May also be set as a class-level - attribute ``_qd_stable_members = True`` (equivalent). + stable_members (bool): per-call launch performance hint, unrelated to fastcache. If + ``True``, declares that the class's ndarray-typed members are allocated once and + never reassigned between kernel calls. Quadrants will skip a per-call walk of the + instance's attributes in (a) the template-mapper's spec-key construction and (b) + the launch-context cache's live-ndarray-id fold-in (~1-2 us/call savings on + Genesis-style containers with several ndarray attrs). Reassigning a member on a + ``stable_members`` class is undefined behaviour — the previously-compiled kernel + will be reused even if the new ndarray has different dtype/ndim/layout. May also + be set as a class-level attribute ``_qd_stable_members = True`` (equivalent). + Note: this flag does **not** affect the fastcache argument hasher; opaque-typed + members are silently skipped from the fastcache key for all data_oriented classes + by default. Returns: The decorated class (or, when called with arguments, a decorator). From fb38fec0b0914fe722d0546605b494769a7fc10b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 02:50:38 -0700 Subject: [PATCH 16/46] Revert "[Doc] Fastcache: opaque-member silencing is the default; clarify stable_members scope" This reverts commit 775790778cf08bbdf1d5a5264993ed0392604203. --- docs/source/user_guide/compound_types.md | 36 ------------------------ docs/source/user_guide/fastcache.md | 5 ---- python/quadrants/lang/kernel_impl.py | 19 +++++-------- 3 files changed, 7 insertions(+), 53 deletions(-) diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 50da314840..7b942e4cab 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -198,42 +198,6 @@ state.step() `@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes, but is disabled for fields; see [Advanced — compound-type cache keying](fastcache.md#compound-type-cache-keying) for more information. -The fastcache hasher classifies each member of a `@qd.data_oriented` argument by `type(member)`: - -| Member kind (by `type`) | Examples | Behaviour | -|---|---|---| -| **Recognised + valid** | `qd.ndarray` (or `qd.Tensor`-wrapped), `dataclasses.dataclass` whose fields are recognised-valid, primitives, enums, `numpy.ndarray`, `torch.Tensor`, nested `@qd.data_oriented` whose own walk succeeds | Contributes to the cache key. | -| **Opaque** (`type(member)` is none of the above) | UUID-typed identifiers, Pydantic `BaseModel`, plain Python classes that are *not* `@qd.data_oriented` / `dataclasses.dataclass`, `list`, `dict` | Skipped silently. | -| **Recognised + unsupported** | `qd.field`, `qd.Vector.field`, `qd.Matrix.field` (or a `dataclasses.dataclass` whose fields are fields), nested `@qd.data_oriented` whose own walk hits a field somewhere | Fastcache fails for the whole call. | - -Skipping opaque members from the cache key is safe by construction: the kernel can only read recognised types (ndarrays, primitives, enums, dataclasses, nested `@qd.data_oriented`). Opaque Python objects — `UUID` identifiers, Pydantic config objects, back-pointers up the object graph, etc. — cannot be read by kernel code and therefore cannot affect kernel codegen, so omitting them from the hash is correct. The classification is purely type-based; there is no special-casing for "metadata" or "back-pointer" as semantic roles. - -`qd.field` / `qd.Matrix.field` are NOT opaque — they are recognised tensor-like types whose shape and dtype would affect kernel codegen, but fastcache doesn't yet support hashing them. Fastcache correctly fails when those are present. - -### `stable_members=True` - -`@qd.data_oriented(stable_members=True)` is an opt-in performance hint for large container classes whose `qd.ndarray` member set is fixed after `__init__`. The flag is unrelated to fastcache — it only affects launch-time bookkeeping. - -```python -@qd.data_oriented(stable_members=True) -class Solver: - def __init__(self, parent, options): - self._uid = make_uuid() - self._parent = parent - self.options = options - self.dofs_state = ... # holds ndarrays; never reassigned - self.links_state = ... -``` - -What changes: - -1. **Per-call template-mapper walker fast path.** The mapper skips the per-call walk that re-discovers ndarray attribute paths and reuses the path set captured on the first instance walked. For solver-style classes (one `self` argument, many attributes, many kernels) this drops a per-call cost of ~100 ns/kernel back below the noise floor. -2. **Per-call launch-context cache.** The launch path skips the per-call `_resolve_struct_ndarray` walk that folds live ndarray ids into the launch-context cache key, reusing the cached set instead. - -Use it when the container has many `qd.ndarray` members that are allocated once in `__init__` and never reassigned. Do not use it if you reassign ndarray members after construction — the previously-compiled kernel will be reused even if the new ndarray has different dtype/ndim/layout (undefined behaviour). May also be set as a class-level attribute `_qd_stable_members = True` (equivalent). - -`stable_members=True` does **not** affect fastcache. Opaque-member silencing in the cache key is the default behaviour for all `@qd.data_oriented` classes. - ### Under the hood Like `dataclasses.dataclass`, a `@qd.data_oriented` object is Python-only — the compiler flattens it into individual kernel parameters and the object itself has no kernel-side representation. Unlike `dataclasses.dataclass` it needs no member annotations: the compiler reads the live instance's attributes directly. Primitive members are baked into the kernel as constants, so each distinct primitive value compiles a new specialised kernel. diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index b9cfd8cc5c..5d4e9381c8 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -154,11 +154,6 @@ The args hasher walks compound-type kernel parameters recursively. For each leaf - Nested `@qd.data_oriented` member — recurses. - Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). - `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. -- **Opaque-typed member** (any type not in the recognised list above — UUID identifiers, Pydantic `BaseModel`, plain Python classes, references up the object graph, `list`, `dict`, ...) — **skipped silently** from the cache key. - -Skipping opaque members is safe by construction: the kernel can only read recognised types, so opaque members cannot affect kernel codegen. This is the default for all `@qd.data_oriented` classes; no opt-in flag is required. (`@qd.data_oriented(stable_members=True)` is a separate per-call launch performance hint — see [compound_types.md](compound_types.md#stable_memberstrue) — and does not affect fastcache behaviour.) - -Note the distinction between **opaque** (any unrecognised type, skipped silently) and **recognised-but-unsupported** (`qd.field` / `qd.Matrix.field`, which disable fastcache). Field-like types are *recognised* — their shape/dtype affect kernel codegen — but the hasher does not yet know how to include them in the cache key, so the safe default is to disable fastcache when one is present, rather than silently emit a stale cache key. **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 0d44343392..9270050e74 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -300,18 +300,13 @@ def data_oriented(cls=None, *, stable_members: bool = False): Args: cls (Class): the class to be decorated. - stable_members (bool): per-call launch performance hint, unrelated to fastcache. If - ``True``, declares that the class's ndarray-typed members are allocated once and - never reassigned between kernel calls. Quadrants will skip a per-call walk of the - instance's attributes in (a) the template-mapper's spec-key construction and (b) - the launch-context cache's live-ndarray-id fold-in (~1-2 us/call savings on - Genesis-style containers with several ndarray attrs). Reassigning a member on a - ``stable_members`` class is undefined behaviour — the previously-compiled kernel - will be reused even if the new ndarray has different dtype/ndim/layout. May also - be set as a class-level attribute ``_qd_stable_members = True`` (equivalent). - Note: this flag does **not** affect the fastcache argument hasher; opaque-typed - members are silently skipped from the fastcache key for all data_oriented classes - by default. + stable_members (bool): if ``True``, declares that the class's ndarray-typed members are + allocated once and never reassigned between kernel calls. Quadrants will skip a + per-call walk of the instance's attributes (~1-2 us/call savings on Genesis-style + containers with several ndarray attrs). Reassigning a member on a ``stable_members`` + class is undefined behaviour — the previously-compiled kernel will be reused even if + the new ndarray has different dtype/ndim/layout. May also be set as a class-level + attribute ``_qd_stable_members = True`` (equivalent). Returns: The decorated class (or, when called with arguments, a decorator). From 7cabaa0b893fc4e13d4c103d230eee3e6cfa852c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 02:50:38 -0700 Subject: [PATCH 17/46] Revert "[Fix] Fastcache: skip opaque-typed members silently by default" This reverts commit 49ffb3b446cba44da09155b9292dae619854a464. --- .../lang/_fast_caching/args_hasher.py | 219 ++++++++++-------- .../lang/fast_caching/test_args_hasher.py | 70 ------ 2 files changed, 117 insertions(+), 172 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 77d3ee2ba6..aa8df595eb 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -41,22 +41,18 @@ _DC_REPR_NONE = object() -# Sentinel returned by ``stringify_obj_type`` when a recognised-but-unsupported tensor-like type (``Field`` / -# ``MatrixField``) is encountered anywhere in the traversal. Containers that see this sentinel (``dataclass_to_repr``, -# the ``data_oriented`` branch, and the top-level ``hash_args`` loop) must propagate it upward — fastcache cannot -# safely hash the call. Distinct from ``None``, which means "opaque type, safe to silently skip at nested levels". -class _FailFastcache: - """Singleton sentinel; identity-compared. See module docstring on ``stringify_obj_type``'s return contract.""" +# Set by ``stringify_obj_type`` when it encounters a recognised-but-unsupported tensor-like type (``Field`` / +# ``MatrixField``) anywhere in the traversal — including nested under a dataclass or another data_oriented object. +# The ``stable_members=True`` data_oriented walker uses this to differentiate two reasons a child returned ``None``: +# truly-opaque metadata (``RigidSolver._uid: UID``, etc.) which is inert and can be skipped, vs a tensor-like type +# whose value affects kernel codegen and must invalidate fastcache for the whole call. Reset at the top of each +# ``hash_args``; snapshotted/restored around each nested ``stringify_obj_type`` call inside the data_oriented walker. +_hit_recognised_unsupported = False - _instance = None - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - -_FAIL_FASTCACHE = _FailFastcache() +def _mark_hit_recognised_unsupported() -> None: + global _hit_recognised_unsupported # pylint: disable=global-statement + _hit_recognised_unsupported = True class FastcacheSkip(enum.Enum): @@ -71,6 +67,15 @@ class FastcacheSkip(enum.Enum): _should_warn = False +# Counter set by the data_oriented walker when entering a ``_qd_stable_members`` object. While nonzero, the +# unknown-type branch of ``stringify_obj_type`` returns ``None`` silently instead of logging +# ``[FASTCACHE][PARAM_INVALID]``. ``stable_members=True`` is the user's promise that the class's member set / types +# don't change after construction — under that promise, opaque members like ``RigidSolver._uid`` (a +# ``genesis.utils.uid.UID``) don't affect kernel codegen so they can be skipped silently rather than killing +# fastcache for the whole call. Single-threaded by construction (the hasher only runs during JIT compile). +_skip_unknown_warn_depth = 0 + + def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: """Flag that a warning is needed if the Field didn't arrive through a qd.Tensor annotation.""" global _should_warn # pylint: disable=global-statement @@ -83,38 +88,24 @@ def _mark_should_warn() -> None: _should_warn = True -def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | _FailFastcache | None: - """Hash a dataclass instance. - - Returns: - - ``str``: a string representation suitable for the fastcache key. - - ``_FAIL_FASTCACHE``: a recognised-but-unsupported tensor-like field was hit; fastcache must be disabled - for the whole call. - - ``None``: dataclass-level skip. Currently unused (dataclasses always succeed unless they hit Field/MatrixField), - but defined symmetrically with ``stringify_obj_type``. - - Note that opaque-typed fields (UUID, plain Python objects, ...) are *silently skipped* — they cannot affect - kernel codegen because the kernel cannot read non-recognised Python types, so omitting them from the hash is - safe by construction. - """ +def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | None: # PERF: For frozen dataclasses, the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``_DC_REPR_NONE`` is stored to distinguish "not yet computed" from "computed but not fast-cacheable". + # A cached ``None`` is stored as the sentinel ``_DC_REPR_NONE`` to distinguish "not yet computed" from + # "computed but not fast-cacheable". is_frozen = type(arg).__hash__ is not None if is_frozen: cached = getattr(arg, "_qd_dc_repr", None) if cached is _DC_REPR_NONE: - return _FAIL_FASTCACHE + return None if cached is not None: return cached repr_l = [] for field in dataclasses.fields(arg): child_value = getattr(arg, field.name) _repr = stringify_obj_type(raise_on_templated_floats, path + (field.name,), child_value, arg_meta=None) - if _repr is _FAIL_FASTCACHE: - # Recognised-but-unsupported (Field/MatrixField) somewhere in this child's subtree. Mark whether the - # field arrived via a non-Tensor annotation so the top-level decides between WARN and FIELD_VIA_TENSOR. + if _repr is None: if isinstance(child_value, _FIELD_TYPES) and field.type is not _TensorWrapper: _mark_should_warn() if is_frozen: @@ -122,11 +113,7 @@ def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], ar object.__setattr__(arg, "_qd_dc_repr", _DC_REPR_NONE) except AttributeError: pass - return _FAIL_FASTCACHE - if _repr is None: - # Opaque-typed field; skip silently. Opaque types cannot affect kernel codegen because the kernel - # cannot read non-recognised Python types — they are inert metadata. - continue + return None full_repr = f"{field.name}: ({_repr})" if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False): full_repr += f" = {child_value}" @@ -148,32 +135,23 @@ def _is_template(arg_meta: ArgMetadata | None) -> bool: def stringify_obj_type( - raise_on_templated_floats: bool, - path: tuple[str, ...], - obj: object, - arg_meta: ArgMetadata | None, - nested: bool = False, -) -> str | _FailFastcache | None: + raise_on_templated_floats: bool, path: tuple[str, ...], obj: object, arg_meta: ArgMetadata | None +) -> str | None: """ - Convert an object into a string representation that only depends on its type (and, where relevant, its value). - - Return contract: - - ``str``: the object is hashable for fastcache; the returned string contributes to the cache key. - - ``_FAIL_FASTCACHE``: a recognised-but-unsupported type (``qd.field`` / ``qd.Matrix.field``) was encountered. - Containers must propagate this upward; fastcache will be disabled for the whole call. - - ``None``: the object's type is *opaque* — not recognised by the hasher. Containers (``dataclass_to_repr`` - and the ``data_oriented`` branch below) silently skip opaque members because opaque types cannot affect - kernel codegen (the kernel can only read recognised types: ndarrays, primitives, enums, dataclasses, - nested ``@qd.data_oriented`` objects). At the top level (``nested=False``), opaque is treated as - an error and a ``[FASTCACHE][PARAM_INVALID]`` warning is emitted. - - Parameters: - - ``nested``: ``True`` if this call comes from a container walker (dataclass / data_oriented). Suppresses - the top-level ``[FASTCACHE][PARAM_INVALID]`` warning for opaque types so nested opaque members are - skipped silently. ``False`` at the top of each kernel-arg traversal. - - ``arg_meta``: non-``None`` only for the top-level kernel arguments and for ``@qd.data_oriented`` members. - Used to determine whether to bake values into the cache key (primitives in template positions, and all - primitive members of data-oriented containers). + Convert an object into a string representation that only depends on its type. + + String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have + to be the actual python type string, just a string that is representative of the type, and won't collide + with different (allowed) types. String should be non-empty. + + Note that fields are not included in fast cache. + + arg_meta should only be non-None for the top level arguments and for data oriented objects. It is + used currently to determine whether a value is added to the cache key, as well as the name. eg + - at the top level, primitive types have their values added to the cache key if their annotation is qd.Template, + since they are baked into the kernel + - in data oriented objects, the values of all primitive types are added to the cache key, since they are baked + into the kernel, and require a kernel recompilation, when they change """ # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` strips # wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / data-oriented @@ -181,6 +159,8 @@ def stringify_obj_type( # fields, so a wrapper stored as a struct field arrives here un-stripped. Without this branch the hasher falls # through to the ``[FASTCACHE][PARAM_INVALID]`` warning and disables the fast path for the whole call. See # ``perso_hugh/doc/quadrants-tensor.md`` §8.14. + # ``qd.Tensor`` wrappers: unwrap to the bare impl so the type checks below match. After unwrap, ``_qd_layout`` (if + # any) is on the impl. # # PERF-CRITICAL: The _any_tensor_constructed guard makes this check zero-cost when no qd.Tensor has been created. # ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer comparison (~10 @@ -197,11 +177,12 @@ def stringify_obj_type( if isinstance(obj, VectorNdarray): return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): - # Recognised-but-unsupported: Field's shape/dtype affect kernel codegen but fastcache doesn't yet know how - # to handle them. Disable fastcache for the whole call. + # disabled for now, because we need to think about how to handle field offset + # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return _FAIL_FASTCACHE + _mark_hit_recognised_unsupported() + return None if isinstance(obj, MatrixNdarray): return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): @@ -209,42 +190,78 @@ def stringify_obj_type( if isinstance(obj, np.ndarray): return f"[np-{obj.dtype}-{obj.ndim}]" if isinstance(obj, MatrixField): - # Recognised-but-unsupported, same as ScalarField above. + # disabled for now, because we need to think about how to handle field offset + # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return _FAIL_FASTCACHE + _mark_hit_recognised_unsupported() + return None if is_dataclass_instance(obj): return dataclass_to_repr(raise_on_templated_floats, path, obj) if is_data_oriented(obj): - # Walk the data_oriented container's members. Recognised members contribute to the cache key; recognised- - # but-unsupported (Field/MatrixField) propagates _FAIL_FASTCACHE; opaque-typed members are skipped silently. - # - # Silently skipping opaque members is safe by construction: the kernel can only read recognised member types - # (ndarrays, primitives, enums, dataclasses, nested data_oriented). Opaque Python objects (UUIDs, Pydantic - # ``BaseModel`` instances, back-pointers up the object graph, etc.) cannot be read by kernel code, so they - # cannot affect kernel codegen and omitting them from the hash is correct. + # ``@qd.data_oriented(stable_members=True)``: the class promises its member *set* and *types* don't change + # after construction. Under that contract, unrecognised member types (e.g. Genesis's ``RigidSolver._uid`` of + # type ``genesis.utils.uid.UID``, or any other opaque metadata) are treated as inert from fastcache's + # perspective: they don't affect kernel codegen so they can be skipped silently rather than killing fastcache + # for the whole call. Without this, migrating a kernel from a standalone ``@qd.kernel`` function to a method + # on a ``@qd.data_oriented`` class disables fastcache the moment the class holds any opaque metadata, even + # though the kernel's compiled output would be identical. + stable_members = bool(type(obj).__dict__.get("_qd_stable_members")) child_repr_l = ["da"] + _dict = {} try: + # pyright is ok with this approach _asdict = getattr(obj, "_asdict") _dict = _asdict() except AttributeError: _dict = obj.__dict__ - for k, v in _dict.items(): - # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the per-instance - # ``BoundQuadrantsCallable`` on ``instance.__dict__`` so that subsequent ``instance.method`` lookups skip - # the descriptor allocation; those entries are not data and must not invalidate the fastcache key. - v_type = type(v) - if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: - continue - _child_repr = stringify_obj_type( - raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, ""), nested=True - ) - if _child_repr is _FAIL_FASTCACHE: - return _FAIL_FASTCACHE - if _child_repr is None: - # Opaque member; skip silently. - continue - child_repr_l.append(f"{k}: {_child_repr}") + global _skip_unknown_warn_depth # pylint: disable=global-statement + if stable_members: + _skip_unknown_warn_depth += 1 + try: + for k, v in _dict.items(): + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` + # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so + # that subsequent ``instance.method`` lookups skip the descriptor allocation; + # those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + # Snapshot the recognised-but-unsupported flag around the recursive call so we can tell whether + # *this child's* subtree hit a ``Field`` / ``MatrixField`` (in which case we must fail fastcache + # even under ``stable_members``). + global _hit_recognised_unsupported # pylint: disable=global-statement + _hit_recognised_unsupported = False + _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) + child_hit_field = _hit_recognised_unsupported + if _child_repr is None: + # Differentiate two reasons ``stringify_obj_type`` returns None: + # + # (a) RECOGNISED-BUT-UNSUPPORTED: ``Field`` / ``MatrixField`` somewhere in this child's + # subtree. These are *known* tensor-like types whose values affect kernel codegen but + # which fastcache doesn't yet handle. Killing fastcache for the whole call is the + # intended contract — ``test_num_envs[False-...]`` pins this behaviour for the field + # backend. + # (b) TRULY-OPAQUE: anything that falls through to the ``[FASTCACHE][PARAM_INVALID]`` + # warning at the bottom of ``stringify_obj_type`` (``RigidSolver._uid`` of type + # ``UID``, etc.). For ``stable_members=True`` containers, opaque metadata is inert by + # the user's contract and can be skipped without invalidating the hash for the rest + # of the members. + if stable_members and not child_hit_field: + continue + if _should_warn: + _logging.warn( + f"A kernel that has been marked as eligible for fast cache was passed 1 or more " + f"parameters that are not, in fact, eligible for fast cache: one of the parameters was a " + f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " + f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " + f"information, the path of the value was {path}." + ) + return None + child_repr_l.append(f"{k}: {_child_repr}") + finally: + if stable_members: + _skip_unknown_warn_depth -= 1 return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): if _is_template(arg_meta): @@ -261,10 +278,10 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - # Opaque (unrecognised) type. At nested levels, container walkers skip these silently — opaque types cannot - # affect kernel codegen because the kernel cannot read non-recognised Python types. At the top level, this is - # a user error (the kernel's argument is uninterpretable to fastcache) and we emit a warning. - if nested: + if _skip_unknown_warn_depth > 0: + # Inside a ``stable_members=True`` data_oriented walk: opaque members are tolerated by contract, so don't log + # the per-member ``[FASTCACHE][PARAM_INVALID]`` warning. The data_oriented walker reads the returned ``None`` + # and skips this member. return None _mark_should_warn() # The bit in caps should not be modified without updating corresponding test @@ -279,9 +296,10 @@ def stringify_obj_type( def hash_args( raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] ) -> str | FastcacheSkip: - """Return the args hash string, or a FastcacheSkip explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement + """Return the args hash string, or a HashFailure explaining why hashing failed.""" + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn, _hit_recognised_unsupported # pylint: disable=line-too-long _should_warn = False + _hit_recognised_unsupported = False g_num_calls += 1 g_num_args += len(args) hash_l = [] @@ -291,12 +309,9 @@ def hash_args( ) for i_arg, arg in enumerate(args): start = time.time() - _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg], nested=False) + _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg]) g_repr_time += time.time() - start - # Both ``_FAIL_FASTCACHE`` (recognised-but-unsupported) and ``None`` (opaque at top level) disable - # fastcache. ``_should_warn`` selects between WARN (loud) and FIELD_VIA_TENSOR (silent — Field reached via - # qd.Tensor annotation, which is a normal path). - if _hash is _FAIL_FASTCACHE or _hash is None or not _hash: + if not _hash: g_num_ignored_calls += 1 return FastcacheSkip.WARN if _should_warn else FastcacheSkip.FIELD_VIA_TENSOR hash_l.append(_hash) diff --git a/tests/python/quadrants/lang/fast_caching/test_args_hasher.py b/tests/python/quadrants/lang/fast_caching/test_args_hasher.py index de7cf43d57..bda9adf808 100644 --- a/tests/python/quadrants/lang/fast_caching/test_args_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_args_hasher.py @@ -107,76 +107,6 @@ class Foo: ... assert args_hasher.hash_args(False, [foo], [None]) is not None -@test_utils.test() -def test_args_hasher_data_oriented_with_opaque_member_silently_skipped() -> None: - """Default ``@qd.data_oriented`` (no ``stable_members``) tolerates opaque-typed members. - - Opaque members (custom Python types not in the recognised set — UUIDs, plain Python objects, - references to non-data-oriented classes, ...) are silently skipped from the fastcache key, because - the kernel cannot read non-recognised Python types and therefore opaque members cannot affect - kernel codegen. Pinning this here so we don't regress to silently disabling fastcache for any - data_oriented holding metadata. - """ - - class Opaque: - def __init__(self, val: int) -> None: - self.val = val - - @qd.data_oriented - class Container: - def __init__(self, opaque: Opaque) -> None: - self.meta = opaque - self.nd = qd.ndarray(qd.i32, shape=(4,)) - - c1 = Container(Opaque(1)) - c2 = Container(Opaque(2)) - h1 = args_hasher.hash_args(False, [c1], [None]) - h2 = args_hasher.hash_args(False, [c2], [None]) - assert isinstance(h1, str), f"opaque meta member must not disable fastcache, got {h1!r}" - assert isinstance(h2, str) - assert h1 == h2, "opaque-typed members must not affect the fastcache key" - - -@test_utils.test() -def test_args_hasher_data_oriented_nested_field_still_fails() -> None: - """Recognised-but-unsupported (``qd.field``) inside a data_oriented container still disables fastcache. - - The opaque-member-silent-skip default does NOT relax this. Fields are *recognised* tensor-like types - whose shape/dtype would affect kernel codegen, but fastcache doesn't yet support hashing them; the - safe default is to disable fastcache for the whole call when one is encountered. - """ - - @qd.data_oriented - class Container: - def __init__(self) -> None: - self.f = qd.field(qd.i32, shape=(4,)) - - c = Container() - h = args_hasher.hash_args(False, [c], [None]) - assert isinstance(h, FastcacheSkip), f"field inside data_oriented must disable fastcache, got {h!r}" - - -@test_utils.test() -def test_args_hasher_dataclass_with_opaque_field_silently_skipped() -> None: - """A dataclass with an opaque-typed field should not disable fastcache.""" - - class Opaque: - def __init__(self, val: int) -> None: - self.val = val - - @dataclasses.dataclass - class State: - meta: object - nd: qd.types.NDArray[qd.i32, 1] - - s1 = State(meta=Opaque(1), nd=qd.ndarray(qd.i32, (4,))) - s2 = State(meta=Opaque(2), nd=qd.ndarray(qd.i32, (4,))) - h1 = args_hasher.hash_args(False, [s1], [None]) - h2 = args_hasher.hash_args(False, [s2], [None]) - assert isinstance(h1, str), f"opaque dataclass field must not disable fastcache, got {h1!r}" - assert h1 == h2, "opaque-typed fields must not affect the fastcache key" - - @test_utils.test() def test_args_hasher_ndarray() -> None: seen = set() From b5b360a0a74abcaaea6804ed313ffcf51a09b3a5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 03:00:51 -0700 Subject: [PATCH 18/46] [Fix] Fastcache: replace PARAM_INVALID / silent-skip with qualname fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unrecognised types in fastcache argument hashing previously had two failure modes, both bad: - Top-level: ``[FASTCACHE][PARAM_INVALID]`` warn + return None, disabling fastcache for the whole call. Any solver-like object carrying a single opaque metadata field (Genesis ``UID``, Pydantic config, back-pointer) silently killed the cache. - Nested under ``@qd.data_oriented(stable_members=True)``: silent skip. Worked for the Genesis case but is dangerous: if someone later adds a new tensor-like type (e.g. ``BFloat16Tensor``) whose value affects kernel codegen but forgets to register it in args_hasher's recognised set, the silent skip serves stale cache results without any indication. Both paths are replaced with a single ``type(v).__qualname__``-based fallback (``opaque-.``) that emits a one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning per type. Properties: - Cache key stable across instances of the same opaque class (Genesis UID #1 and UID #2 produce the same key). Kernels cannot read non-recognised Python types so opaque metadata cannot affect codegen, making type-identity-only hashing correct for genuinely opaque members. - Loud diagnostic for the dangerous case: any unrecognised type that ever gets hashed prints a warning pointing at args_hasher.stringify_obj_type so a missed tensor-like registration is impossible to miss. - ``ScalarField`` / ``MatrixField`` (recognised-but-unsupported tensor-like) still disable fastcache via a new ``_FAIL_FASTCACHE`` sentinel — their shape/dtype affect codegen but fastcache doesn't yet handle them. Distinct from the qualname fallback so the field path remains correct. Also adds ``pruning_paths`` and ``parent_flat`` plumbing through ``stringify_obj_type`` / ``dataclass_to_repr`` / ``hash_args`` for the upcoming pruning-driven narrow walk (L1 cache lookup of kernel-accessed flat names); the new parameters default to None so this commit alone is the qualname-fallback baseline. ``test_src_ll_cache_arg_warnings`` updated to assert the new ``[UNKNOWN_TYPE]`` warning (instead of the old ``[PARAM_INVALID]`` + ``[INVALID_FUNC]`` dead-end). The ``_qd_stable_members`` flag is no longer read by args_hasher; its launch-context role (``_mutable_nd_cached_val`` short-circuit) is unchanged in this commit and will be addressed separately. --- .../lang/_fast_caching/args_hasher.py | 388 +++++++++++------- .../lang/fast_caching/test_src_ll_cache.py | 14 +- 2 files changed, 253 insertions(+), 149 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index aa8df595eb..1889a1e164 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -11,6 +11,7 @@ from quadrants._tensor_wrapper import Tensor as _TensorWrapper from quadrants.types.annotations import Template +from .._dataclass_util import create_flat_name from .._ndarray import ScalarNdarray from .._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from ..field import ScalarField @@ -41,18 +42,26 @@ _DC_REPR_NONE = object() -# Set by ``stringify_obj_type`` when it encounters a recognised-but-unsupported tensor-like type (``Field`` / -# ``MatrixField``) anywhere in the traversal — including nested under a dataclass or another data_oriented object. -# The ``stable_members=True`` data_oriented walker uses this to differentiate two reasons a child returned ``None``: -# truly-opaque metadata (``RigidSolver._uid: UID``, etc.) which is inert and can be skipped, vs a tensor-like type -# whose value affects kernel codegen and must invalidate fastcache for the whole call. Reset at the top of each -# ``hash_args``; snapshotted/restored around each nested ``stringify_obj_type`` call inside the data_oriented walker. -_hit_recognised_unsupported = False +# Sentinel returned by ``stringify_obj_type`` when a recognised-but-unsupported tensor-like type (``ScalarField`` / +# ``MatrixField``) is encountered anywhere in the traversal. Containers (``dataclass_to_repr``, ``data_oriented`` +# branch, top-level ``hash_args`` loop) must propagate it upward — fastcache cannot safely hash the call because +# fields have shape/dtype that would affect kernel codegen but fastcache doesn't yet know how to include them. +# +# Distinct from any other return value: an unrecognised opaque type now falls back to a deterministic +# ``type(v).__qualname__`` string (see fallback in ``stringify_obj_type``), so the only way ``stringify_obj_type`` +# disables fastcache is by returning this sentinel. +class _FailFastcache: + """Singleton sentinel; identity-compared.""" + _instance = None -def _mark_hit_recognised_unsupported() -> None: - global _hit_recognised_unsupported # pylint: disable=global-statement - _hit_recognised_unsupported = True + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +_FAIL_FASTCACHE = _FailFastcache() class FastcacheSkip(enum.Enum): @@ -62,18 +71,21 @@ class FastcacheSkip(enum.Enum): WARN = "warn" -# Set when the fastcache skip is something callers should warn about (as opposed to a Field arriving through a -# qd.Tensor annotation, which is a normal silent path). Reset at the start of each hash_args call. +# Set when the fastcache skip is something callers should warn about (as opposed to a ``Field`` arriving through a +# ``qd.Tensor`` annotation, which is a normal silent path). Reset at the start of each ``hash_args`` call. _should_warn = False -# Counter set by the data_oriented walker when entering a ``_qd_stable_members`` object. While nonzero, the -# unknown-type branch of ``stringify_obj_type`` returns ``None`` silently instead of logging -# ``[FASTCACHE][PARAM_INVALID]``. ``stable_members=True`` is the user's promise that the class's member set / types -# don't change after construction — under that promise, opaque members like ``RigidSolver._uid`` (a -# ``genesis.utils.uid.UID``) don't affect kernel codegen so they can be skipped silently rather than killing -# fastcache for the whole call. Single-threaded by construction (the hasher only runs during JIT compile). -_skip_unknown_warn_depth = 0 +# Set of ``type(v).__qualname__`` strings we've already emitted the "unknown type, falling back to qualname hash" +# warning for. Lets the loop run thousands of times without spamming the log while still telling the user once +# that fastcache encountered an unrecognised type at a hashed path. Cleared by ``reset_unknown_type_warn_state`` +# (called from ``qd.init``) so each new test sees a clean log. +_warned_unknown_types: set[str] = set() + + +def reset_unknown_type_warn_state() -> None: + """Clear the once-per-process warned-unknown-types set. Called from test setup / ``qd.init``.""" + _warned_unknown_types.clear() def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: @@ -88,40 +100,130 @@ def _mark_should_warn() -> None: _should_warn = True -def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | None: - # PERF: For frozen dataclasses, the repr never changes. Cache it on the instance to avoid repeated +def _qualname_fallback(obj: object, path: tuple[str, ...]) -> str: + """Deterministic fallback for unrecognised types. + + Returns a string derived from ``type(obj)``'s module + qualname so the cache key is *stable* across calls + (instances of the same opaque class get the same hash contribution). Warn once per unrecognised type so a + new tensor-like type added to Quadrants without being added to the recognised list here gets noticed in the + logs without spamming the per-call hot path. + + Safety note: this captures type identity only, NOT value or type-parameters (e.g. dtype/shape on a hypothetical + ``BFloat16Tensor``). For genuinely opaque metadata (UUID, Pydantic config, back-pointers) the type-identity + hash is correct because the kernel cannot read non-recognised Python types. For new tensor-like types whose + dtype/shape *would* affect codegen, the warning is the signal that someone needs to add them to the recognised + set in this module. + """ + t = type(obj) + qualname = f"{getattr(t, '__module__', '')}.{getattr(t, '__qualname__', t.__name__)}" + if qualname not in _warned_unknown_types: + _warned_unknown_types.add(qualname) + _logging.warn( + f"[FASTCACHE][UNKNOWN_TYPE] Falling back to type-name hash for path {path} type {qualname}. " + f"The cache key captures the type identity but not type parameters (e.g. dtype/shape). If this " + f"type's value affects kernel codegen, add explicit handling to " + f"``quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type``." + ) + return f"opaque-{qualname}" + + +def _child_flat(parent_flat: str | None, child_name: str) -> str | None: + """Compute the flat name a kernel parameter would have if it pointed at this container's child. + + For a top-level arg ``state`` with child ``x``: ``__qd_state__qd_x``. + For a deeper child ``state.dofs.x``: ``__qd_state__qd_dofs__qd_x`` (built incrementally). + + ``parent_flat`` is the *kernel-side* representation of this container's root: + - top-level arg of a kernel: ``arg_meta.name`` (e.g. ``"state"``, ``"self"``) — no ``__qd_`` prefix. + - any nested level: the already-computed ``__qd_…`` flat name. + + Returns ``None`` when ``parent_flat`` itself is ``None``, indicating "no path info available" — the caller + must walk the child unconditionally (i.e. ignore ``pruning_paths`` for this branch). + """ + if parent_flat is None: + return None + return create_flat_name(parent_flat, child_name) + + +def _is_path_used(pruning_paths: set[str] | None, child_flat: str | None) -> bool: + """Return True if a child at ``child_flat`` should be hashed. + + - ``pruning_paths is None``: pre-pruning-info compile — hash everything. + - ``child_flat is None``: caller could not compute a flat-name path (no parent_flat available) — hash + everything as well, so we never accidentally drop a child we couldn't classify. + - both non-None: only hash children whose flat name is in the set. Pruning's prefix-expansion step in + ``Kernel.materialize`` guarantees that if any descendant of ``__qd_a__qd_b`` is used, ``__qd_a__qd_b`` + itself is also in the set, so this single membership check is sufficient to decide whether to descend. + """ + if pruning_paths is None or child_flat is None: + return True + return child_flat in pruning_paths + + +def dataclass_to_repr( + raise_on_templated_floats: bool, + path: tuple[str, ...], + arg: Any, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Hash a dataclass instance, optionally narrowed by pruning information. + + Returns ``_FAIL_FASTCACHE`` if any field's subtree hits a recognised-but-unsupported tensor type + (``ScalarField`` / ``MatrixField``); otherwise a string. + + Pruning: if ``pruning_paths`` is non-None, only descend into fields whose flat name is in the set. Pruning's + prefix-expansion step ensures the set already contains all ancestors of used leaves, so checking the + immediate child's flat name is sufficient. + """ + # PERF: For frozen dataclasses the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``None`` is stored as the sentinel ``_DC_REPR_NONE`` to distinguish "not yet computed" from - # "computed but not fast-cacheable". + # A cached ``_DC_REPR_NONE`` sentinel distinguishes "computed but not fast-cacheable" from "not yet computed". + # + # The cache is keyed by ``(is_frozen, pruning_paths is None)`` because a frozen dataclass's pruned repr + # depends on the pruning_paths set — we use separate cache slots for pruned vs unpruned to avoid serving + # the wrong narrowing. + cache_attr = "_qd_dc_repr" if pruning_paths is None else "_qd_dc_repr_narrow" is_frozen = type(arg).__hash__ is not None if is_frozen: - cached = getattr(arg, "_qd_dc_repr", None) + cached = getattr(arg, cache_attr, None) if cached is _DC_REPR_NONE: - return None - if cached is not None: + return _FAIL_FASTCACHE + if cached is not None and pruning_paths is None: + # Narrow cache may be stale if pruning_paths set changed; only reuse the unpruned cache. return cached repr_l = [] for field in dataclasses.fields(arg): + child_flat = _child_flat(parent_flat, field.name) + if not _is_path_used(pruning_paths, child_flat): + continue child_value = getattr(arg, field.name) - _repr = stringify_obj_type(raise_on_templated_floats, path + (field.name,), child_value, arg_meta=None) - if _repr is None: + _repr = stringify_obj_type( + raise_on_templated_floats, + path + (field.name,), + child_value, + arg_meta=None, + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _repr is _FAIL_FASTCACHE: if isinstance(child_value, _FIELD_TYPES) and field.type is not _TensorWrapper: _mark_should_warn() if is_frozen: try: - object.__setattr__(arg, "_qd_dc_repr", _DC_REPR_NONE) + object.__setattr__(arg, cache_attr, _DC_REPR_NONE) except AttributeError: pass - return None + return _FAIL_FASTCACHE full_repr = f"{field.name}: ({_repr})" if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False): full_repr += f" = {child_value}" repr_l.append(full_repr) result = "[" + ",".join(repr_l) + "]" - if is_frozen: + if is_frozen and pruning_paths is None: try: - object.__setattr__(arg, "_qd_dc_repr", result) + object.__setattr__(arg, cache_attr, result) except AttributeError: pass return result @@ -135,36 +237,49 @@ def _is_template(arg_meta: ArgMetadata | None) -> bool: def stringify_obj_type( - raise_on_templated_floats: bool, path: tuple[str, ...], obj: object, arg_meta: ArgMetadata | None -) -> str | None: - """ - Convert an object into a string representation that only depends on its type. - - String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have - to be the actual python type string, just a string that is representative of the type, and won't collide - with different (allowed) types. String should be non-empty. - - Note that fields are not included in fast cache. - - arg_meta should only be non-None for the top level arguments and for data oriented objects. It is - used currently to determine whether a value is added to the cache key, as well as the name. eg - - at the top level, primitive types have their values added to the cache key if their annotation is qd.Template, - since they are baked into the kernel - - in data oriented objects, the values of all primitive types are added to the cache key, since they are baked - into the kernel, and require a kernel recompilation, when they change + raise_on_templated_floats: bool, + path: tuple[str, ...], + obj: object, + arg_meta: ArgMetadata | None, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Convert ``obj`` into a deterministic string that contributes to the fastcache key. + + Return contract: + - ``str``: hashable; the returned string contributes to the cache key. + - ``_FAIL_FASTCACHE``: a recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``) + was encountered. Containers must propagate this upward; fastcache is disabled for the whole call. + + For *every other* unrecognised type, this function falls back to a deterministic + ``type(obj).__qualname__``-based string (see ``_qualname_fallback``). The pre-refactor design returned + ``None`` and disabled fastcache for any unrecognised member type, which made adding a UUID or Pydantic + config object to a ``@qd.data_oriented`` ``self`` silently kill fastcache. The qualname fallback captures + type identity (sufficient for genuinely opaque metadata — kernels cannot read non-recognised Python types + so opaque metadata cannot affect codegen) and warns once per unrecognised type so any future tensor-like + addition that *does* need explicit handling gets noticed. + + Parameters: + - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. + Determines whether primitive values are baked into the cache key (template-position primitives and + all primitive members of data-oriented containers). + - ``pruning_paths``: optional set of kernel-accessed flat names. When provided, ``dataclass_to_repr`` and + the ``data_oriented`` branch below descend only into children whose flat name is in the set. Skipped + children are *guaranteed* not to affect kernel codegen (the kernel never reads them), so omitting them + from the hash is safe by construction. + - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the + ``self`` arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's + flat name for the narrow-walk lookup. """ - # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` strips - # wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / data-oriented - # walkers below (``dataclass_to_repr`` and the ``is_data_oriented`` branch) do raw ``getattr`` to fetch struct - # fields, so a wrapper stored as a struct field arrives here un-stripped. Without this branch the hasher falls - # through to the ``[FASTCACHE][PARAM_INVALID]`` warning and disables the fast path for the whole call. See - # ``perso_hugh/doc/quadrants-tensor.md`` §8.14. - # ``qd.Tensor`` wrappers: unwrap to the bare impl so the type checks below match. After unwrap, ``_qd_layout`` (if - # any) is on the impl. + # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` + # strips wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / + # data-oriented walkers below do raw ``getattr`` to fetch struct fields, so a wrapper stored as a struct field + # arrives here un-stripped. Without this branch the hasher would hash the wrapper as an unknown type instead + # of unwrapping to the recognised impl. See ``perso_hugh/doc/quadrants-tensor.md`` §8.14. # - # PERF-CRITICAL: The _any_tensor_constructed guard makes this check zero-cost when no qd.Tensor has been created. - # ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer comparison (~10 - # ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. + # PERF-CRITICAL: the ``_any_tensor_constructed`` guard makes this check zero-cost when no ``qd.Tensor`` has + # been created. ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a + # pointer comparison (~10 ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. if ( _tensor_wrapper._any_tensor_constructed and type(obj) in _TENSOR_WRAPPER_TYPES ): # pyright: ignore[reportOptionalMemberAccess] @@ -177,12 +292,11 @@ def stringify_obj_type( if isinstance(obj, VectorNdarray): return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): - # disabled for now, because we need to think about how to handle field offset - # etc + # Recognised-but-unsupported: shape/dtype affect kernel codegen but fastcache doesn't yet hash them. + # Disable fastcache for the whole call. # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - _mark_hit_recognised_unsupported() - return None + return _FAIL_FASTCACHE if isinstance(obj, MatrixNdarray): return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): @@ -190,78 +304,47 @@ def stringify_obj_type( if isinstance(obj, np.ndarray): return f"[np-{obj.dtype}-{obj.ndim}]" if isinstance(obj, MatrixField): - # disabled for now, because we need to think about how to handle field offset - # etc + # Recognised-but-unsupported, same as ScalarField above. # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - _mark_hit_recognised_unsupported() - return None + return _FAIL_FASTCACHE if is_dataclass_instance(obj): - return dataclass_to_repr(raise_on_templated_floats, path, obj) + return dataclass_to_repr( + raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat + ) if is_data_oriented(obj): - # ``@qd.data_oriented(stable_members=True)``: the class promises its member *set* and *types* don't change - # after construction. Under that contract, unrecognised member types (e.g. Genesis's ``RigidSolver._uid`` of - # type ``genesis.utils.uid.UID``, or any other opaque metadata) are treated as inert from fastcache's - # perspective: they don't affect kernel codegen so they can be skipped silently rather than killing fastcache - # for the whole call. Without this, migrating a kernel from a standalone ``@qd.kernel`` function to a method - # on a ``@qd.data_oriented`` class disables fastcache the moment the class holds any opaque metadata, even - # though the kernel's compiled output would be identical. - stable_members = bool(type(obj).__dict__.get("_qd_stable_members")) + # Walk the data_oriented container's members. Same narrow-walk semantics as ``dataclass_to_repr``: + # if ``pruning_paths`` is provided, only descend into children whose flat name is in the set; otherwise + # walk every attr. Recognised-but-unsupported (Field/MatrixField) anywhere in a child's subtree + # propagates ``_FAIL_FASTCACHE`` upward. child_repr_l = ["da"] - _dict = {} try: - # pyright is ok with this approach _asdict = getattr(obj, "_asdict") _dict = _asdict() except AttributeError: _dict = obj.__dict__ - global _skip_unknown_warn_depth # pylint: disable=global-statement - if stable_members: - _skip_unknown_warn_depth += 1 - try: - for k, v in _dict.items(): - # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` - # stashes the per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so - # that subsequent ``instance.method`` lookups skip the descriptor allocation; - # those entries are not data and must not invalidate the fastcache key. - v_type = type(v) - if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: - continue - # Snapshot the recognised-but-unsupported flag around the recursive call so we can tell whether - # *this child's* subtree hit a ``Field`` / ``MatrixField`` (in which case we must fail fastcache - # even under ``stable_members``). - global _hit_recognised_unsupported # pylint: disable=global-statement - _hit_recognised_unsupported = False - _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) - child_hit_field = _hit_recognised_unsupported - if _child_repr is None: - # Differentiate two reasons ``stringify_obj_type`` returns None: - # - # (a) RECOGNISED-BUT-UNSUPPORTED: ``Field`` / ``MatrixField`` somewhere in this child's - # subtree. These are *known* tensor-like types whose values affect kernel codegen but - # which fastcache doesn't yet handle. Killing fastcache for the whole call is the - # intended contract — ``test_num_envs[False-...]`` pins this behaviour for the field - # backend. - # (b) TRULY-OPAQUE: anything that falls through to the ``[FASTCACHE][PARAM_INVALID]`` - # warning at the bottom of ``stringify_obj_type`` (``RigidSolver._uid`` of type - # ``UID``, etc.). For ``stable_members=True`` containers, opaque metadata is inert by - # the user's contract and can be skipped without invalidating the hash for the rest - # of the members. - if stable_members and not child_hit_field: - continue - if _should_warn: - _logging.warn( - f"A kernel that has been marked as eligible for fast cache was passed 1 or more " - f"parameters that are not, in fact, eligible for fast cache: one of the parameters was a " - f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " - f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " - f"information, the path of the value was {path}." - ) - return None - child_repr_l.append(f"{k}: {_child_repr}") - finally: - if stable_members: - _skip_unknown_warn_depth -= 1 + for k, v in _dict.items(): + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the + # per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so subsequent ``instance.method`` + # lookups skip the descriptor allocation; those entries are not data and must not invalidate the + # fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + child_flat = _child_flat(parent_flat, k) + if not _is_path_used(pruning_paths, child_flat): + continue + _child_repr = stringify_obj_type( + raise_on_templated_floats, + (*path, k), + v, + ArgMetadata(Template, ""), + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _child_repr is _FAIL_FASTCACHE: + return _FAIL_FASTCACHE + child_repr_l.append(f"{k}: {_child_repr}") return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): if _is_template(arg_meta): @@ -278,28 +361,31 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - if _skip_unknown_warn_depth > 0: - # Inside a ``stable_members=True`` data_oriented walk: opaque members are tolerated by contract, so don't log - # the per-member ``[FASTCACHE][PARAM_INVALID]`` warning. The data_oriented walker reads the returned ``None`` - # and skips this member. - return None - _mark_should_warn() - # The bit in caps should not be modified without updating corresponding test - # The rest of free text can be freely modified - # (will probably formalize this in more general doc / contributor guidelines at some point) - _logging.warn( - f"[FASTCACHE][PARAM_INVALID] Parameter with path {path} and type {arg_type} not allowed by fast cache." - ) - return None + # Unrecognised type — fall back to deterministic qualname-based hash and warn once. See ``_qualname_fallback`` + # for the safety reasoning. + return _qualname_fallback(obj, path) def hash_args( - raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] + raise_on_templated_floats: bool, + args: Sequence[Any], + arg_metas: Sequence[ArgMetadata | None], + pruning_paths: set[str] | None = None, ) -> str | FastcacheSkip: - """Return the args hash string, or a HashFailure explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn, _hit_recognised_unsupported # pylint: disable=line-too-long + """Return the args hash string, or a ``FastcacheSkip`` explaining why hashing failed. + + Parameters: + - ``pruning_paths``: optional set of kernel-accessed flat names from the L1 cache (or freshly populated + after a cold compile). When provided, the container walkers skip children whose flat name is not in + the set; this both narrows the cache key (so unrelated metadata changes don't cause cache misses) and + eliminates the brittleness of walking opaque-typed members blindly. + + Fastcache is disabled (``FastcacheSkip`` returned) only when a recognised-but-unsupported tensor-like type + (``ScalarField`` / ``MatrixField``) is encountered. Truly-unrecognised types use a ``type(v).__qualname__`` + fallback so the cache key stays stable. + """ + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement _should_warn = False - _hit_recognised_unsupported = False g_num_calls += 1 g_num_args += len(args) hash_l = [] @@ -309,11 +395,23 @@ def hash_args( ) for i_arg, arg in enumerate(args): start = time.time() - _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg]) + arg_meta = arg_metas[i_arg] + # Top-level arg flat name: matches the kernel-side ``arg_meta.name`` (no ``__qd_`` prefix at the root). + # Used by the narrow walk to construct child flat names compatible with ``pruning.used_vars_by_func_id``. + top_flat = arg_meta.name if arg_meta is not None else None + _hash = stringify_obj_type( + raise_on_templated_floats, + (str(i_arg),), + arg, + arg_meta, + pruning_paths=pruning_paths, + parent_flat=top_flat, + ) g_repr_time += time.time() - start - if not _hash: + if _hash is _FAIL_FASTCACHE: g_num_ignored_calls += 1 return FastcacheSkip.WARN if _should_warn else FastcacheSkip.FIELD_VIA_TENSOR + # All other return values are valid strings (qualname fallback handles unrecognised types). hash_l.append(_hash) start = time.time() res = hash_iterable_strings(hash_l) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 711839cf5d..bc49b18a4e 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -142,10 +142,14 @@ def k1(foo: qd.Template) -> None: k1(foo=RandomClass()) _out, err = capfd.readouterr() - assert "[FASTCACHE][PARAM_INVALID]" in err + # Unrecognised types now fall back to a deterministic ``type(v).__qualname__`` hash (instead of silently + # disabling fastcache via the old ``[PARAM_INVALID]`` / ``[INVALID_FUNC]`` dead-end), and emit an + # ``[UNKNOWN_TYPE]`` warning once per type so a new tensor-like type added to Quadrants without explicit + # args-hasher handling still gets noticed in the logs. ``[PARAM_INVALID]`` is gone. + assert "[FASTCACHE][UNKNOWN_TYPE]" in err assert RandomClass.__name__ in err - assert "[FASTCACHE][INVALID_FUNC]" in err - assert k1.__name__ in err + assert "[FASTCACHE][PARAM_INVALID]" not in err + assert "[FASTCACHE][INVALID_FUNC]" not in err @qd.kernel def not_pure_k1(foo: qd.Template) -> None: @@ -153,8 +157,10 @@ def not_pure_k1(foo: qd.Template) -> None: not_pure_k1(foo=RandomClass()) _out, err = capfd.readouterr() + # Without ``@qd.pure``, fastcache is not active at all — neither the new UNKNOWN_TYPE nor the old + # PARAM_INVALID / INVALID_FUNC warnings should fire. + assert "[FASTCACHE][UNKNOWN_TYPE]" not in err assert "[FASTCACHE][PARAM_INVALID]" not in err - assert RandomClass.__name__ not in err assert "[FASTCACHE][INVALID_FUNC]" not in err assert k1.__name__ not in err From dce130562a71dd450cc7d858bd0bf315c48dd552 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 03:31:23 -0700 Subject: [PATCH 19/46] [Refactor] Fastcache: two-level cache + pruning-driven narrow args walk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the pre-refactor single-level cache (one key derived from source + config + a *wide* args walk) with a two-level pruning-driven scheme: - L1 key (``src_hasher.make_source_config_key``): source + config + version, no args dependence. Stores ``PruningInfo`` — the set of kernel-accessed flat names produced during compile (``Pruning``'s ``used_vars_by_func_id[KERNEL_FUNC_ID]``, folded with data_oriented ndarray attribute chains from ``struct_ndarray_launch_info``). Also persists ``graph_do_while_arg`` (source-deterministic). - L2 key (``src_hasher.make_full_cache_key``): L1 key + ``narrow_args_hash``. The narrow hash walks only paths in the L1 pruning set, so unrelated metadata changes on the same kernel-accessed surface no longer invalidate the cache. Lookup flow (warm call): L1 lookup → narrow args walk using L1 pruning info → L2 lookup → load artifact. Cold compile: L1 miss → full compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store L2. Crucially, "L1 hit but L2 miss" still triggers full pass 0+1 (not just pass 1): pass 0 is what populates per-callee-func pruning info, and L1 only stores the kernel-level set, so skipping pass 0 is only safe when the C++ artifact is already loaded (``only_parse_function_def=True``). Pruning narrowing rules in ``args_hasher.stringify_obj_type``: - Dataclass children: flat-name pruning is *complete* (every dataclass field is flattened by ``FlattenAttributeNameTransformer``), so narrow walking by ``child_flat in pruning_paths`` is safe. - Data_oriented children: pruning is only complete for ndarray members (via ``struct_ndarray_launch_info``). Primitive members (template-position values baked into the kernel) are NOT tracked by pruning. To stay correct, the data_oriented branch only narrows *ndarray* children; non-ndarray children are always walked (the recursive call still narrows nested dataclasses). This is why ``test_template_raise_on_data_oriented_floats`` and the dtype-distinct cache-key test both still pass: primitives keep contributing to the hash, only kernel-unused ndarrays get pruned. Behavior change: ``test_src_ll_cache_arg_warnings`` and ``test_fastcache_field_warnings_warn_struct_template_field`` updated to reflect that fastcache no longer fires ``[PARAM_INVALID]`` or ``[INVALID_FUNC]`` for unrecognised types at the *top level* (qualname fallback from previous commit handles them) or for Field-bearing *unused* dataclass members (narrowing skips them). Tests now exercise the genuinely-live cases. ``test_src_hasher_*`` updated to use the new ``make_source_config_key`` / ``make_full_cache_key`` API. stable_members is no longer read by the args hasher (handled by previous commit); its launch-context role in ``Kernel.launch_kernel`` still uses the legacy flag and will be addressed in a follow-up. --- .../lang/_fast_caching/args_hasher.py | 29 ++- .../lang/_fast_caching/src_hasher.py | 207 ++++++++++++++---- python/quadrants/lang/kernel.py | 173 +++++++++++++-- .../test_fastcache_field_warnings.py | 12 +- .../lang/fast_caching/test_src_hasher.py | 23 +- 5 files changed, 375 insertions(+), 69 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 1889a1e164..14b854c967 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -313,10 +313,22 @@ def stringify_obj_type( raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat ) if is_data_oriented(obj): - # Walk the data_oriented container's members. Same narrow-walk semantics as ``dataclass_to_repr``: - # if ``pruning_paths`` is provided, only descend into children whose flat name is in the set; otherwise - # walk every attr. Recognised-but-unsupported (Field/MatrixField) anywhere in a child's subtree - # propagates ``_FAIL_FASTCACHE`` upward. + # Walk the data_oriented container's members. Narrowing rules differ from ``dataclass_to_repr``: + # + # Pruning info for data_oriented containers is *only complete for ndarray members*: the kernel-compile + # path records each kernel-accessed ndarray's structural attribute chain in + # ``struct_ndarray_launch_info``, which ``Kernel._fold_struct_nd_paths_into_pruning`` folds into the + # flat-name pruning set. Non-ndarray attribute accesses on data_oriented args (``self.an_int``, + # ``self.a_float`` — values that get baked into the kernel at compile time) are *not* tracked anywhere + # as pruning input (data_oriented args aren't run through ``FlattenAttributeNameTransformer``). + # + # If we naively applied flat-name pruning to *every* child, an unused-but-present opaque member would + # match (silently dropped → safe), a kernel-read primitive member would silently disappear from the hash + # (BAD — its value affects codegen and we'd serve a stale cached compile when the value changes), and + # the templated-float raise-guard would also stop firing. + # + # Conservative fix: only narrow *ndarray* children. For everything else, walk unconditionally. The + # recursive call still applies narrowing to nested dataclasses (where flat-name tracking IS complete). child_repr_l = ["da"] try: _asdict = getattr(obj, "_asdict") @@ -332,7 +344,14 @@ def stringify_obj_type( if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: continue child_flat = _child_flat(parent_flat, k) - if not _is_path_used(pruning_paths, child_flat): + # ndarray-only pruning narrowing — see the comment at the top of this branch for why other types + # cannot be safely narrowed here. + if ( + pruning_paths is not None + and child_flat is not None + and child_flat not in pruning_paths + and isinstance(v, (ScalarNdarray, VectorNdarray, MatrixNdarray)) + ): continue _child_repr = stringify_obj_type( raise_on_templated_floats, diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 1c03bf737b..a14a1a7d1b 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,3 +1,47 @@ +"""Two-level fastcache key derivation and persistence. + +Background (pre-refactor) +------------------------- +Fastcache used a single cache key derived from source + config + a *wide* args hash that walked every member +of every container argument. That walk was brittle: + + - Encountering any unrecognised type silently disabled fastcache (``[FASTCACHE][PARAM_INVALID]`` warning + + ``None`` return); a single Genesis ``RigidSolver._uid`` member killed the cache for the whole solver. + + - Working around it via ``@qd.data_oriented(stable_members=True)`` opt-in only swapped one brittleness for + another: a new tensor-like type added later but missed in args_hasher's recognised set would be silently + skipped, serving stale cached results. + +Both fundamentally stem from the wide walk *blindly* visiting paths the kernel never reads. The pre-refactor +design had no way to know which paths actually mattered before compile. + +Two-level cache +--------------- +The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so +the args hash can walk *only* paths the kernel reads: + + - L1 (this module's ``make_source_config_key`` + ``load_pruning_info`` / ``store_pruning_info``): + keyed by source+config only (no args). Stores ``PruningInfo`` — the set of kernel-accessed flat names + (e.g. ``__qd_state__qd_x``) plus the ``graph_do_while_arg`` (also a kernel-source property). + + - L2 (``make_full_cache_key`` + ``load_full`` / ``store_full``): keyed by L1 key + the *narrow* args hash + computed with pruning info from L1. Stores the C++ ``frontend_cache_key`` that names the compiled + artifact. + +Lookup flow on a warm call: L1 lookup → narrow args hash (paths from L1) → L2 lookup → load artifact. + +Cold compile flow: L1 miss → cold compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store +L2. + +Safety implication +------------------ +A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect +kernel codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go +through ``args_hasher.stringify_obj_type`` which falls back to a ``type(v).__qualname__``-based string for +unrecognised types and emits a one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning, so a missed type registration +is impossible to miss but doesn't silently disable fastcache. +""" + import json import os import warnings @@ -18,47 +62,136 @@ from .python_side_cache import PythonSideCache -def create_cache_key( +# Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to +# the same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had +# no such prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from +# prior Quadrants installs are simply ignored rather than mis-served. +_L1_MARKER = "l1" +_L2_MARKER = "l2" + + +def make_source_config_key(kernel_source_info: FunctionSourceInfo) -> str: + """Build the L1 cache key: source + config + version, with no dependence on args. + + Used by ``_try_load_fastcache`` before any args walking. The same key drives ``load_pruning_info`` / + ``store_pruning_info``; the matching ``make_full_cache_key`` derives the L2 key from this plus the narrow + args hash. + """ + kernel_hash = function_hasher.hash_kernel(kernel_source_info) + config_hash = config_hasher.hash_compile_config() + return hash_iterable_strings( + ( + _L1_MARKER, + quadrants.__version_str__, + kernel_hash, + config_hash, + kernel_source_info.filepath, + str(kernel_source_info.start_lineno), + "pruned", + "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", + ) + ) + + +def make_full_cache_key(source_config_key: str, narrow_args_hash: str) -> str: + """Build the L2 cache key from the L1 key + narrow args hash. See module docstring.""" + return hash_iterable_strings((_L2_MARKER, source_config_key, narrow_args_hash)) + + +def compute_narrow_args_hash( raise_on_templated_floats: bool, kernel_source_info: FunctionSourceInfo, args: Sequence[Any], arg_metas: Sequence[ArgMetadata], + pruning_paths: set[str] | None, ) -> str | None: + """Compute the args hash narrowed by ``pruning_paths`` (or wide if ``pruning_paths is None``). + + Returns ``None`` if a recognised-but-unsupported tensor-like type forces fastcache off — the caller emits + the appropriate user-visible diagnostic via the ``FastcacheSkip.WARN`` branch. """ - cache key takes into account: - - arg types - - cache value arg values - - kernel function (but not sub functions) - - compilation config (which includes arch, and debug) - """ - args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas) + args_hash = args_hasher.hash_args( + raise_on_templated_floats, args, arg_metas, pruning_paths=pruning_paths + ) if isinstance(args_hash, FastcacheSkip): if args_hash is FastcacheSkip.WARN: - # the bit in caps at start should not be modified without modifying corresponding text - # freetext bit can be freely modified _logging.warn( f"[FASTCACHE][INVALID_FUNC] The pure function {kernel_source_info.function_name} could not be " "fast cached, because one or more parameter types were invalid" ) return None - kernel_hash = function_hasher.hash_kernel(kernel_source_info) - config_hash = config_hasher.hash_compile_config() - cache_key = hash_iterable_strings( - ( - quadrants.__version_str__, - kernel_hash, - args_hash, - config_hash, - kernel_source_info.filepath, - str(kernel_source_info.start_lineno), - "pruned", - "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", - ) + return args_hash + + +class L1CacheValue(BaseModel): + """Persisted L1 entry — pruning info that's source-and-config-deterministic (not args-dependent). + + Pruning info is the set of *flat names* (``__qd___qd___qd_…``) that the kernel actually reads. + Computed during compile (``Pruning.used_vars_by_func_id``); persisted here so subsequent calls can build + a narrow args hash without having to recompile. + + ``graph_do_while_arg`` is also stored here because it's a property of the kernel source (not of any + particular arg value). + + ``hashed_function_source_infos`` is the same content-hash list used for L2 validation; an L1 hit is + rejected if any helper source has changed since the L1 entry was written, even if the kernel source + itself hasn't (kernel_hash only covers the entry point). + """ + + used_py_dataclass_parameters: set[str] + hashed_function_source_infos: list[HashedFunctionSourceInfo] + graph_do_while_arg: str | None = None + + +def store_pruning_info( + source_config_key: str, + function_source_infos: Iterable[FunctionSourceInfo], + used_py_dataclass_parameters: set[str], + graph_do_while_arg: str | None = None, +) -> None: + """Persist the L1 entry after a cold compile. See ``L1CacheValue`` for what's stored / why.""" + if not source_config_key: + return + cache = PythonSideCache() + hashed_function_source_infos = function_hasher.hash_functions(function_source_infos) + cache_value = L1CacheValue( + used_py_dataclass_parameters=used_py_dataclass_parameters, + hashed_function_source_infos=list(hashed_function_source_infos), + graph_do_while_arg=graph_do_while_arg, ) - return cache_key + cache.store(source_config_key, cache_value.model_dump_json()) + + +def load_pruning_info( + source_config_key: str, +) -> tuple[set[str], str | None] | tuple[None, None]: + """Look up L1 cache. Returns (pruning_paths, graph_do_while_arg) on hit, (None, None) on miss / invalid. + + Validates ``hashed_function_source_infos`` against the current on-disk source; if any helper has changed + since the entry was written, the entry is invalid and we treat the lookup as a miss so the caller does a + cold compile (which will overwrite the stale L1 entry). + """ + cache = PythonSideCache() + maybe_value_json = cache.try_load(source_config_key) + if maybe_value_json is None: + return None, None + try: + cache_value = L1CacheValue.model_validate_json(maybe_value_json) + except (pydantic.ValidationError, json.JSONDecodeError, UnicodeDecodeError) as e: + warnings.warn(f"Failed to parse L1 cache entry: {e}") + return None, None + if not function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos): + return None, None + return cache_value.used_py_dataclass_parameters, cache_value.graph_do_while_arg class CacheValue(BaseModel): + """Persisted L2 entry — frontend cache key for the compiled artifact + source-validation metadata. + + The full pruning info is duplicated here for backward-compat with existing on-disk caches; it's the same + set that L1 also stores. The L1 set is the source of truth for narrowing the args hash on warm calls. + """ + frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] @@ -72,22 +205,10 @@ def store( used_py_dataclass_parameters: set[str], graph_do_while_arg: str | None = None, ) -> None: - """ - Note that unlike other caches, this cache is not going to store the actual value we want. - This cache is only used for verification that our cache key is valid. Big picture: - - we have a cache key, based on args and top level kernel function - - we want to use this to look up LLVM IR, in C++ side cache - - however, before doing that, we first want to validate that the source code didn't change - - i.e. is our cache key still valid? - - the python side cache contains information we will use to verify that our cache key is valid - - ie the list of function source infos - - Update! We are now going to store parameter pruning infomation, which is: - - used_py_dataclass_parameters: set[str] - - Update 2: we are going to store the cache key used by the c++ kernel cache, so that we can use that - to retrieve the immutable cached c++ kernel later, rather than, before, we were storing the c++ - cached kernel using the fast cache key, leading to bugs, when cached kernel file then had to be mutable. + """Persist the L2 entry — the C++ frontend cache key that names the compiled artifact for this call. + + ``fast_cache_key`` is the L2 key from ``make_full_cache_key``. The L1 entry has typically been stored + earlier by ``store_pruning_info`` during the same materialize. """ if not fast_cache_key: return @@ -117,9 +238,9 @@ def _try_load(cache_key: str) -> CacheValue | None: def load(cache_key: str) -> tuple[set[str], str, str | None] | tuple[None, None, None]: - """ - loads function source infos from cache, if available - checks the hashes against the current source code + """Look up L2 cache. Returns (used_pruning_paths, frontend_cache_key, graph_do_while_arg) on hit. + + Validates helper-source hashes against the live source; an L2 entry is invalidated if any helper changed. """ cache_value = _try_load(cache_key) if cache_value is None: diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index e5865da677..92b2c311df 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -60,6 +60,7 @@ def _kernel_coverage_enabled() -> bool: from quadrants.types.enums import AutodiffMode from quadrants.types.utils import is_signed +from ._dataclass_util import create_flat_name from ._func_base import FuncBase from ._kernel_types import ( ArgsHash, @@ -340,23 +341,60 @@ def reset(self) -> None: self.fe_ll_cache_observations = FeLlCacheObservations() def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType") -> set[str] | None: - frontend_cache_key: str | None = None + """Two-phase fastcache lookup. + + Phase 1 — L1 lookup keyed by source+config only (no args). Returns the set of kernel-accessed flat + names (pruning info). Hit OR miss, this only determines whether we have pruning info for the narrow + args walk; it never on its own justifies skipping pass 0 — that requires the C++ artifact to load. + + Phase 2 — narrow args walk + L2 lookup + artifact load. Only when *all three* succeed do we return + non-None and let ``materialize`` skip pass 0. The reason: pass 0 is what populates pruning info for + *every called ``@qd.func``* (not just the kernel itself). Skipping pass 0 is only safe when pass 1 + runs in ``only_parse_function_def`` mode (i.e. the C++ artifact is already loaded so the AST walker + never enters any callee body); otherwise callee variables can't be found in their func's empty + ``used_vars_by_func_id`` set and the build fails with "Name __qd_… is not defined". + + Side effects: populates ``self._l1_key`` (always when fastcache is active), ``self._pruning_paths_from_l1`` + (the L1 pruning info, or None if L1 miss — used by ``materialize`` for L1-store skipping and for + post-compile narrow-hash construction), and ``self.fast_checksum`` (the L2 key, when phase 2 computed + the narrow args hash). All three are read by the post-compile path in ``_maybe_persist_l1_and_set_l2_key``. + """ + self._l1_key = None # type: ignore[attr-defined] + self._pruning_paths_from_l1 = None # type: ignore[attr-defined] + self.fast_checksum = None if self.runtime.src_ll_cache and self.quadrants_callable and self.quadrants_callable.is_pure: kernel_source_info, _src = get_source_info_and_src(self.func) - self.fast_checksum = src_hasher.create_cache_key( - self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas + self._kernel_source_info_cached = kernel_source_info # reused by materialize / launch_kernel + self._l1_key = src_hasher.make_source_config_key(kernel_source_info) + + # Phase 1: L1 lookup — pruning info only, no args walk yet. + pruning_paths, cached_graph_do_while_arg = src_hasher.load_pruning_info(self._l1_key) + if pruning_paths is None: + # Cold L1. ``materialize`` will compile pass 0 + pass 1 to populate pruning info, then we + # store L1 + L2 after compile. ``cache_key_generated`` is intentionally NOT flipped to True + # here: it tracks "fastcache produced a valid L2 args hash" (the pre-refactor semantic), and + # we don't know yet whether the narrow args walk will succeed. + return None + self._pruning_paths_from_l1 = pruning_paths + + # Phase 2: narrow args hash + L2 lookup. + narrow_args_hash = src_hasher.compute_narrow_args_hash( + self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas, pruning_paths + ) + if narrow_args_hash is None: + # Recognised-but-unsupported tensor-like (Field / MatrixField) — fastcache off for this call. + # ``self.fast_checksum`` stays None so no L2 entry is written; ``cache_key_generated`` stays + # False to match the pre-refactor "Field disables fastcache key generation" contract. + return None + self.fast_checksum = src_hasher.make_full_cache_key(self._l1_key, narrow_args_hash) + self.src_ll_cache_observations.cache_key_generated = True + + used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg_l2 = src_hasher.load( + self.fast_checksum ) - used_py_dataclass_parameters = None - cached_graph_do_while_arg: str | None = None - if self.fast_checksum: - self.src_ll_cache_observations.cache_key_generated = True - used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg = src_hasher.load( # type: ignore[reportAssignmentType] - self.fast_checksum - ) if used_py_dataclass_parameters is not None and frontend_cache_key is not None: self.src_ll_cache_observations.cache_validated = True prog = impl.get_runtime().prog - assert self.fast_checksum is not None self.compiled_kernel_data_by_key[key] = prog.load_fast_cache( frontend_cache_key, self.func.__name__, @@ -366,9 +404,13 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True self.used_py_dataclass_parameters_by_key_enforcing[key] = used_py_dataclass_parameters - if cached_graph_do_while_arg is not None: - self.graph_do_while_arg = cached_graph_do_while_arg + self.graph_do_while_arg = cached_graph_do_while_arg_l2 or cached_graph_do_while_arg return used_py_dataclass_parameters + # L2 miss or artifact load failed: report cold so ``materialize`` does pass 0 + pass 1 (needed + # to populate per-callee pruning info). ``self.fast_checksum`` is still set so the post-compile + # ``src_hasher.store`` will write a fresh L2 entry under the narrow-args key. + self.graph_do_while_arg = cached_graph_do_while_arg or self.graph_do_while_arg + return None elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test @@ -381,7 +423,6 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, ...], arg_features=None): if key is None: key = (self.func, 0, self.autodiff_mode) - self.fast_checksum = None if key in self.materialized_kernels: return @@ -438,6 +479,16 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . self._struct_ndarray_launch_info_by_key[key] = getattr( ctx.global_context, "struct_ndarray_launch_info", [] ) + # Fold data_oriented ndarray attribute chains into the kernel's used-flat-names set so + # ``args_hasher.hash_args`` can narrow data_oriented walks too. ``used_vars_by_func_id`` + # only contains flat names from dataclass-arg expansion in + # ``extract_struct_locals_from_context``; data_oriented args don't go through that + # expansion, so accesses like ``self.x`` on an ndarray member are only tracked via + # ``struct_ndarray_launch_info``. Without this fold, narrow hashing for data_oriented + # args walks nothing — every (arg_idx, attr_chain) pair gets the same hash regardless + # of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the + # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). + self._fold_struct_nd_paths_into_pruning(key, pruning) else: for used_parameters in pruning.used_vars_by_func_id.values(): new_used_parameters = set() @@ -455,6 +506,100 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ] runtime._current_global_context = None + # Post-compile fastcache bookkeeping. See ``_maybe_persist_l1_and_set_l2_key`` docstring. + self._maybe_persist_l1_and_set_l2_key(key, py_args) + + def _fold_struct_nd_paths_into_pruning( + self, key: "CompiledKernelKeyType", pruning: Pruning + ) -> None: + """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat + name set so ``args_hasher.hash_args`` narrow-walks them correctly. + + Background: pruning's ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat + names produced by ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* + args. ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` in + the AST and don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths *are* + recorded — in ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — but only + for ndarray members. + + Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` + and union all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the + same convention used for dataclass args. + + Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at + compile, opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk + cannot distinguish "kernel reads this primitive" from "kernel does not read this primitive". The + ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking *all* + attrs of a data_oriented container — narrowing only suppresses subtrees explicitly absent from the + pruning set. So for a data_oriented arg with mostly-ndarray members, the cache key correctly + depends on the ndarray paths it uses; for one with primitive members whose values matter, those + members are still folded into the hash (qualname-fallback / value paths). + """ + nd_info = self._struct_ndarray_launch_info_by_key.get(key) + if not nd_info: + return + kernel_used: set[str] = pruning.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID] + for _arg_id_cpp, arg_idx, attr_chain in nd_info: + if arg_idx < 0 or arg_idx >= len(self.arg_metas): + continue + arg_name = self.arg_metas[arg_idx].name + if not arg_name: + continue + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + kernel_used.add(flat) + + def _maybe_persist_l1_and_set_l2_key( + self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...] + ) -> None: + """After a successful materialize, persist L1 (if missing) and set ``fast_checksum`` to the L2 key. + + Called at the end of ``materialize`` once both passes have completed (or once pass 1 has completed + with a loaded artifact). Two responsibilities: + + 1. If L1 was missing (``self._pruning_paths_from_l1 is None``), write the freshly-computed + pruning info so the next call from a new process can skip the args-walk warm-up. + + 2. If ``fast_checksum`` is still None (which means either L1 was missing, or L1 hit but phase 2 + of ``_try_load_fastcache`` saw a FIELD-related FastcacheSkip — in which case we keep ``None`` + and the post-compile ``src_hasher.store`` is skipped), compute the narrow args hash *now* + using the just-populated pruning info and derive the L2 key. The post-launch ``src_hasher.store`` + call uses ``self.fast_checksum`` as the L2 key. + + Side-effect helper; split out from ``materialize`` to keep the compile loop readable. + """ + l1_key = getattr(self, "_l1_key", None) + if not l1_key: + return # fastcache inactive for this kernel (not pure / no runtime.src_ll_cache) + kernel_source_info = getattr(self, "_kernel_source_info_cached", None) + if kernel_source_info is None: + return + used_params = self.used_py_dataclass_parameters_by_key_enforcing.get(key) + if used_params is None: + return + if getattr(self, "_pruning_paths_from_l1", None) is None: + src_hasher.store_pruning_info( + l1_key, + self.visited_functions, + used_params, + graph_do_while_arg=self.graph_do_while_arg, + ) + # If phase 2 didn't run (L1 cold) or returned None (FIELD encountered earlier — but in that case + # post-compile narrow hashing would also see the FIELD and produce None, which is fine: we want + # fast_checksum to stay None so no L2 entry is stored), compute the narrow args hash now. + if self.fast_checksum is None: + narrow_args_hash = src_hasher.compute_narrow_args_hash( + self.raise_on_templated_floats, + kernel_source_info, + py_args, + self.arg_metas, + used_params, + ) + if narrow_args_hash is not None: + self.fast_checksum = src_hasher.make_full_cache_key(l1_key, narrow_args_hash) + self.src_ll_cache_observations.cache_key_generated = True + def launch_kernel( self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None ) -> Any: diff --git a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py index a4bda2a1b1..b7ff707b60 100644 --- a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py +++ b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py @@ -160,7 +160,15 @@ def k(x: qd.Template): @test_utils.test(arch=qd.cpu) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows stderr not working with capfd") def test_fastcache_field_warnings_warn_struct_template_field(tmp_path, capfd): - """Struct with qd.Template-annotated field containing a Field — warning should fire.""" + """Struct with qd.Template-annotated field containing a Field — warning should fire when the field is + actually read by the kernel. + + Pruning-driven narrowing of args hashing only walks members the kernel reads; an unused dataclass field + cannot affect kernel codegen so it's correctly omitted from the hash (and from the + Field-disables-fastcache check). For the warning path to fire, the kernel must reference the Field — that + matches the user-visible contract that fastcache fails iff a "live" Field argument prevents safe + parametrisation. + """ qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) @dataclasses.dataclass(frozen=True) @@ -173,7 +181,7 @@ class S: @qd.pure @qd.kernel def k(x: S): - pass + x.a[0] = 1 capfd.readouterr() k(s) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index e7a2d9952b..bd0d81176e 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -23,6 +23,13 @@ @test_utils.test() def test_src_hasher_create_cache_key_vary_config() -> None: + """Source+config key (L1) is stable across re-init with identical config, changes when the config changes. + + Updated from the pre-refactor ``create_cache_key`` API (single-level, args-dependent) to the two-level + ``make_source_config_key`` (L1 — source+config only, no args). The L1 key is the right level to test + because config changes only affect the L1 layer; L2 adds the args-narrow hash on top. + """ + @qd.kernel def f1() -> None: pass @@ -31,15 +38,15 @@ def f1() -> None: # so we are forcing it to false each initialization for now qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_base = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_base = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_same = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_same = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False, random_seed=123) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_diff = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_diff = src_hasher.make_source_config_key(kernel_info) assert cache_key_base == cache_key_same assert cache_key_same != cache_key_diff @@ -103,7 +110,11 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc mod = temporary_module("child_diff_test_src_hasher_store_validate") kernel_info = get_fileinfos([mod.f1.fn])[0] fileinfos = get_fileinfos([mod.f1.fn, mod.f2.fn]) - fast_cache_key = src_hasher.create_cache_key(False, kernel_info, [], []) + # L2 key: source+config (L1) + narrow-args-hash. Use an empty narrow-args-hash since the test isn't + # exercising args at all — it tests the helper-source-change invalidation logic, which lives in L2. + fast_cache_key = src_hasher.make_full_cache_key( + src_hasher.make_source_config_key(kernel_info), narrow_args_hash="" + ) assert fast_cache_key is not None @@ -202,7 +213,9 @@ def src_hasher_vary_kernel_func_child(args: list[str]) -> None: sys.path.append(args_obj.module_file_path) mod = importlib.import_module(args_obj.module_name) info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) - cache_key = src_hasher.create_cache_key(False, info, [], []) + # Source+config key (L1) — varies with the *kernel source* (the property this test exercises) and is + # the same level as the pre-refactor ``create_cache_key`` call site, just without the args-dependent tail. + cache_key = src_hasher.make_source_config_key(info) print(f"CACHE_KEY={cache_key}") print(TEST_RAN) From 984ac40a181bdcfbf8958eede81953eb7de90f1e Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 03:40:46 -0700 Subject: [PATCH 20/46] [Doc] Fastcache: pruning-driven semantics; stable_members is launch-perf-only After the previous two commits, fastcache is no longer brittle wrt opaque members: the cache key is derived from kernel pruning info, and unrecognised types at kernel-read paths fall back to a deterministic type(v).__qualname__ hash with a one-shot [UNKNOWN_TYPE] warning. This commit aligns the user-visible docs (fastcache.md, compound_types.md) and the data_oriented(stable_members=...) docstring with that semantic. stable_members is documented as *purely* a launch-time perf hint with no fastcache role; the launch-context comments in kernel.py and _template_mapper.py are updated to call this out explicitly. Also fixes a pylint no-else-return warning introduced by the refactor. --- docs/source/user_guide/compound_types.md | 4 ++- docs/source/user_guide/fastcache.md | 36 +++++++++++++++++------ python/quadrants/lang/_template_mapper.py | 5 +++- python/quadrants/lang/kernel.py | 14 ++++----- python/quadrants/lang/kernel_impl.py | 21 ++++++++----- 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 7b942e4cab..b24f0de524 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -196,7 +196,9 @@ state.step() ### Fastcache -`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes, but is disabled for fields; see [Advanced — compound-type cache keying](fastcache.md#compound-type-cache-keying) for more information. +`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes. Cache keying is *pruning-driven*: only the members the kernel actually reads contribute to the cache key. Opaque metadata members (e.g. UUIDs, Pydantic config objects, back-pointers to parent solvers) do **not** disable fastcache and do **not** force cache misses — the kernel cannot read them so they cannot affect compiled code. See [Pruning-driven argument hashing](fastcache.md#pruning-driven-argument-hashing) for the full keying rules. + +`qd.field` / `MatrixField` members reached at a kernel-read path do disable fastcache for the call (recognised-but-unsupported tensor-like types). ### Under the hood diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 5d4e9381c8..5bbefed5f3 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -100,9 +100,12 @@ Fastcache supports the following parameter types: | `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) | | Non-template primitives (int, float, bool) | Yes | type only | | `enum.Enum` | Yes | name and value | -| `qd.field` / `ScalarField` / `MatrixField` | **No** | — | +| Anything else (Pydantic models, UUIDs, back-pointers, etc.) | Yes (degraded — type identity only) | `type(v).__qualname__` (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)) | +| `qd.field` / `ScalarField` / `MatrixField` at a kernel-read path | **No** | — | -If any parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. +For *recognised-but-unsupported* tensor-like types (`qd.field` / `ScalarField` / `MatrixField`) reached at a path the kernel actually reads, fastcache is disabled for that call and the kernel falls back to normal compilation. For these arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted; for other annotations a warning identifies the offending parameter. + +For *unrecognised* types reached at a kernel-read path, the hasher uses `type(v).__qualname__` as a stable type-identity hash and emits a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type. Kernel-unused members do not need to be recognised at all — the pruning narrowing skips them. ### 3. Source code must be available @@ -114,12 +117,25 @@ Each compiled artifact is stored under a key derived from all of the following: - The **Quadrants version** (`quadrants.__version__`). - The **source code** of the kernel function or any `@qd.func` it calls. -- The **argument types** (e.g. switching from `f32` to `f64`, or changing ndarray dimensionality). +- The **argument types at paths the kernel actually reads** (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing) below). - The **compiler configuration** (e.g. `arch`, `debug`, `opt_level`, `fast_math`). - **Template parameter values** (since they are baked into the compiled kernel). When any of these change, the resulting key is different, so a new compilation occurs and a new entry is stored. Previous entries remain on disk — multiple cached versions coexist. You do not need to manually clear the cache when making code changes — the hash mismatch causes a transparent recompilation. +### Pruning-driven argument hashing + +Fastcache uses a **two-level cache**: + +- **L1** (source + config only): stores the set of *flat names* the kernel actually reads — e.g. `__qd_state__qd_x` for a kernel that reads `state.x`. This is the kernel's pruning info, computed at compile time. +- **L2** (L1 + narrow argument hash): stores the compiled artifact under a key that only hashes the arg paths in the L1 pruning set. + +The practical implication: **kernel-unused members do not affect the cache key**. If a `@qd.data_oriented` container has an opaque metadata member (e.g. a UUID, a Pydantic config, a back-pointer to a parent solver), that member is *not* hashed because the kernel cannot read it — including it in the hash would only cause spurious cache misses without any safety benefit. The kernel reading `self.dofs_state.x` does not have its cache key disturbed by changes to `self._uid` or `self.cfg`. + +For paths the kernel *does* read but that contain an unrecognised type (a type fastcache has no explicit handling for in `quadrants/lang/_fast_caching/args_hasher.py`), the hasher falls back to a deterministic `type(v).__qualname__`-based string and emits a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type. This keeps the cache key stable across instances of the same opaque class while making any missed tensor-like registration impossible to overlook — if a future tensor type is added to Quadrants but not registered with the args hasher, the warning fires on the first call so the gap can be closed before stale cache results occur. + +`qd.field` / `ScalarField` / `MatrixField` are *recognised-but-unsupported*: their shape/dtype would affect codegen but fastcache doesn't yet know how to safely include them, so encountering one at a kernel-read path disables fastcache for the call (with a warn-level diagnostic). + ## Advanced ### Diagnostics @@ -147,13 +163,15 @@ On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules: -**`@qd.data_oriented`:** the walker descends into `vars(obj)`. For each child: +**`@qd.data_oriented`:** the walker descends into `vars(obj)`. Pruning narrowing applies to *ndarray children only* (the kernel-compile path tracks every kernel-accessed ndarray's structural path in `struct_ndarray_launch_info`); other children are always walked. For each walked child: -- `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not. -- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. -- Nested `@qd.data_oriented` member — recurses. -- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). -- `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. +- `qd.ndarray` member, kernel-read — `(dtype, ndim, layout)` is included in the cache key. Element values are not. +- `qd.ndarray` member, kernel-unused — *skipped*. Changes to its dtype/ndim/layout don't invalidate the cache. +- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. (Quadrants does not currently track per-primitive kernel-access info on data_oriented args, so primitive members are always hashed — see the safety note in [Pruning-driven argument hashing](#pruning-driven-argument-hashing).) +- Opaque member (Pydantic config, UUID, parent back-pointer, anything else fastcache doesn't recognise) — hashed by `type(v).__qualname__` only. The one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning fires on first encounter so a missing tensor-like registration in `args_hasher` is impossible to overlook. +- Nested `@qd.data_oriented` member — recurses (with these same rules). +- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below, including full flat-name pruning). +- `qd.field` member reached at a kernel-read path — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. A `qd.field` member at a kernel-*unused* path is simply skipped (no diagnostic). **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index 1a3db47ebd..ffb384cc18 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -28,7 +28,10 @@ def _classify_for_args_hash(arg: Any) -> "list[tuple] | None": """First-sighting classification for ``type(arg)`` in the args_hash walk. Returns the path list to walk (when the arg is a data_oriented container without ``_qd_stable_members`` that actually contains ndarrays), or ``None`` to - skip subsequent per-call work for this type.""" + skip subsequent per-call work for this type. + + ``_qd_stable_members`` here is a *launch-time perf hint only* (see ``@qd.data_oriented(stable_members=...)``). + It does not affect fastcache key derivation.""" if not is_data_oriented(arg): return None if type(arg).__dict__.get("_qd_stable_members"): diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 92b2c311df..f332e9d8f7 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -412,7 +412,7 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType self.graph_do_while_arg = cached_graph_do_while_arg or self.graph_do_while_arg return None - elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: + if self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test # freetext can be freely modified. # As for why we are using `print` rather than eg logger.info, it is because this is only printed when @@ -509,9 +509,7 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . # Post-compile fastcache bookkeeping. See ``_maybe_persist_l1_and_set_l2_key`` docstring. self._maybe_persist_l1_and_set_l2_key(key, py_args) - def _fold_struct_nd_paths_into_pruning( - self, key: "CompiledKernelKeyType", pruning: Pruning - ) -> None: + def _fold_struct_nd_paths_into_pruning(self, key: "CompiledKernelKeyType", pruning: Pruning) -> None: """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat name set so ``args_hasher.hash_args`` narrow-walks them correctly. @@ -550,9 +548,7 @@ def _fold_struct_nd_paths_into_pruning( flat = create_flat_name(flat, attr) kernel_used.add(flat) - def _maybe_persist_l1_and_set_l2_key( - self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...] - ) -> None: + def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...]) -> None: """After a successful materialize, persist L1 (if missing) and set ``fast_checksum`` to the L2 key. Called at the end of ``materialize`` once both passes have completed (or once pass 1 has completed @@ -626,7 +622,9 @@ def launch_kernel( # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated # with ``@qd.data_oriented(stable_members=True)``) promise their ndarray # members are never reassigned, so we exclude them from the per-call - # ``_resolve_struct_ndarray`` walk that builds ``args_hash``. + # ``_resolve_struct_ndarray`` walk that builds ``args_hash``. This is a + # *launch-time perf hint only* and has no fastcache role — fastcache derives + # its key from kernel-pruning info regardless of this flag. self._mutable_nd_cached_val = [ (idx, chain) for _, idx, chain in struct_nd_info diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 9270050e74..319a264980 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -300,13 +300,20 @@ def data_oriented(cls=None, *, stable_members: bool = False): Args: cls (Class): the class to be decorated. - stable_members (bool): if ``True``, declares that the class's ndarray-typed members are - allocated once and never reassigned between kernel calls. Quadrants will skip a - per-call walk of the instance's attributes (~1-2 us/call savings on Genesis-style - containers with several ndarray attrs). Reassigning a member on a ``stable_members`` - class is undefined behaviour — the previously-compiled kernel will be reused even if - the new ndarray has different dtype/ndim/layout. May also be set as a class-level - attribute ``_qd_stable_members = True`` (equivalent). + stable_members (bool): launch-context perf hint — if ``True``, declares that the class's + ndarray-typed members are allocated once and never reassigned between kernel calls. + Quadrants will skip the per-call ndarray-reference walk that ``Kernel.launch_kernel`` + uses to detect ndarray reassignment on mutable containers (~1-2 us/call savings on + Genesis-style containers with dozens of ndarray attrs). Reassigning a member on a + ``stable_members`` class is undefined behaviour — the previously-compiled kernel will + be reused even if the new ndarray has different dtype/ndim/layout. May also be set + as a class-level attribute ``_qd_stable_members = True`` (equivalent). + + Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache + argument hashing — the cache key is derived from pruning info (the set of flat names + the kernel actually reads), and unrecognised types at kernel-accessed paths fall back + to a deterministic ``type(v).__qualname__`` hash with a one-shot ``[UNKNOWN_TYPE]`` + warning. See ``docs/source/user_guide/fastcache.md``. Returns: The decorated class (or, when called with arguments, a decorator). From 12fb215dbcd686c3b5a230ff42c122211c1cba42 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 05:12:18 -0700 Subject: [PATCH 21/46] [Fix] Fastcache: full pruning coverage for data_oriented; remove qualname fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three rules now strictly enforced by the args hasher: 1. The cache key may only include contributions from kernel-pruned paths. Never a qualname-based hash for unrecognised types — that captures type identity without type parameters (dtype/shape) and would silently mask value-affecting changes. 2. Unrecognised types at kernel-read paths must not be silently dropped. Fastcache is disabled loudly with a one-shot [UNKNOWN_TYPE] warning plus [INVALID_FUNC] log line. 3. Fastcache works for data_oriented containers — pruning info now covers every attribute chain rooted at a kernel arg, not just ndarrays. Compiler-side: ASTTransformer.build_Name annotates non-flattened kernel-arg Names with ``_qd_arg_chain``; build_Attribute propagates the annotation through ``self.dofs.x`` chains and records them via the new ``Pruning.mark_kernel_arg_chain_used`` (separate set so they don't poison ``struct_locals`` and break codegen). ``Pruning.record_after_call`` was extended to propagate chain-path entries across @qd.func calls including Attribute args (``f(self.dofs)``). After both compile passes, ``Kernel._fold_kernel_arg_chain_paths_into_pruning`` merges the kernel's chain-paths into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (same set as ``used_py_dataclass_parameters_by_key_enforcing[key]`` by reference) so the fastcache args-hash narrow walk picks them up. Args-hasher side: removed the data_oriented ndarray-only carveout — ``_is_path_used(pruning_paths, child_flat)`` now applies to every member. Removed ``_qualname_fallback``; replaced with ``_fail_unknown_type`` which returns _FAIL_FASTCACHE and emits the [UNKNOWN_TYPE] warning. Tests + docs updated to match. Full x64 suite: 4063 passed. --- docs/source/user_guide/compound_types.md | 4 +- docs/source/user_guide/fastcache.md | 44 ++++-- .../lang/_fast_caching/args_hasher.py | 135 +++++++++--------- python/quadrants/lang/_pruning.py | 82 ++++++++++- python/quadrants/lang/ast/ast_transformer.py | 32 +++++ python/quadrants/lang/kernel.py | 30 ++++ .../lang/fast_caching/test_src_ll_cache.py | 12 +- tests/python/test_template_typing.py | 19 ++- 8 files changed, 263 insertions(+), 95 deletions(-) diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index b24f0de524..9c6a8aed88 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -196,9 +196,9 @@ state.step() ### Fastcache -`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes. Cache keying is *pruning-driven*: only the members the kernel actually reads contribute to the cache key. Opaque metadata members (e.g. UUIDs, Pydantic config objects, back-pointers to parent solvers) do **not** disable fastcache and do **not** force cache misses — the kernel cannot read them so they cannot affect compiled code. See [Pruning-driven argument hashing](fastcache.md#pruning-driven-argument-hashing) for the full keying rules. +`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes. Cache keying is *pruning-driven*: only the members the kernel actually reads contribute to the cache key. Opaque metadata members (e.g. UUIDs, Pydantic config objects, back-pointers to parent solvers) are skipped by the args-hasher's narrow walk as long as the kernel doesn't read them — they cannot affect compiled code and cannot cause spurious cache misses. See [Pruning-driven argument hashing](fastcache.md#pruning-driven-argument-hashing) for the full keying rules. -`qd.field` / `MatrixField` members reached at a kernel-read path do disable fastcache for the call (recognised-but-unsupported tensor-like types). +If the kernel *does* read a member of an unrecognised type, fastcache is disabled for the call with a `[FASTCACHE][UNKNOWN_TYPE]` diagnostic — there is no qualname-fallback. `qd.field` / `MatrixField` members at a kernel-read path likewise disable fastcache (recognised-but-unsupported tensor-like types). ### Under the hood diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 5bbefed5f3..4453bd9e47 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -96,16 +96,19 @@ Fastcache supports the following parameter types: | `torch.Tensor` | Yes | dtype, ndim | | `numpy.ndarray` | Yes | dtype, ndim | | `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | -| `@qd.data_oriented` objects | Yes | member types recursively; primitive member types and values baked into kernel (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | +| `@qd.data_oriented` objects | Yes | member types recursively, narrowed by pruning (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)); primitive member values baked into kernel | | `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) | | Non-template primitives (int, float, bool) | Yes | type only | | `enum.Enum` | Yes | name and value | -| Anything else (Pydantic models, UUIDs, back-pointers, etc.) | Yes (degraded — type identity only) | `type(v).__qualname__` (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)) | | `qd.field` / `ScalarField` / `MatrixField` at a kernel-read path | **No** | — | +| Anything else at a kernel-read path | **No** | — | -For *recognised-but-unsupported* tensor-like types (`qd.field` / `ScalarField` / `MatrixField`) reached at a path the kernel actually reads, fastcache is disabled for that call and the kernel falls back to normal compilation. For these arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted; for other annotations a warning identifies the offending parameter. +Two failure modes — both loud, never silent: -For *unrecognised* types reached at a kernel-read path, the hasher uses `type(v).__qualname__` as a stable type-identity hash and emits a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type. Kernel-unused members do not need to be recognised at all — the pruning narrowing skips them. +- **Recognised-but-unsupported** tensor-like types (`qd.field` / `ScalarField` / `MatrixField`) reached at a path the kernel actually reads → fastcache disabled for the call, kernel falls back to normal compilation. For these arriving through a `qd.Tensor`-annotated parameter the diagnostic is silent (normal usage); for other annotations a `[FASTCACHE][INVALID_FUNC]` log line identifies the offending parameter. +- **Unrecognised** types at a kernel-read path → fastcache disabled for the call, with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type identifying the offending class plus an `[INVALID_FUNC]` log line confirming the cache is off. To make a type cache-eligible, add explicit handling for it to `quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type`, or refactor the kernel so it does not read this member (pruning will then skip it). + +Kernel-unused members of any type — including unrecognised ones — do **not** disable fastcache. The pruning narrowing in the args hasher skips them entirely, so opaque metadata (UUIDs, Pydantic configs, parent back-pointers) attached to a `@qd.data_oriented` instance is harmless as long as the kernel doesn't read it. ### 3. Source code must be available @@ -127,12 +130,27 @@ When any of these change, the resulting key is different, so a new compilation o Fastcache uses a **two-level cache**: -- **L1** (source + config only): stores the set of *flat names* the kernel actually reads — e.g. `__qd_state__qd_x` for a kernel that reads `state.x`. This is the kernel's pruning info, computed at compile time. +- **L1** (source + config only): stores the set of *flat names* the kernel actually reads — e.g. `__qd_state__qd_x` for a kernel that reads `state.x`. This is the kernel's pruning info, computed at compile time by the AST builder. - **L2** (L1 + narrow argument hash): stores the compiled artifact under a key that only hashes the arg paths in the L1 pruning set. -The practical implication: **kernel-unused members do not affect the cache key**. If a `@qd.data_oriented` container has an opaque metadata member (e.g. a UUID, a Pydantic config, a back-pointer to a parent solver), that member is *not* hashed because the kernel cannot read it — including it in the hash would only cause spurious cache misses without any safety benefit. The kernel reading `self.dofs_state.x` does not have its cache key disturbed by changes to `self._uid` or `self.cfg`. +#### Two rules + +The args hasher enforces two strict invariants: + +1. **The cache key may only include contributions from kernel-pruned paths.** A path is "pruned" (in the pruning-info sense) if Quadrants's compiler recorded the kernel reading it. Pruning info covers: + - Dataclass-flattened param accesses (`__qd_some_dc__qd_field` Names produced by `FlattenAttributeNameTransformer`, marked via `build_Name`). + - Ndarray accesses on data_oriented / template args (`struct_ndarray_launch_info`, folded into the pruning set by `Kernel._fold_struct_nd_paths_into_pruning`). + - Any other attribute-chain access on a kernel arg (`self.dofs.x`, `cfg.n`, …) recorded by `ASTTransformer.build_Attribute`'s `_qd_arg_chain` tracking and folded by `Kernel._fold_kernel_arg_chain_paths_into_pruning`. This covers primitive members baked into the kernel, nested struct paths, and accesses through `@qd.func` callees (propagated by `Pruning.record_after_call`). + + Paths *not* in the pruning set are skipped by the args hasher — they are guaranteed not to affect kernel codegen because the kernel cannot read them. + +2. **Unrecognised types at kernel-read paths must not be silently dropped or hashed by type-name.** If pruning says the kernel reads a path and the value at that path is a type the args hasher doesn't explicitly handle (Pydantic models, UUIDs, third-party tensor wrappers, …), fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning identifying the offending type plus an `[INVALID_FUNC]` log line confirming the cache is off. There is no qualname-fallback — capturing type identity without type parameters (dtype/shape on a hypothetical tensor type) would silently mask a value-affecting change. + +#### Practical implications -For paths the kernel *does* read but that contain an unrecognised type (a type fastcache has no explicit handling for in `quadrants/lang/_fast_caching/args_hasher.py`), the hasher falls back to a deterministic `type(v).__qualname__`-based string and emits a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type. This keeps the cache key stable across instances of the same opaque class while making any missed tensor-like registration impossible to overlook — if a future tensor type is added to Quadrants but not registered with the args hasher, the warning fires on the first call so the gap can be closed before stale cache results occur. +- **Kernel-unused members do not affect the cache key.** If a `@qd.data_oriented` container has an opaque metadata member (UUID, Pydantic config, parent-solver back-pointer), and the kernel never reads it, the member is *not* hashed. Changes to `self._uid` or `self.cfg` don't disturb the cache key of a kernel that reads `self.dofs_state.x`. +- **Kernel-unused members of unrecognised types are also fine.** Pruning narrowing skips them before the type-recognition check runs. +- **Kernel-read members of unrecognised types fail fastcache loudly.** Either add explicit handling in `quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type` (for new tensor-like types whose dtype/shape matter), or move the access out of the kernel-read path (for opaque metadata that shouldn't be there in the first place). `qd.field` / `ScalarField` / `MatrixField` are *recognised-but-unsupported*: their shape/dtype would affect codegen but fastcache doesn't yet know how to safely include them, so encountering one at a kernel-read path disables fastcache for the call (with a warn-level diagnostic). @@ -163,15 +181,17 @@ On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules: -**`@qd.data_oriented`:** the walker descends into `vars(obj)`. Pruning narrowing applies to *ndarray children only* (the kernel-compile path tracks every kernel-accessed ndarray's structural path in `struct_ndarray_launch_info`); other children are always walked. For each walked child: +**`@qd.data_oriented`:** the walker descends into `vars(obj)`, narrowed by pruning info — *every* child (ndarray, primitive, opaque, nested struct) is subject to the pruning check. For each walked child: - `qd.ndarray` member, kernel-read — `(dtype, ndim, layout)` is included in the cache key. Element values are not. - `qd.ndarray` member, kernel-unused — *skipped*. Changes to its dtype/ndim/layout don't invalidate the cache. -- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. (Quadrants does not currently track per-primitive kernel-access info on data_oriented args, so primitive members are always hashed — see the safety note in [Pruning-driven argument hashing](#pruning-driven-argument-hashing).) -- Opaque member (Pydantic config, UUID, parent back-pointer, anything else fastcache doesn't recognise) — hashed by `type(v).__qualname__` only. The one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning fires on first encounter so a missing tensor-like registration in `args_hasher` is impossible to overlook. +- Primitive (`int` / `float` / `bool` / `enum.Enum`) member, kernel-read — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values that the kernel reads get different cache entries. +- Primitive member, kernel-unused — *skipped* (the kernel cannot read it so its value cannot affect codegen). - Nested `@qd.data_oriented` member — recurses (with these same rules). -- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below, including full flat-name pruning). -- `qd.field` member reached at a kernel-read path — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. A `qd.field` member at a kernel-*unused* path is simply skipped (no diagnostic). +- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). +- Opaque member (anything fastcache doesn't recognise), kernel-unused — *skipped*. +- Opaque member, kernel-read — fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning. To make the type cacheable, add explicit handling to `args_hasher.py::stringify_obj_type`. +- `qd.field` member, kernel-read — fastcache is disabled for the call (recognised-but-unsupported). A `qd.field` member at a kernel-*unused* path is simply skipped (no diagnostic). **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 14b854c967..77a2366d4c 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -42,14 +42,12 @@ _DC_REPR_NONE = object() -# Sentinel returned by ``stringify_obj_type`` when a recognised-but-unsupported tensor-like type (``ScalarField`` / -# ``MatrixField``) is encountered anywhere in the traversal. Containers (``dataclass_to_repr``, ``data_oriented`` -# branch, top-level ``hash_args`` loop) must propagate it upward — fastcache cannot safely hash the call because -# fields have shape/dtype that would affect kernel codegen but fastcache doesn't yet know how to include them. +# Sentinel returned by ``stringify_obj_type`` whenever fastcache cannot safely hash a value: +# - Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). +# - Unrecognised type at a kernel-read path (no qualname fallback — see rules in fastcache.md). # -# Distinct from any other return value: an unrecognised opaque type now falls back to a deterministic -# ``type(v).__qualname__`` string (see fallback in ``stringify_obj_type``), so the only way ``stringify_obj_type`` -# disables fastcache is by returning this sentinel. +# Containers (``dataclass_to_repr``, ``data_oriented`` branch, top-level ``hash_args`` loop) must propagate it +# upward — fastcache is disabled for the whole call and the caller writes the appropriate diagnostic. class _FailFastcache: """Singleton sentinel; identity-compared.""" @@ -76,10 +74,10 @@ class FastcacheSkip(enum.Enum): _should_warn = False -# Set of ``type(v).__qualname__`` strings we've already emitted the "unknown type, falling back to qualname hash" +# Set of ``type(v).__qualname__`` strings we've already emitted the "unknown type at a kernel-read path" # warning for. Lets the loop run thousands of times without spamming the log while still telling the user once -# that fastcache encountered an unrecognised type at a hashed path. Cleared by ``reset_unknown_type_warn_state`` -# (called from ``qd.init``) so each new test sees a clean log. +# that fastcache encountered an unrecognised type. Cleared by ``reset_unknown_type_warn_state`` (called from +# ``qd.init``) so each new test sees a clean log. _warned_unknown_types: set[str] = set() @@ -100,31 +98,33 @@ def _mark_should_warn() -> None: _should_warn = True -def _qualname_fallback(obj: object, path: tuple[str, ...]) -> str: - """Deterministic fallback for unrecognised types. +def _fail_unknown_type(obj: object, path: tuple[str, ...]) -> _FailFastcache: + """Disable fastcache for the call when an unrecognised type appears at a kernel-read path. - Returns a string derived from ``type(obj)``'s module + qualname so the cache key is *stable* across calls - (instances of the same opaque class get the same hash contribution). Warn once per unrecognised type so a - new tensor-like type added to Quadrants without being added to the recognised list here gets noticed in the - logs without spamming the per-call hot path. + Two rules at work here (see ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing"): - Safety note: this captures type identity only, NOT value or type-parameters (e.g. dtype/shape on a hypothetical - ``BFloat16Tensor``). For genuinely opaque metadata (UUID, Pydantic config, back-pointers) the type-identity - hash is correct because the kernel cannot read non-recognised Python types. For new tensor-like types whose - dtype/shape *would* affect codegen, the warning is the signal that someone needs to add them to the recognised - set in this module. + 1. The fastcache key may *only* contain contributions from kernel-pruned paths — never a + ``type(v).__qualname__`` fallback for an unrecognised type, because that hash captures type identity + only and would silently mask a value-affecting change (e.g. a new tensor-like type whose dtype matters). + + 2. We may not silently *discard* something at a kernel-read path on the basis that it's unrecognised — + that would let unrecognised but codegen-affecting values escape the cache key and serve stale results. + + The only way to honour both rules is to fail the call's fastcache loudly, with a one-shot warning per type + so the user can add explicit handling in ``stringify_obj_type``. """ t = type(obj) qualname = f"{getattr(t, '__module__', '')}.{getattr(t, '__qualname__', t.__name__)}" if qualname not in _warned_unknown_types: _warned_unknown_types.add(qualname) _logging.warn( - f"[FASTCACHE][UNKNOWN_TYPE] Falling back to type-name hash for path {path} type {qualname}. " - f"The cache key captures the type identity but not type parameters (e.g. dtype/shape). If this " - f"type's value affects kernel codegen, add explicit handling to " - f"``quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type``." + f"[FASTCACHE][UNKNOWN_TYPE] Unrecognised type {qualname} reached at kernel-read path {path}. " + f"Fastcache is disabled for this call. Add explicit handling for this type to " + f"``quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type``, or refactor the kernel " + f"so it does not read this member." ) - return f"opaque-{qualname}" + _mark_should_warn() + return _FAIL_FASTCACHE def _child_flat(parent_flat: str | None, child_name: str) -> str | None: @@ -248,25 +248,31 @@ def stringify_obj_type( Return contract: - ``str``: hashable; the returned string contributes to the cache key. - - ``_FAIL_FASTCACHE``: a recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``) - was encountered. Containers must propagate this upward; fastcache is disabled for the whole call. + - ``_FAIL_FASTCACHE``: fastcache cannot safely hash this value — caller must propagate upward and + disable fastcache for the whole call. Triggered by: + * Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). + * Unrecognised type at this kernel-read path (see ``_fail_unknown_type``). + + Two rules from ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing" govern this + function: - For *every other* unrecognised type, this function falls back to a deterministic - ``type(obj).__qualname__``-based string (see ``_qualname_fallback``). The pre-refactor design returned - ``None`` and disabled fastcache for any unrecognised member type, which made adding a UUID or Pydantic - config object to a ``@qd.data_oriented`` ``self`` silently kill fastcache. The qualname fallback captures - type identity (sufficient for genuinely opaque metadata — kernels cannot read non-recognised Python types - so opaque metadata cannot affect codegen) and warns once per unrecognised type so any future tensor-like - addition that *does* need explicit handling gets noticed. + 1. The cache key may *only* include contributions from paths that pruning has marked kernel-accessed + (``pruning_paths``). Container walkers (dataclass + data_oriented) check ``_is_path_used`` per + child and skip non-pruned subtrees — kernel-unread paths are *guaranteed* not to affect codegen so + this is safe by construction. + + 2. At paths the kernel *does* read, unrecognised types must not be silently dropped or hashed by + type-name — fastcache fails the call (loudly, with a one-shot warning) so the gap can be closed. Parameters: - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. Determines whether primitive values are baked into the cache key (template-position primitives and all primitive members of data-oriented containers). - - ``pruning_paths``: optional set of kernel-accessed flat names. When provided, ``dataclass_to_repr`` and - the ``data_oriented`` branch below descend only into children whose flat name is in the set. Skipped - children are *guaranteed* not to affect kernel codegen (the kernel never reads them), so omitting them - from the hash is safe by construction. + - ``pruning_paths``: optional set of kernel-accessed flat names from L1 cache. When provided, + ``dataclass_to_repr`` and the ``data_oriented`` branch below descend only into children whose flat + name is in the set. Pruning info is populated by ``ASTTransformer.build_Name`` / + ``build_Attribute`` (kernel-arg-rooted chains) plus ``Kernel._fold_struct_nd_paths_into_pruning`` + (ndarray accesses through data_oriented containers). - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the ``self`` arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's flat name for the narrow-walk lookup. @@ -313,22 +319,13 @@ def stringify_obj_type( raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat ) if is_data_oriented(obj): - # Walk the data_oriented container's members. Narrowing rules differ from ``dataclass_to_repr``: - # - # Pruning info for data_oriented containers is *only complete for ndarray members*: the kernel-compile - # path records each kernel-accessed ndarray's structural attribute chain in - # ``struct_ndarray_launch_info``, which ``Kernel._fold_struct_nd_paths_into_pruning`` folds into the - # flat-name pruning set. Non-ndarray attribute accesses on data_oriented args (``self.an_int``, - # ``self.a_float`` — values that get baked into the kernel at compile time) are *not* tracked anywhere - # as pruning input (data_oriented args aren't run through ``FlattenAttributeNameTransformer``). - # - # If we naively applied flat-name pruning to *every* child, an unused-but-present opaque member would - # match (silently dropped → safe), a kernel-read primitive member would silently disappear from the hash - # (BAD — its value affects codegen and we'd serve a stale cached compile when the value changes), and - # the templated-float raise-guard would also stop firing. - # - # Conservative fix: only narrow *ndarray* children. For everything else, walk unconditionally. The - # recursive call still applies narrowing to nested dataclasses (where flat-name tracking IS complete). + # Walk the data_oriented container's members, narrowed by pruning info — the kernel-compile path + # records every kernel-accessed attribute chain (ndarrays via ``_promote_ndarray_if_declared`` + + # ``_fold_struct_nd_paths_into_pruning``; primitives, opaque members, nested structs via + # ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` propagation calling + # ``pruning.mark_used``). Members not in ``pruning_paths`` are *guaranteed* not to affect kernel + # codegen because the kernel cannot read them. Dropping them from the hash satisfies rule 1 + # (cache only pruned paths). child_repr_l = ["da"] try: _asdict = getattr(obj, "_asdict") @@ -344,14 +341,7 @@ def stringify_obj_type( if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: continue child_flat = _child_flat(parent_flat, k) - # ndarray-only pruning narrowing — see the comment at the top of this branch for why other types - # cannot be safely narrowed here. - if ( - pruning_paths is not None - and child_flat is not None - and child_flat not in pruning_paths - and isinstance(v, (ScalarNdarray, VectorNdarray, MatrixNdarray)) - ): + if not _is_path_used(pruning_paths, child_flat): continue _child_repr = stringify_obj_type( raise_on_templated_floats, @@ -380,9 +370,8 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - # Unrecognised type — fall back to deterministic qualname-based hash and warn once. See ``_qualname_fallback`` - # for the safety reasoning. - return _qualname_fallback(obj, path) + # Unrecognised type at a kernel-read path — fail fastcache loudly. See ``_fail_unknown_type``. + return _fail_unknown_type(obj, path) def hash_args( @@ -396,12 +385,16 @@ def hash_args( Parameters: - ``pruning_paths``: optional set of kernel-accessed flat names from the L1 cache (or freshly populated after a cold compile). When provided, the container walkers skip children whose flat name is not in - the set; this both narrows the cache key (so unrelated metadata changes don't cause cache misses) and - eliminates the brittleness of walking opaque-typed members blindly. + the set; this is what keeps the cache key narrow and brittleness-free (no opaque-typed member can + affect the key unless the kernel actually reads it). + + Fastcache is disabled (``FastcacheSkip`` returned) when either: + - a recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``) is encountered at a + kernel-read path, OR + - an unrecognised type is encountered at a kernel-read path (see ``_fail_unknown_type``). - Fastcache is disabled (``FastcacheSkip`` returned) only when a recognised-but-unsupported tensor-like type - (``ScalarField`` / ``MatrixField``) is encountered. Truly-unrecognised types use a ``type(v).__qualname__`` - fallback so the cache key stays stable. + Both cases are loud: ``FastcacheSkip.WARN`` triggers an ``[INVALID_FUNC]`` log line and the unknown-type + branch additionally emits a one-shot ``[UNKNOWN_TYPE]`` warning identifying the offending type. """ global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement _should_warn = False diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index b5b1f97a27..088459f9c8 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -1,13 +1,32 @@ -from ast import Name, Starred, expr, keyword +from ast import Attribute, Name, Starred, expr, keyword from collections import defaultdict from typing import TYPE_CHECKING, Any +from ._dataclass_util import create_flat_name from ._exceptions import raise_exception from ._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from .exception import QuadrantsSyntaxError from .func import Func from .kernel_arguments import ArgMetadata + +def _flatten_arg_node(node: expr) -> str | None: + """Flatten an AST arg node into the corresponding kernel-arg-rooted flat name (or ``None`` if the + node isn't a recognisable name/attribute chain rooted at a plain Name). + + Mirrors ``FlattenAttributeNameTransformer._flatten_attribute_name`` but on the raw call-arg AST. + Used by ``record_after_call`` to handle ``f(self.dofs)`` etc. — without this the callee's pruning + info for attribute-chain args is dropped at the call boundary.""" + if isinstance(node, Name): + return node.id + if isinstance(node, Attribute): + parent = _flatten_arg_node(node.value) + if parent is None: + return None + return create_flat_name(parent, node.attr) + return None + + if TYPE_CHECKING: import ast @@ -50,11 +69,54 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: # therefore unreliable — in that case ``_predeclare_struct_ndarrays`` falls back to # registering every reachable ndarray (same as the historical behavior). self.pass_0_ran: bool = False + # Kernel-arg-rooted attribute chains used by each func, in flat-name form + # (``__qd_self__qd_dofs__qd_x``). Populated by ``ASTTransformer.build_Attribute`` + # for non-flattened kernel args (data_oriented / qd.template). Kept *separate* from + # ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on the enforcing + # pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite + # ``s.x`` → ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a + # ``QuadrantsNameError: Name "__qd_s__qd_x" is not defined``. ``record_after_call`` + # propagates entries from callee to caller (so ``f(self.dofs)`` where ``f`` reads + # ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). After both + # compile passes, ``Kernel._fold_kernel_arg_chain_paths_into_pruning`` merges the + # kernel's set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` so fastcache stores them + # in L1 and the args_hasher narrow walk picks them up. + self.kernel_arg_chain_paths_by_func_id: dict[int, set[str]] = defaultdict(set) def mark_used(self, func_id: int, parameter_flat_name: str) -> None: assert not self.enforcing self.used_vars_by_func_id[func_id].add(parameter_flat_name) + def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None: + """Record a kernel-arg-rooted attribute chain (e.g. ``__qd_self__qd_dofs__qd_x``). + + Stored separately from ``used_vars_by_func_id`` — see the docstring on + ``kernel_arg_chain_paths_by_func_id`` for why.""" + assert not self.enforcing + self.kernel_arg_chain_paths_by_func_id[func_id].add(chain_flat_name) + + @staticmethod + def _propagate_chain_paths( + callee_chain_paths: set[str], + callee_param_name: str, + caller_flat: str, + chain_paths_to_propagate: set[str], + ) -> None: + """When ``f(self.dofs)`` is called and ``f``'s body reads ``s.x`` (callee param ``s`` bound to caller + attribute chain ``self.dofs``), the callee's chain-paths set contains ``__qd_s__qd_x`` but the + caller's chain-paths set must record ``__qd_self__qd_dofs__qd_x``. This helper does that + prefix substitution. Only chain paths starting with ``__qd___qd_`` are propagated + (chains rooted in unrelated callee args don't apply to this caller arg).""" + prefix = f"__qd_{callee_param_name}__qd_" + for sub in callee_chain_paths: + if sub.startswith(prefix): + rest = sub[len(prefix) :] + if caller_flat.startswith("__qd_"): + new_flat = f"{caller_flat}__qd_{rest}" + else: + new_flat = f"__qd_{caller_flat}__qd_{rest}" + chain_paths_to_propagate.add(new_flat) + def enforce(self) -> None: self.enforcing = True @@ -81,7 +143,9 @@ def record_after_call( callee_func_id = func.wrapper.func_id # type: ignore # Copy the used parameters from the child function into our own function. callee_used_vars = self.used_vars_by_func_id[callee_func_id] + callee_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(callee_func_id, set()) vars_to_unprune: set[str] = set() + chain_paths_to_propagate: set[str] = set() arg_id = 0 # node.args ordering will match that of the called function's metas_expanded, # because of the way calling with sequential args works. @@ -99,6 +163,15 @@ def record_after_call( callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + # NEW: propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) + # AND through plain-Name args of non-flattened types (``f(self)``). These flow into the + # caller's separate chain-paths set, not ``used_vars`` — see the field-level docstring. + caller_flat = _flatten_arg_node(arg) + if caller_flat is not None and not caller_flat.startswith("__qd_"): + callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 # Note that our own arg_metas ordering will in general NOT match that of the child's. That's # because our ordering is based on the order in which we pass arguments to the function, but the @@ -112,8 +185,15 @@ def record_after_call( callee_param_name = kwarg.arg if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + caller_flat = _flatten_arg_node(kwarg.value) + if caller_flat is not None and not caller_flat.startswith("__qd_"): + callee_param_name = kwarg.arg + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 self.used_vars_by_func_id[my_func_id].update(vars_to_unprune) + self.kernel_arg_chain_paths_by_func_id[my_func_id].update(chain_paths_to_propagate) used_callee_vars = self.used_vars_by_func_id[callee_func_id] child_arg_id = 0 diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 977eec7fca..b911bd7653 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -16,6 +16,7 @@ from quadrants._lib import core as _qd_core from quadrants.lang import exception, expr, impl, matrix, mesh from quadrants.lang import ops as qd_ops +from quadrants.lang._dataclass_util import create_flat_name from quadrants.lang._ndrange import _Ndrange from quadrants.lang.ast.ast_transformer_utils import ( ASTTransformerFuncContext, @@ -85,6 +86,15 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): pruning = ctx.global_context.pruning if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters and node.id.startswith("__qd_"): ctx.global_context.pruning.mark_used(ctx.func.func_id, node.id) + # Track chains rooted at non-flattened kernel args (``@qd.data_oriented`` / ``qd.template`` params, which + # appear in the AST with bare names like ``self``). ``build_Attribute`` propagates this annotation through + # ``state.dofs.x`` chains and ``mark_used``s the flat name, so the fastcache narrow walk can include them + # in pruning (dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as + # already-flat ``__qd_…`` Names, handled by the block above). + if node.id in ctx.kernel_args and not node.id.startswith("__qd_"): + node._qd_arg_chain = node.id # type: ignore[attr-defined] + else: + node._qd_arg_chain = None # type: ignore[attr-defined] node.violates_pure, node.ptr, node.violates_pure_reason = ctx.get_var_by_name(node.id) # Flattened struct fields (``__qd_foo__qd_bar``) injected by ``populate_global_vars_from_dataclass`` are raw # ``Ndarray`` instances. ``build_Attribute`` already promotes these via ``_promote_ndarray_if_declared`` but @@ -798,6 +808,28 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) + # Propagate the kernel-arg-rooted chain annotation and record this access in pruning's *separate* + # chain-paths set. ``build_Name`` sets ``_qd_arg_chain`` on non-flattened kernel args (e.g. + # data_oriented ``self``); each Attribute access in the chain extends it + # (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). + # + # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses + # ``pruning.used_vars_by_func_id`` as ``struct_locals``, which drives + # ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would make the transformer + # rewrite ``self.x`` into ``Name('__qd_self__qd_x')``, and ``build_Name`` would then fail to find + # such a variable. ``mark_kernel_arg_chain_used`` puts the chain into a *separate* per-func set + # that's merged into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` only *after* both compile passes, + # by ``Kernel._fold_kernel_arg_chain_paths_into_pruning`` — so the fastcache args-hash narrow walk + # picks them up without breaking codegen. + parent_chain = getattr(node.value, "_qd_arg_chain", None) + if parent_chain is not None: + flat = create_flat_name(parent_chain, node.attr) + node._qd_arg_chain = flat # type: ignore[attr-defined] + pruning = ctx.global_context.pruning + if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters: + pruning.mark_kernel_arg_chain_used(ctx.func.func_id, flat) + else: + node._qd_arg_chain = None # type: ignore[attr-defined] return node.ptr @staticmethod diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index f332e9d8f7..23e679a315 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -489,6 +489,14 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . # of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). self._fold_struct_nd_paths_into_pruning(key, pruning) + # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested + # struct paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` + # tracking. Kept separate from ``used_vars_by_func_id`` during compile (would otherwise + # poison ``struct_locals`` and break codegen) — see the field-level docstring on + # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` + # assignment to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set + # by reference, so the final fastcache L1 entry sees all kernel-accessed paths. + self._fold_kernel_arg_chain_paths_into_pruning(pruning) else: for used_parameters in pruning.used_vars_by_func_id.values(): new_used_parameters = set() @@ -548,6 +556,28 @@ def _fold_struct_nd_paths_into_pruning(self, key: "CompiledKernelKeyType", pruni flat = create_flat_name(flat, attr) kernel_used.add(flat) + @staticmethod + def _fold_kernel_arg_chain_paths_into_pruning(pruning: Pruning) -> None: + """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both + compile passes have completed. + + Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain + (e.g. ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into + ``pruning.kernel_arg_chain_paths_by_func_id`` rather than ``used_vars_by_func_id``, because the + latter is read on the enforcing pass to build ``struct_locals`` for + ``FlattenAttributeNameTransformer``. If chain names appeared there, the transformer would rewrite + ``self.n`` into ``Name('__qd_self__qd_n')`` and ``build_Name`` would fail to find such a variable. + + Doing the merge here — after pass 1, just like ``_fold_struct_nd_paths_into_pruning`` — + avoids that interaction while still making the chain paths available to the fastcache args-hash + narrow walk. The set on ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* + object as ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so + updating one updates both.""" + kernel_chain_paths = pruning.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) + if not kernel_chain_paths: + return + pruning.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_chain_paths) + def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...]) -> None: """After a successful materialize, persist L1 (if missing) and set ``fast_checksum`` to the L2 key. diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index bc49b18a4e..5736a6b96e 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -142,14 +142,16 @@ def k1(foo: qd.Template) -> None: k1(foo=RandomClass()) _out, err = capfd.readouterr() - # Unrecognised types now fall back to a deterministic ``type(v).__qualname__`` hash (instead of silently - # disabling fastcache via the old ``[PARAM_INVALID]`` / ``[INVALID_FUNC]`` dead-end), and emit an - # ``[UNKNOWN_TYPE]`` warning once per type so a new tensor-like type added to Quadrants without explicit - # args-hasher handling still gets noticed in the logs. ``[PARAM_INVALID]`` is gone. + # Unrecognised types at a (top-level) kernel-read path now fail fastcache loudly: a one-shot + # ``[UNKNOWN_TYPE]`` warning identifies the offending type, and ``[INVALID_FUNC]`` then reports the + # disabled cache. The old silent ``[PARAM_INVALID]`` dead-end is gone — the two rules driving this + # are documented in ``args_hasher.py::_fail_unknown_type`` and ``fastcache.md`` "Pruning-driven + # argument hashing": (1) only pruned paths may contribute to the cache key (so no qualname fallback), + # (2) unrecognised types at pruned paths must not be silently dropped. assert "[FASTCACHE][UNKNOWN_TYPE]" in err assert RandomClass.__name__ in err + assert "[FASTCACHE][INVALID_FUNC]" in err assert "[FASTCACHE][PARAM_INVALID]" not in err - assert "[FASTCACHE][INVALID_FUNC]" not in err @qd.kernel def not_pure_k1(foo: qd.Template) -> None: diff --git a/tests/python/test_template_typing.py b/tests/python/test_template_typing.py index 69e9ee990b..c4f00a081e 100644 --- a/tests/python/test_template_typing.py +++ b/tests/python/test_template_typing.py @@ -57,24 +57,35 @@ class DataOrientedWithoutFloat: def __init__(self) -> None: self.an_int = 123 self.a_bool = True + self.scratch = qd.ndarray(qd.i32, shape=(1,)) @qd.data_oriented class DataOrientedWithFloat: def __init__(self) -> None: self.an_int = 123 self.a_float = 1.23 + self.scratch = qd.ndarray(qd.i32, shape=(1,)) + # Read the primitive members so the fastcache narrow walk includes them in the hash. Pre-pruning + # the args_hasher walked every member of every container arg blindly; with pruning the kernel must + # actually access ``a.a_float`` for the raise-on-templated-floats guard to fire (the value being + # baked-in only matters when the kernel reads it). @qd.kernel(fastcache=True) - def k1(a: qd.Template) -> None: ... + def k1f(a: qd.Template) -> None: + a.scratch[0] = qd.cast(a.a_float, qd.i32) + + @qd.kernel(fastcache=True) + def k1i(a: qd.Template) -> None: + a.scratch[0] = a.an_int my_do1 = DataOrientedWithoutFloat() - k1(my_do1) + k1i(my_do1) my_do2 = DataOrientedWithFloat() if raise_on_templated_floats: with pytest.raises(ValueError): - k1(my_do2) + k1f(my_do2) else: - k1(my_do2) + k1f(my_do2) @pytest.mark.parametrize("raise_on_templated_floats", [False, True]) From 356394ef5e38ef6454ef5b3f9022c264a1882b97 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 05:28:52 -0700 Subject: [PATCH 22/46] [Test] Pin pruning-driven fastcache behaviour for @qd.data_oriented args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five new tests in test_data_oriented_ndarray.py covering the three rules the args hasher now enforces (see fastcache.md "Pruning-driven argument hashing"): - test_data_oriented_kernel_unused_opaque_member_does_not_affect_cache: rule 1 — two State instances differing only in an uuid member that the kernel never reads share the same compiled artifact across processes. - test_data_oriented_kernel_read_opaque_member_fails_fastcache: rule 2 — when the kernel actually reads an unrecognised-type member, fastcache fails loudly with [UNKNOWN_TYPE]+[INVALID_FUNC]. Kernel still runs. - test_data_oriented_kernel_read_primitive_distinguishes_cache_key: rule 3 — kernel reading a primitive member (s.n baked in) cold-compiles per value and both values load distinct artifacts on warm start. - test_data_oriented_kernel_unread_primitive_does_not_affect_cache: rule 1 mirror for primitives — unused_n differences don't perturb the cache key. - test_data_oriented_qd_func_chain_propagation_distinguishes_cache_key: Pruning.record_after_call propagation through @qd.func(s.dofs) — the inner dofs.x dtype must reach the kernel's pruning set so changes invalidate the cache. --- tests/python/test_data_oriented_ndarray.py | 241 +++++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index 7c73a714b4..eb7ff591b3 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1022,3 +1022,244 @@ def run(s: qd.template()): run(p) np.testing.assert_array_equal(x.to_numpy(), np.arange(10, 10 + N)) + + +# --------------------------------------------------------------------------- +# Pruning-driven fastcache behaviour for @qd.data_oriented containers. +# +# These pin the three rules enforced by the args hasher (see fastcache.md +# "Pruning-driven argument hashing"): +# 1. The cache key may only include contributions from kernel-pruned paths. +# 2. Unrecognised types at kernel-read paths must not be silently dropped. +# 3. Fastcache works for @qd.data_oriented kernel args end-to-end. +# --------------------------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unused_opaque_member_does_not_affect_cache(tmp_path, monkeypatch): + """Rule 1: kernel-unused opaque members do not affect the fastcache key. + + Two ``State`` instances differ only in an opaque ``uuid`` member that the kernel never reads. + Both must hit the same compiled artifact on the second process — proof that the args hasher's + pruning narrow walk skips the opaque attribute (no qualname-fallback, no spurious miss).""" + import uuid + + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + self.uuid = uuid.uuid4() # opaque member, kernel does not read it + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + + # Second process: cold-start, must load from disk. If the uuid had leaked into the cache key, + # different uuid → different L2 key → no artifact would load. + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different uuid) should ALSO load from disk" + assert run._primal.src_ll_cache_observations.cache_loaded + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_opaque_member_fails_fastcache(tmp_path, capfd) -> None: + """Rule 2: when the kernel actually reads an unrecognised-type member, fastcache fails loudly + with [UNKNOWN_TYPE] + [INVALID_FUNC] — no silent drop, no qualname fallback. The kernel still + runs via normal compilation.""" + from quadrants._test_tools import qd_init_same_arch + from quadrants.lang._fast_caching.args_hasher import reset_unknown_type_warn_state + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + reset_unknown_type_warn_state() + + class CustomConfig: + def __init__(self, scale: int) -> None: + self.scale = scale + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + x = qd.ndarray(qd.i32, shape=(4,)) + state = State(x=x, cfg=CustomConfig(scale=3)) + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + scale = s.cfg.scale # makes ``__qd_s__qd_cfg`` and ``__qd_s__qd_cfg__qd_scale`` live + for i in range(4): + s.x[i] = i * scale + + run(state) + _out, err = capfd.readouterr() + np.testing.assert_array_equal(x.to_numpy(), np.arange(4) * 3) + + obs = run._primal.src_ll_cache_observations + assert obs.cache_key_generated is False, "unrecognised type at kernel-read path must disable fastcache" + assert "[FASTCACHE][UNKNOWN_TYPE]" in err + assert CustomConfig.__name__ in err + assert "[FASTCACHE][INVALID_FUNC]" in err + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_primitive_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Rule 3 (data_oriented works) + pruning correctness: when the kernel reads a primitive member, + its value is baked into the kernel and must drive a distinct cache entry per value. Two State + instances differing only in ``n`` (read by the kernel) cold-compile separately and both load + from disk on the second process.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, n): + self.x = x + self.n = n # primitive, baked into kernel via ``for i in range(s.n)`` + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(s.n): + s.x[i] = i + s.n + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different ``n`` → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both ``n`` values should load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unread_primitive_does_not_affect_cache(tmp_path, monkeypatch) -> None: + """Rule 1: kernel-unused primitive members do not affect the cache key. Mirror of the opaque + case for primitives. Two State instances differing only in ``unused_n`` must share the cache.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, unused_n): + self.x = x + self.unused_n = unused_n # kernel never reads this + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different unused_n) should ALSO load from disk" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_qd_func_chain_propagation_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``@qd.func`` calls (``record_after_call`` extension): + when the kernel calls ``f(self.dofs)`` and ``f`` reads ``s.x``, the kernel's pruning set + must include ``__qd_self__qd_dofs__qd_x`` so that changes to the inner ndarray's dtype + invalidate the cache. Two States differing in ``dofs.x``'s dtype must cold-compile separately.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Dofs: + def __init__(self, x): + self.x = x + + @qd.data_oriented + class State: + def __init__(self, dofs): + self.dofs = dofs + + @qd.func + def write_dofs(d: qd.template(), v: qd.i32): + d.x[0] = v + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_dofs(s.dofs, 7) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "differing dofs.x dtype → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both dtypes load distinct artifacts" From 45129bc389e6abfed02859c62980766787a8666d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 05:43:41 -0700 Subject: [PATCH 23/46] [Doc] data_oriented(stable_members=...) docstring: correct the failure-mode note The earlier docstring mentioned a qualname fallback for unrecognised types, which was true at the time but was subsequently removed in the strict-rules refactor. Update the note to match the actual current behaviour: unrecognised types at kernel-read paths fail fastcache loudly with [UNKNOWN_TYPE] + [INVALID_FUNC]. --- python/quadrants/lang/kernel_impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 319a264980..278c451362 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -310,10 +310,10 @@ def data_oriented(cls=None, *, stable_members: bool = False): as a class-level attribute ``_qd_stable_members = True`` (equivalent). Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache - argument hashing — the cache key is derived from pruning info (the set of flat names - the kernel actually reads), and unrecognised types at kernel-accessed paths fall back - to a deterministic ``type(v).__qualname__`` hash with a one-shot ``[UNKNOWN_TYPE]`` - warning. See ``docs/source/user_guide/fastcache.md``. + argument hashing — the fastcache key is derived from pruning info (the set of flat + names the kernel actually reads), and unrecognised types at kernel-read paths fail + fastcache loudly with a one-shot ``[UNKNOWN_TYPE]`` + ``[INVALID_FUNC]`` diagnostic + (no qualname fallback). See ``docs/source/user_guide/fastcache.md``. Returns: The decorated class (or, when called with arguments, a decorator). From 1f25d9c2e93bff1173bf59c731e7c05ac82271fb Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 07:51:04 -0700 Subject: [PATCH 24/46] [Fix] record_after_call: propagate chain paths through Attribute args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``f(self.dofs)`` was silently dropping the callee's chain-path info at the call boundary because the propagation gate (``not caller_flat.startswith("__qd_")``) excluded any flat path that started with ``__qd_``. But every Attribute chain flattens to a ``__qd_*`` name even when its root is a bare kernel arg (``self.dofs`` → ``__qd_self__qd_dofs``), so the gate was effectively "only propagate through bare-Name args" — exactly the opposite of what ``f(self.dofs)`` needs. Fix: ``_flatten_arg_node`` now returns ``(flat_name, root_name_id)`` and ``record_after_call`` gates on the *root* Name's id. Bare kernel-arg chains (root ``self``) propagate; already-flattened dataclass refs (root ``__qd_self__qd_x``) still skip — those are handled by the existing ``vars_to_unprune`` path and dataclass expansion. Genesis impact: ``RigidSolver.step_1`` calls ``func_update_cartesian_space(static_rigid_sim_config=self._static_rigid_sim_config, ...)`` and the callee reads ``static_rigid_sim_config.para_level`` etc. via ``qd.static``. Without propagation those paths never reached the kernel's pruning set, so the args-hasher walked ``self._static_rigid_sim_config`` as data_oriented and skipped every child (none pruned). Different configs hashed identically, fastcache hit served a stale kernel, and ``test_ndarray_no_compile`` failed with ``z == 0.5`` (physics step ran on an iter-N-compiled kernel against an iter-(N+1) scene). Adds ``test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key`` that pins the case explicitly without depending on ndarrays (the existing chain-propagation test passed despite this bug because ndarray paths are tracked via the separate ``_fold_struct_nd_paths_into_pruning`` mechanism). --- python/quadrants/lang/_pruning.py | 54 ++++++++++------- tests/python/test_data_oriented_ndarray.py | 70 ++++++++++++++++++++++ 2 files changed, 104 insertions(+), 20 deletions(-) diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 088459f9c8..5245741c71 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -10,20 +10,27 @@ from .kernel_arguments import ArgMetadata -def _flatten_arg_node(node: expr) -> str | None: - """Flatten an AST arg node into the corresponding kernel-arg-rooted flat name (or ``None`` if the - node isn't a recognisable name/attribute chain rooted at a plain Name). +def _flatten_arg_node(node: expr) -> tuple[str, str] | None: + """Flatten an AST arg node into ``(flat_name, root_name_id)`` (or ``None`` if the node isn't a + recognisable name/attribute chain rooted at a plain Name). + + Returns both the full flat name (e.g. ``__qd_self__qd_dofs`` for ``self.dofs``) and the root + Name's id (``self``). Callers use the root id to distinguish kernel-arg-rooted chains + (``self.dofs`` → root ``self``) from already-flattened dataclass-arg references + (``__qd_self__qd_dofs`` → root ``__qd_self__qd_dofs``). The flat path alone is ambiguous because + ``__qd_self__qd_dofs`` could be either an attribute chain *or* a single flattened Name. Mirrors ``FlattenAttributeNameTransformer._flatten_attribute_name`` but on the raw call-arg AST. Used by ``record_after_call`` to handle ``f(self.dofs)`` etc. — without this the callee's pruning info for attribute-chain args is dropped at the call boundary.""" if isinstance(node, Name): - return node.id + return node.id, node.id if isinstance(node, Attribute): parent = _flatten_arg_node(node.value) if parent is None: return None - return create_flat_name(parent, node.attr) + parent_flat, root_id = parent + return create_flat_name(parent_flat, node.attr), root_id return None @@ -163,15 +170,20 @@ def record_after_call( callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) - # NEW: propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) - # AND through plain-Name args of non-flattened types (``f(self)``). These flow into the - # caller's separate chain-paths set, not ``used_vars`` — see the field-level docstring. - caller_flat = _flatten_arg_node(arg) - if caller_flat is not None and not caller_flat.startswith("__qd_"): - callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore - self._propagate_chain_paths( - callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate - ) + # Propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) + # AND through plain-Name args of non-flattened types (``f(self)``). Gate on the *root* + # Name id, not the resulting flat string: ``self.dofs`` flattens to ``__qd_self__qd_dofs`` + # (which starts with ``__qd_``) but its root is the bare kernel arg ``self`` — we still + # need to propagate. Already-flattened dataclass refs like ``Name('__qd_self__qd_dofs')`` + # have a ``__qd_*`` root and are handled by the ``vars_to_unprune`` path above. + flat = _flatten_arg_node(arg) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 # Note that our own arg_metas ordering will in general NOT match that of the child's. That's # because our ordering is based on the order in which we pass arguments to the function, but the @@ -185,12 +197,14 @@ def record_after_call( callee_param_name = kwarg.arg if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) - caller_flat = _flatten_arg_node(kwarg.value) - if caller_flat is not None and not caller_flat.startswith("__qd_"): - callee_param_name = kwarg.arg - self._propagate_chain_paths( - callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate - ) + flat = _flatten_arg_node(kwarg.value) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = kwarg.arg + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 self.used_vars_by_func_id[my_func_id].update(vars_to_unprune) self.kernel_arg_chain_paths_by_func_id[my_func_id].update(chain_paths_to_propagate) diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index eb7ff591b3..aec55ddc71 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1263,3 +1263,73 @@ def run(s: qd.template()): run(a) run(b) assert captured[-2] is not None and captured[-1] is not None, "both dtypes load distinct artifacts" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested + data_oriented containers. + + Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the + caller-side arg flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — + ``self.cfg`` → ``__qd_self__qd_cfg``). When that happened, primitive members read inside the + callee (``cfg.n`` → ``__qd_cfg__qd_n`` in the callee's chain set) never made it into the kernel's + pruning set, so the args-hasher walked ``self.cfg`` as data_oriented and found no pruned children, + yielding an identical hash for *any* value of ``cfg.n``. Two configs that should produce + different kernels (different ``range(s.cfg.n)`` trip counts baked into codegen) would then share + a fastcache entry — leading to stale-kernel hits and silent miscompiles (e.g. Genesis' + ``test_ndarray_no_compile`` was failing with iter-N kernels reused for iter-N+1 scenes that have + a different ``RigidSimStaticConfig.para_level`` baked into their ``qd.static`` branches). + + The fix in ``_pruning.py`` gates propagation on the *root Name* of the chain (``self``, not the + flat result), so both ``f(self)`` and ``f(self.cfg)`` propagate, while already-flattened + dataclass refs (``Name('__qd_state__qd_x')``) are still skipped.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Cfg: + def __init__(self, n): + self.n = n # primitive read by ``write_x`` — drives codegen via ``range(c.n)`` + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + @qd.func + def write_x(x: qd.template(), c: qd.template()): + for i in range(c.n): + x[i] = i + c.n + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_x(s.x, s.cfg) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different cfg.n → both cold-compile" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both cfg.n values load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) From 5fc9b4c4cfaf1bad5384784ccf75dc9e927b71d5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 09:37:14 -0700 Subject: [PATCH 25/46] [Fix] Track @qd.func params in fn_param_names for chain-path seeding Without this, attribute chains rooted at a @qd.func param (e.g. static_rigid_sim_config.para_level inside a qd.func) were not recorded in pruning, causing args-hasher to skip kernel-read primitive members of nested data_oriented containers and producing stale fastcache hits. - Add ASTTransformerFuncContext.fn_param_names - _transform_func_arg adds bare argument name to fn_param_names - build_Name seeds _qd_arg_chain for nodes in fn_param_names too --- python/quadrants/lang/ast/ast_transformer.py | 19 +++++++++++++------ .../lang/ast/ast_transformer_utils.py | 12 ++++++++++++ .../function_def_transformer.py | 10 ++++++++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index b911bd7653..8836433714 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -86,12 +86,19 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): pruning = ctx.global_context.pruning if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters and node.id.startswith("__qd_"): ctx.global_context.pruning.mark_used(ctx.func.func_id, node.id) - # Track chains rooted at non-flattened kernel args (``@qd.data_oriented`` / ``qd.template`` params, which - # appear in the AST with bare names like ``self``). ``build_Attribute`` propagates this annotation through - # ``state.dofs.x`` chains and ``mark_used``s the flat name, so the fastcache narrow walk can include them - # in pruning (dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as - # already-flat ``__qd_…`` Names, handled by the block above). - if node.id in ctx.kernel_args and not node.id.startswith("__qd_"): + # Track chains rooted at non-flattened parameter names: top-level ``@qd.kernel`` args + # (``ctx.kernel_args``) and ``@qd.func`` params (``ctx.fn_param_names``). Both appear in the + # AST as bare names (``self`` for a data_oriented kernel arg; ``static_rigid_sim_config`` for + # a ``qd.template()`` func arg bound to a ``@qd.data_oriented`` instance). + # ``build_Attribute`` propagates this annotation through ``state.dofs.x`` chains and + # ``mark_kernel_arg_chain_used``-s the flat name. The kernel's pruning narrow walk picks them + # up directly (kernel case) or after ``record_after_call`` propagates the callee's func-arg + # chains back through the call boundary (func case): e.g. ``func(s=self._sub)`` where ``func`` + # reads ``s.x`` ends up with ``__qd_self__qd__sub__qd_x`` recorded in the kernel's pruning, + # so the args-hasher hashes that primitive value into the fastcache key. + # Dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as + # already-flat ``__qd_…`` Names, handled by the block above via ``mark_used``. + if not node.id.startswith("__qd_") and (node.id in ctx.kernel_args or node.id in ctx.fn_param_names): node._qd_arg_chain = node.id # type: ignore[attr-defined] else: node._qd_arg_chain = None # type: ignore[attr-defined] diff --git a/python/quadrants/lang/ast/ast_transformer_utils.py b/python/quadrants/lang/ast/ast_transformer_utils.py index 506778c683..5e0af0ea55 100644 --- a/python/quadrants/lang/ast/ast_transformer_utils.py +++ b/python/quadrants/lang/ast/ast_transformer_utils.py @@ -247,6 +247,18 @@ def __init__( self.visited_funcdef = False self.is_real_function = is_real_function self.kernel_args: list = [] + # Names of the bare (non-flattened) parameters of a ``@qd.func`` being processed. Used by + # ``build_Name`` to seed ``_qd_arg_chain`` for attribute accesses rooted at a func param + # (e.g. ``static_rigid_sim_config.para_level`` where ``static_rigid_sim_config`` is a + # ``qd.template()`` arg bound to a ``@qd.data_oriented`` instance). Without this, chains + # rooted at func params would not be recorded in pruning, and the args-hasher would skip + # over kernel-read primitive members of nested data_oriented containers — leading to stale + # fastcache hits when those members change between calls. + # ``kernel_args`` only tracks top-level ``@qd.kernel`` args; ``_transform_func_arg`` for a + # ``@qd.func`` does not append to it (see function_def_transformer.py). This separate set + # avoids piggy-backing on ``kernel_args`` so the existing "kernel arg is immutable" + # diagnostic in ``build_assign_annotated`` doesn't start firing for func params. + self.fn_param_names: set[str] = set() self.only_parse_function_def: bool = False self.autodiff_mode = autodiff_mode self.loop_depth: int = 0 diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 1810c086c4..8a88dcd2e5 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -341,6 +341,16 @@ def _transform_func_arg( argument_type: Any, data: Any, ) -> None: + # Record the bare (non-flattened) func param name so ``build_Name`` can seed ``_qd_arg_chain`` + # for attribute accesses rooted at this param. Critical for ``qd.template()`` args bound to + # ``@qd.data_oriented`` instances (e.g. ``static_rigid_sim_config.para_level`` inside a + # ``@qd.func``): without this, the kernel's pruning set never learns about ``.para_level``, + # the args-hasher skips the value, and different ``para_level`` configurations collide in the + # fastcache key. Flat names starting with ``__qd_`` arrive here too via the dataclass-flatten + # recursion below; they're harmless to add (``build_Name``'s chain branch gates on + # ``not node.id.startswith("__qd_")``) but the bare-name entries are what enables propagation. + ctx.fn_param_names.add(argument_name) + # Template arguments are passed by reference. if isinstance(argument_type, annotations.template): ctx.create_variable(argument_name, data) From 710ee4705adf32383b68c21e083f237d82d5f6c6 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 13:46:03 -0700 Subject: [PATCH 26/46] [Fix] Fastcache: prune _predeclare_struct_ndarrays by flat-name on cache hit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pass 0 is skipped on a fastcache hit, leaving the `id(nd)`-keyed `pruning.used_struct_ndarray_ids` set empty. The previous fallback ("register every reachable ndarray") quietly broke physics on cache hits — the kernel was compiled with N ndarray slots but launch re-registered every reachable ndarray (~4x for Genesis's rigid solver), scrambling arg-slot bindings. Use the *flat-name* form of the kernel-accessed paths (`pruning.used_vars_by_func_id[KERNEL_FUNC_ID]`, reseeded from the cached `used_py_dataclass_parameters` which already contains every leaf folded in by `_fold_struct_nd_paths_into_pruning`) to gate registration when pass 0 didn't run. This reproduces the exact ndarray set the originating compile produced. Test pin: `tests/test_quadrants.py::test_ndarray_no_compile` (cpu + gpu parameterizations) was silently failing on the second iteration with wrong physics (`z=[0.5,0.5]` instead of `0.2`); now passes. --- .../function_def_transformer.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 8a88dcd2e5..dda962d5f9 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -243,14 +243,28 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: """ from quadrants.lang.util import cook_dtype # pylint: disable=C0415 + from quadrants.lang._pruning import Pruning # pylint: disable=C0415 + cache = ctx.global_context.ndarray_to_any_array launch_info = ctx.global_context.struct_ndarray_launch_info pruning = ctx.global_context.pruning used_ids = getattr(pruning, "used_struct_ndarray_ids", None) # Only prune on the enforcing pass when we actually ran pass 0 to populate the - # used-ndarray set. On a fastcache hit pass 0 is skipped and the set is empty — - # fall back to registering every reachable ndarray. + # used-ndarray set. On a fastcache hit pass 0 is skipped and the set is empty. prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) + # On a fastcache hit (enforcing without a pass-0 run), the `id(nd)` set is empty, but the + # *flat-name* set on ``used_vars_by_func_id[KERNEL_FUNC_ID]`` was loaded from cache and + # already contains every kernel-accessed leaf path (folded in by + # ``_fold_struct_nd_paths_into_pruning`` during the compile that produced the cache entry). + # Use that to prune the walk so we register the exact same ndarray set as the originating + # compile produced — without this, every reachable ndarray gets registered, the kernel's + # arg slots get rebound to the wrong ndarrays at launch, and physics silently breaks. + prune_from_flat_names = pruning.enforcing and not getattr(pruning, "pass_0_ran", False) + kernel_used_flat_names = ( + pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) + if prune_from_flat_names + else None + ) # Cycle-safe walker: Genesis object graphs have cross-references (e.g. solver <-> scene <-> sim) so we must # avoid re-entering the same node. ``seen`` is shared across the whole arg's traversal — ``id(obj)`` is @@ -289,6 +303,19 @@ def _register_ndarray(nd, arg_idx, attr_chain): return if prune and key not in used_ids: return + if prune_from_flat_names: + # Build the leaf flat name (e.g. ``__qd_self__qd__collider_state__qd_active_buffer``) + # and skip registration when the kernel's cached pruning set doesn't contain it. + if arg_idx < 0 or arg_idx >= len(ctx.func.arg_metas): + return + arg_name = ctx.func.arg_metas[arg_idx].name + if not arg_name: + return + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + if flat not in kernel_used_flat_names: + return from quadrants._lib import core as _qd_core # pylint: disable=C0415 element_type = cook_dtype(nd.element_type) From 090f1a8d81530e70c3d3cde798c5de58ded09247 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 16:21:55 -0700 Subject: [PATCH 27/46] [CI] Fix linters, pyright, MockContext test, deleted-comment, line-wrap CI fixes for PR #705: - Linters (black + ruff): collapse 3 wrapped call lines (autoformat) and sort imports in ``src_hasher.py`` and ``function_def_transformer.py``. - ``test_process_func_arg``: add ``fn_param_names: set[str]`` to ``MockContext`` mirroring the real ``ASTTransformerFuncContext`` so ``_transform_func_arg`` no longer ``AttributeError``s when recording bare param names. This was crashing 18 test parameterisations across Linux/Mac/CUDA/Vulkan/AMD CI. - Pyright: guard ``kwarg.arg`` ``None`` (double-star unpack) in ``_pruning.record_after_call``; narrow ``_arg_nd_paths_or_none.get``'s sentinel-default in ``_template_mapper`` via explicit ``cast``; handle ``Ndarray.shape is None`` in ``_template_mapper_hotpath`` by skipping the entry (uninitialised ``_physical_shape`` has no spec contribution). - Restore the maintenance-constraint comment above the ``[FASTCACHE][INVALID_FUNC]`` warning so it stays in sync with the ``test_src_ll_cache`` assertion (Cursor deleted-comment check). - Reflow 3 comment/docstring lines to use the project's 120c width. --- python/quadrants/lang/_fast_caching/src_hasher.py | 7 +++---- python/quadrants/lang/_pruning.py | 13 ++++++++----- python/quadrants/lang/_template_mapper.py | 9 ++++++--- python/quadrants/lang/_template_mapper_hotpath.py | 7 ++++++- .../ast_transformers/function_def_transformer.py | 7 ++----- .../lang/ast/test_function_def_transformer.py | 3 +++ .../quadrants/lang/fast_caching/test_src_hasher.py | 4 +--- tests/python/test_data_oriented_ndarray.py | 14 +++++++------- 8 files changed, 36 insertions(+), 28 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index a14a1a7d1b..2d93245397 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -61,7 +61,6 @@ from .hash_utils import hash_iterable_strings from .python_side_cache import PythonSideCache - # Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to # the same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had # no such prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from @@ -110,11 +109,11 @@ def compute_narrow_args_hash( Returns ``None`` if a recognised-but-unsupported tensor-like type forces fastcache off — the caller emits the appropriate user-visible diagnostic via the ``FastcacheSkip.WARN`` branch. """ - args_hash = args_hasher.hash_args( - raise_on_templated_floats, args, arg_metas, pruning_paths=pruning_paths - ) + args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas, pruning_paths=pruning_paths) if isinstance(args_hash, FastcacheSkip): if args_hash is FastcacheSkip.WARN: + # the bit in caps at start should not be modified without modifying corresponding text + # freetext bit can be freely modified _logging.warn( f"[FASTCACHE][INVALID_FUNC] The pure function {kernel_source_info.function_name} could not be " "fast cached, because one or more parameter types were invalid" diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 5245741c71..c17fe64d2d 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -76,8 +76,8 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: # therefore unreliable — in that case ``_predeclare_struct_ndarrays`` falls back to # registering every reachable ndarray (same as the historical behavior). self.pass_0_ran: bool = False - # Kernel-arg-rooted attribute chains used by each func, in flat-name form - # (``__qd_self__qd_dofs__qd_x``). Populated by ``ASTTransformer.build_Attribute`` + # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). + # Populated by ``ASTTransformer.build_Attribute`` # for non-flattened kernel args (data_oriented / qd.template). Kept *separate* from # ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on the enforcing # pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite @@ -202,9 +202,12 @@ def record_after_call( caller_flat, root_id = flat if not root_id.startswith("__qd_"): callee_param_name = kwarg.arg - self._propagate_chain_paths( - callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate - ) + # ``kwarg.arg`` is ``None`` for double-star unpacking (``**kwargs``); + # chain propagation requires a concrete parameter name so just skip. + if callee_param_name is not None: + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 self.used_vars_by_func_id[my_func_id].update(vars_to_unprune) self.kernel_arg_chain_paths_by_func_id[my_func_id].update(chain_paths_to_propagate) diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index ffb384cc18..80fc320a6d 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, TypeAlias +from typing import Any, TypeAlias, cast from weakref import ReferenceType from quadrants.lang import impl @@ -121,10 +121,13 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl for i in self.template_slot_locations: arg = args[i] cls = type(arg) - paths = _arg_nd_paths_or_none.get(cls, _UNCLASSIFIED) - if paths is _UNCLASSIFIED: + cached = _arg_nd_paths_or_none.get(cls, _UNCLASSIFIED) + if cached is _UNCLASSIFIED: paths = _classify_for_args_hash(arg) _arg_nd_paths_or_none[cls] = paths + else: + # Narrow the ``object`` sentinel union back to the actual cached value type. + paths = cast("list[tuple] | None", cached) if paths is None: continue for chain in paths: diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index b5621ea0f1..e00da3fa7a 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -159,9 +159,14 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: v = v._unwrap() if not isinstance(v, Ndarray): continue + # ``Ndarray.shape`` can legitimately be ``None`` (uninitialised ``_physical_shape``); such an instance + # has no meaningful spec contribution, so skip it rather than crashing on ``len(None)``. + shape = v.shape + if shape is None: + continue type_id = id(v.element_type) element_type = type_id if type_id in primitive_types.type_ids else v.element_type - out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout)) + out.append((".".join(chain), element_type, len(shape), v.grad is not None, v._qd_layout)) def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: AnnotationType, arg_name: str) -> Any: diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index dda962d5f9..d1412f4d00 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -241,9 +241,8 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: pass needs every reachable ndarray in the cache for ``build_Attribute`` to resolve the accesses that *will* populate the set). """ - from quadrants.lang.util import cook_dtype # pylint: disable=C0415 - from quadrants.lang._pruning import Pruning # pylint: disable=C0415 + from quadrants.lang.util import cook_dtype # pylint: disable=C0415 cache = ctx.global_context.ndarray_to_any_array launch_info = ctx.global_context.struct_ndarray_launch_info @@ -261,9 +260,7 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: # arg slots get rebound to the wrong ndarrays at launch, and physics silently breaks. prune_from_flat_names = pruning.enforcing and not getattr(pruning, "pass_0_ran", False) kernel_used_flat_names = ( - pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) - if prune_from_flat_names - else None + pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) if prune_from_flat_names else None ) # Cycle-safe walker: Genesis object graphs have cross-references (e.g. solver <-> scene <-> sim) so we must diff --git a/tests/python/quadrants/lang/ast/test_function_def_transformer.py b/tests/python/quadrants/lang/ast/test_function_def_transformer.py index a46d5e2cbc..1408c99014 100644 --- a/tests/python/quadrants/lang/ast/test_function_def_transformer.py +++ b/tests/python/quadrants/lang/ast/test_function_def_transformer.py @@ -81,6 +81,9 @@ def test_process_func_arg(argument_name: str, argument_type: Any, expected_varia class MockContext: def __init__(self) -> None: self.variables: dict[str, Any] = {} + # Mirror the real ``ASTTransformerFuncContext.fn_param_names`` so + # ``_transform_func_arg`` can record bare param names without crashing. + self.fn_param_names: set[str] = set() def create_variable(self, name: str, data: Any) -> None: assert name not in self.variables diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index bd0d81176e..9a0fcfd271 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -112,9 +112,7 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc fileinfos = get_fileinfos([mod.f1.fn, mod.f2.fn]) # L2 key: source+config (L1) + narrow-args-hash. Use an empty narrow-args-hash since the test isn't # exercising args at all — it tests the helper-source-change invalidation logic, which lives in L2. - fast_cache_key = src_hasher.make_full_cache_key( - src_hasher.make_source_config_key(kernel_info), narrow_args_hash="" - ) + fast_cache_key = src_hasher.make_full_cache_key(src_hasher.make_source_config_key(kernel_info), narrow_args_hash="") assert fast_cache_key is not None diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index aec55ddc71..fb5717d924 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -902,10 +902,10 @@ def run(s: qd.template()): def test_is_data_oriented_safe_on_pydantic_like_metaclass(): - """``is_data_oriented`` must not invoke ``__getattr__`` on the class (or metaclass), - so it stays safe in the presence of pathological metaclasses whose ``__getattr__`` - blows the Python recursion limit on arbitrary attribute lookups (e.g. Pydantic's - ``ModelMetaclass`` when probed for a name not in its private-attrs cache).""" + """``is_data_oriented`` must not invoke ``__getattr__`` on the class (or metaclass), so it stays safe in the + presence of pathological metaclasses whose ``__getattr__`` blows the Python recursion limit on arbitrary + attribute lookups (e.g. Pydantic's ``ModelMetaclass`` when probed for a name not in its private-attrs cache). + """ from quadrants.lang.util import is_data_oriented @@ -922,9 +922,9 @@ class Pathological(metaclass=RecursingMeta): @test_utils.test(arch=qd.cpu) def test_data_oriented_with_pydantic_like_child(): - """A ``@qd.data_oriented`` class holding a child whose metaclass has the recursing - ``__getattr__`` (Pydantic-style). Walker must classify the child as non-data-oriented - and continue without blowing the stack.""" + """A ``@qd.data_oriented`` class holding a child whose metaclass has the recursing ``__getattr__`` + (Pydantic-style). Walker must classify the child as non-data-oriented and continue without blowing the stack. + """ N = 4 class RecursingMeta(type): From be4b030fe8ae1e33a8ba151068944fb1a981a81e Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 17:59:46 -0700 Subject: [PATCH 28/46] [Refactor] Move fold_*_into_pruning from Kernel to Pruning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moves ``_fold_struct_nd_paths_into_pruning`` and ``_fold_kernel_arg_chain_paths_into_pruning`` off the ``Kernel`` class onto the ``Pruning`` class in ``_pruning.py`` (as ``fold_struct_nd_paths`` and ``fold_kernel_arg_chain_paths``). Both methods operate exclusively on ``Pruning`` internals — the first only used ``self._struct_ndarray_launch_info_by_key`` and ``self.arg_metas`` (both trivially passed in), the second was already a ``@staticmethod`` whose only parameter was a ``Pruning``. Putting the merge step next to the accumulation primitives (``mark_used``, ``mark_kernel_arg_chain_used``, ``_propagate_chain_paths``) keeps the pruning pipeline in one module. Pure refactor, no behaviour change. --- python/quadrants/lang/_pruning.py | 62 +++++++++++++++++++++++++++++ python/quadrants/lang/kernel.py | 66 +------------------------------ 2 files changed, 64 insertions(+), 64 deletions(-) diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index c17fe64d2d..465a768d08 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -102,6 +102,68 @@ def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None assert not self.enforcing self.kernel_arg_chain_paths_by_func_id[func_id].add(chain_flat_name) + def fold_struct_nd_paths( + self, struct_ndarray_launch_info: list[tuple[Any, int, tuple[str, ...]]], arg_metas: list[ArgMetadata] + ) -> None: + """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat + name set so ``args_hasher.hash_args`` narrow-walks them correctly. + + Background: ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat names + produced by ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* + args. ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` + in the AST and don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths + *are* recorded — in ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — + but only for ndarray members. + + Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` + and union all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the + same convention used for dataclass args. + + Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at + compile, opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk + cannot distinguish "kernel reads this primitive" from "kernel does not read this primitive". The + ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking + *all* attrs of a data_oriented container — narrowing only suppresses subtrees explicitly absent + from the pruning set. So for a data_oriented arg with mostly-ndarray members, the cache key + correctly depends on the ndarray paths it uses; for one with primitive members whose values + matter, those members are still folded into the hash (qualname-fallback / value paths). + """ + if not struct_ndarray_launch_info: + return + kernel_used: set[str] = self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID] + for _arg_id_cpp, arg_idx, attr_chain in struct_ndarray_launch_info: + if arg_idx < 0 or arg_idx >= len(arg_metas): + continue + arg_name = arg_metas[arg_idx].name + if not arg_name: + continue + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + kernel_used.add(flat) + + def fold_kernel_arg_chain_paths(self) -> None: + """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both + compile passes have completed. + + Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain + (e.g. ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into ``kernel_arg_chain_paths_by_func_id`` + rather than ``used_vars_by_func_id``, because the latter is read on the enforcing pass to build + ``struct_locals`` for ``FlattenAttributeNameTransformer``. If chain names appeared there, the + transformer would rewrite ``self.n`` into ``Name('__qd_self__qd_n')`` and ``build_Name`` would + fail to find such a variable. + + Doing the merge here — after pass 1, just like ``fold_struct_nd_paths`` — avoids that interaction + while still making the chain paths available to the fastcache args-hash narrow walk. The set on + ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* object as + ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so updating + one updates both. + """ + kernel_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) + if not kernel_chain_paths: + return + self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_chain_paths) + @staticmethod def _propagate_chain_paths( callee_chain_paths: set[str], diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 23e679a315..880796d905 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -60,7 +60,6 @@ def _kernel_coverage_enabled() -> bool: from quadrants.types.enums import AutodiffMode from quadrants.types.utils import is_signed -from ._dataclass_util import create_flat_name from ._func_base import FuncBase from ._kernel_types import ( ArgsHash, @@ -488,7 +487,7 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . # args walks nothing — every (arg_idx, attr_chain) pair gets the same hash regardless # of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). - self._fold_struct_nd_paths_into_pruning(key, pruning) + pruning.fold_struct_nd_paths(self._struct_ndarray_launch_info_by_key.get(key, []), self.arg_metas) # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested # struct paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` # tracking. Kept separate from ``used_vars_by_func_id`` during compile (would otherwise @@ -496,7 +495,7 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` # assignment to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set # by reference, so the final fastcache L1 entry sees all kernel-accessed paths. - self._fold_kernel_arg_chain_paths_into_pruning(pruning) + pruning.fold_kernel_arg_chain_paths() else: for used_parameters in pruning.used_vars_by_func_id.values(): new_used_parameters = set() @@ -517,67 +516,6 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . # Post-compile fastcache bookkeeping. See ``_maybe_persist_l1_and_set_l2_key`` docstring. self._maybe_persist_l1_and_set_l2_key(key, py_args) - def _fold_struct_nd_paths_into_pruning(self, key: "CompiledKernelKeyType", pruning: Pruning) -> None: - """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat - name set so ``args_hasher.hash_args`` narrow-walks them correctly. - - Background: pruning's ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat - names produced by ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* - args. ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` in - the AST and don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths *are* - recorded — in ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — but only - for ndarray members. - - Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` - and union all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the - same convention used for dataclass args. - - Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at - compile, opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk - cannot distinguish "kernel reads this primitive" from "kernel does not read this primitive". The - ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking *all* - attrs of a data_oriented container — narrowing only suppresses subtrees explicitly absent from the - pruning set. So for a data_oriented arg with mostly-ndarray members, the cache key correctly - depends on the ndarray paths it uses; for one with primitive members whose values matter, those - members are still folded into the hash (qualname-fallback / value paths). - """ - nd_info = self._struct_ndarray_launch_info_by_key.get(key) - if not nd_info: - return - kernel_used: set[str] = pruning.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID] - for _arg_id_cpp, arg_idx, attr_chain in nd_info: - if arg_idx < 0 or arg_idx >= len(self.arg_metas): - continue - arg_name = self.arg_metas[arg_idx].name - if not arg_name: - continue - flat = arg_name - for attr in attr_chain: - flat = create_flat_name(flat, attr) - kernel_used.add(flat) - - @staticmethod - def _fold_kernel_arg_chain_paths_into_pruning(pruning: Pruning) -> None: - """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both - compile passes have completed. - - Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain - (e.g. ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into - ``pruning.kernel_arg_chain_paths_by_func_id`` rather than ``used_vars_by_func_id``, because the - latter is read on the enforcing pass to build ``struct_locals`` for - ``FlattenAttributeNameTransformer``. If chain names appeared there, the transformer would rewrite - ``self.n`` into ``Name('__qd_self__qd_n')`` and ``build_Name`` would fail to find such a variable. - - Doing the merge here — after pass 1, just like ``_fold_struct_nd_paths_into_pruning`` — - avoids that interaction while still making the chain paths available to the fastcache args-hash - narrow walk. The set on ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* - object as ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so - updating one updates both.""" - kernel_chain_paths = pruning.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) - if not kernel_chain_paths: - return - pruning.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_chain_paths) - def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...]) -> None: """After a successful materialize, persist L1 (if missing) and set ``fast_checksum`` to the L2 key. From 75c08f61b4805f8ea94674df1952d6ef155649b2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 19:05:34 -0700 Subject: [PATCH 29/46] [Style] Reflow 3 docstring paragraphs to 120c (Check line wrapping) Cursor's line-wrap check flagged three new docstring paragraphs wrapping at ~85-95c instead of the project's 120c width: - ``ast_transformer.py:_promote_ndarray_if_declared`` - ``ast_transformers/function_def_transformer.py:_predeclare_struct_ndarrays`` - ``kernel_impl.py:data_oriented`` ``stable_members`` docstring Pure prose reflow, no semantic change. --- python/quadrants/lang/ast/ast_transformer.py | 12 ++++----- .../function_def_transformer.py | 14 +++++----- python/quadrants/lang/kernel_impl.py | 26 +++++++++---------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 8836433714..469c0725c2 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -675,13 +675,11 @@ def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> """If *value* is a bare ``Ndarray`` that was pre-declared as a kernel arg (in ``_predeclare_struct_ndarrays``), return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged. - Also records the source ndarray id in ``pruning.used_struct_ndarray_ids`` on the - non-enforcing first pass, so that the enforcing second-pass - ``_predeclare_struct_ndarrays`` can skip ndarrays that the kernel never actually - accesses. Both ``Ndarray`` instances and pre-existing ``AnyArray`` proxies (tagged - with ``_qd_source_ndarray_id``) are handled — the latter is the case for accesses - in inlined ``@qd.func`` bodies whose params were bound to already-promoted proxies - by Option A in ``call_transformer``. + Also records the source ndarray id in ``pruning.used_struct_ndarray_ids`` on the non-enforcing first pass, so + that the enforcing second-pass ``_predeclare_struct_ndarrays`` can skip ndarrays that the kernel never actually + accesses. Both ``Ndarray`` instances and pre-existing ``AnyArray`` proxies (tagged with + ``_qd_source_ndarray_id``) are handled — the latter is the case for accesses in inlined ``@qd.func`` bodies + whose params were bound to already-promoted proxies by Option A in ``call_transformer``. """ from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index d1412f4d00..2f37ef8c28 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -232,14 +232,12 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. - Pruning: in the enforcing (second) compile pass, ``pruning.used_struct_ndarray_ids`` - contains the set of ``id(ndarray)`` values that ``_promote_ndarray_if_declared`` - observed being accessed during the first pass (directly in the kernel body, or - transitively through ``@qd.func`` inlining). We register only those, dropping every - unused ndarray from the kernel's parameter list. On the first pass the set is empty - / not yet populated, so we register everything as today (correctness: the first - pass needs every reachable ndarray in the cache for ``build_Attribute`` to resolve - the accesses that *will* populate the set). + Pruning: in the enforcing (second) compile pass, ``pruning.used_struct_ndarray_ids`` contains the set of + ``id(ndarray)`` values that ``_promote_ndarray_if_declared`` observed being accessed during the first pass + (directly in the kernel body, or transitively through ``@qd.func`` inlining). We register only those, dropping + every unused ndarray from the kernel's parameter list. On the first pass the set is empty / not yet populated, + so we register everything as today (correctness: the first pass needs every reachable ndarray in the cache for + ``build_Attribute`` to resolve the accesses that *will* populate the set). """ from quadrants.lang._pruning import Pruning # pylint: disable=C0415 from quadrants.lang.util import cook_dtype # pylint: disable=C0415 diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 278c451362..ae31f51c75 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -300,20 +300,18 @@ def data_oriented(cls=None, *, stable_members: bool = False): Args: cls (Class): the class to be decorated. - stable_members (bool): launch-context perf hint — if ``True``, declares that the class's - ndarray-typed members are allocated once and never reassigned between kernel calls. - Quadrants will skip the per-call ndarray-reference walk that ``Kernel.launch_kernel`` - uses to detect ndarray reassignment on mutable containers (~1-2 us/call savings on - Genesis-style containers with dozens of ndarray attrs). Reassigning a member on a - ``stable_members`` class is undefined behaviour — the previously-compiled kernel will - be reused even if the new ndarray has different dtype/ndim/layout. May also be set - as a class-level attribute ``_qd_stable_members = True`` (equivalent). - - Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache - argument hashing — the fastcache key is derived from pruning info (the set of flat - names the kernel actually reads), and unrecognised types at kernel-read paths fail - fastcache loudly with a one-shot ``[UNKNOWN_TYPE]`` + ``[INVALID_FUNC]`` diagnostic - (no qualname fallback). See ``docs/source/user_guide/fastcache.md``. + stable_members (bool): launch-context perf hint — if ``True``, declares that the class's ndarray-typed members + are allocated once and never reassigned between kernel calls. Quadrants will skip the per-call ndarray- + reference walk that ``Kernel.launch_kernel`` uses to detect ndarray reassignment on mutable containers + (~1-2 us/call savings on Genesis-style containers with dozens of ndarray attrs). Reassigning a member on a + ``stable_members`` class is undefined behaviour — the previously-compiled kernel will be reused even if + the new ndarray has different dtype/ndim/layout. May also be set as a class-level attribute + ``_qd_stable_members = True`` (equivalent). + + Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache argument hashing — the + fastcache key is derived from pruning info (the set of flat names the kernel actually reads), and + unrecognised types at kernel-read paths fail fastcache loudly with a one-shot ``[UNKNOWN_TYPE]`` + + ``[INVALID_FUNC]`` diagnostic (no qualname fallback). See ``docs/source/user_guide/fastcache.md``. Returns: The decorated class (or, when called with arguments, a decorator). From 29dd841cd6313a4bc4560f929fef26bbaedb9747 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 19:47:54 -0700 Subject: [PATCH 30/46] [Style] Reflow 3 more comment/docstring lines to 120c Cursor line-wrap check flagged 3 lines wrapped at 58-78c instead of 120c: - ``_pruning.py:80`` ``kernel_arg_chain_paths_by_func_id`` comment block - ``_pruning.py:100`` ``mark_kernel_arg_chain_used`` docstring - ``kernel_impl.py:323`` ``make_kernel_indirect`` primal-capture comment Pure prose reflow. --- python/quadrants/lang/_pruning.py | 24 +++++++++++------------- python/quadrants/lang/kernel_impl.py | 7 +++---- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 465a768d08..a8fadd7f9b 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -77,17 +77,15 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: # registering every reachable ndarray (same as the historical behavior). self.pass_0_ran: bool = False # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). - # Populated by ``ASTTransformer.build_Attribute`` - # for non-flattened kernel args (data_oriented / qd.template). Kept *separate* from - # ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on the enforcing - # pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite - # ``s.x`` → ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a - # ``QuadrantsNameError: Name "__qd_s__qd_x" is not defined``. ``record_after_call`` - # propagates entries from callee to caller (so ``f(self.dofs)`` where ``f`` reads - # ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). After both - # compile passes, ``Kernel._fold_kernel_arg_chain_paths_into_pruning`` merges the - # kernel's set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` so fastcache stores them - # in L1 and the args_hasher narrow walk picks them up. + # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / qd.template). + # Kept *separate* from ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on the enforcing + # pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite ``s.x`` → + # ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a ``QuadrantsNameError: Name + # "__qd_s__qd_x" is not defined``. ``record_after_call`` propagates entries from callee to caller (so + # ``f(self.dofs)`` where ``f`` reads ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). + # After both compile passes, ``Pruning.fold_kernel_arg_chain_paths`` merges the kernel's set into + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` so fastcache stores them in L1 and the args_hasher narrow walk + # picks them up. self.kernel_arg_chain_paths_by_func_id: dict[int, set[str]] = defaultdict(set) def mark_used(self, func_id: int, parameter_flat_name: str) -> None: @@ -97,8 +95,8 @@ def mark_used(self, func_id: int, parameter_flat_name: str) -> None: def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None: """Record a kernel-arg-rooted attribute chain (e.g. ``__qd_self__qd_dofs__qd_x``). - Stored separately from ``used_vars_by_func_id`` — see the docstring on - ``kernel_arg_chain_paths_by_func_id`` for why.""" + Stored separately from ``used_vars_by_func_id`` — see the docstring on ``kernel_arg_chain_paths_by_func_id`` + for why.""" assert not self.enforcing self.kernel_arg_chain_paths_by_func_id[func_id].add(chain_flat_name) diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index ae31f51c75..b9e4cd2980 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -320,10 +320,9 @@ def data_oriented(cls=None, *, stable_members: bool = False): return lambda c: data_oriented(c, stable_members=stable_members) def make_kernel_indirect(fun, is_property, attr_name): - # Capture the primal at decoration time so the per-call path skips the - # ``_BoundedDifferentiableMethod`` allocation. The class itself is validated when - # ``_BoundedDifferentiableMethod`` is invoked via the `.grad()` path; for the common - # primal call here we replicate the check inline. + # Capture the primal at decoration time so the per-call path skips the ``_BoundedDifferentiableMethod`` + # allocation. The class itself is validated when ``_BoundedDifferentiableMethod`` is invoked via the + # ``.grad()`` path; for the common primal call here we replicate the check inline. primal = fun._primal @wraps(fun) From 4bd2d1074c99b757891976115fdc478663e015a3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 20:39:47 -0700 Subject: [PATCH 31/46] [Style] Reflow 3 more comment lines to 120c Cursor line-wrap check flagged: - ``ast_transformer.py:821`` ``build_Attribute`` ``Why not mark_used?`` block - ``_quadrants_callable.py:97`` ``__set_name__`` comment - ``_quadrants_callable.py:109`` ``__get__`` non-data-descriptor caching comment Pure prose reflow. --- python/quadrants/lang/_quadrants_callable.py | 13 ++++++------- python/quadrants/lang/ast/ast_transformer.py | 15 +++++++-------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/python/quadrants/lang/_quadrants_callable.py b/python/quadrants/lang/_quadrants_callable.py index 1bd50efe04..0c071c6919 100644 --- a/python/quadrants/lang/_quadrants_callable.py +++ b/python/quadrants/lang/_quadrants_callable.py @@ -94,9 +94,8 @@ def __init__(self, fn: Callable, wrapper: Callable) -> None: update_wrapper(self, fn) def __set_name__(self, owner: type, name: str) -> None: - # Captured at class-body time. ``data_oriented.make_kernel_indirect`` sets this - # explicitly on its replacement callable since setattr-after-class doesn't trigger - # __set_name__. + # Captured at class-body time. ``data_oriented.make_kernel_indirect`` sets this explicitly on its replacement + # callable since setattr-after-class doesn't trigger __set_name__. self._attr_name = name def __call__(self, *args, **kwargs): @@ -106,10 +105,10 @@ def __get__(self, instance, owner): if instance is None: return self bound = BoundQuadrantsCallable(instance, self) - # Non-data descriptor (no __set__): a __dict__ entry on the instance wins over the - # descriptor on subsequent attribute lookups. Stash the bound callable there so future - # ``instance.method`` accesses skip __get__ allocation entirely (~0.6-1.2 us/call). - # Skip if the class uses __slots__ (no __dict__) or the attribute name wasn't captured. + # Non-data descriptor (no __set__): a __dict__ entry on the instance wins over the descriptor on subsequent + # attribute lookups. Stash the bound callable there so future ``instance.method`` accesses skip __get__ + # allocation entirely (~0.6-1.2 us/call). Skip if the class uses __slots__ (no __dict__) or the attribute name + # wasn't captured. name = self._attr_name if name is not None: inst_dict = getattr(instance, "__dict__", None) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 469c0725c2..0b22b066e8 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -818,14 +818,13 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): # data_oriented ``self``); each Attribute access in the chain extends it # (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). # - # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses - # ``pruning.used_vars_by_func_id`` as ``struct_locals``, which drives - # ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would make the transformer - # rewrite ``self.x`` into ``Name('__qd_self__qd_x')``, and ``build_Name`` would then fail to find - # such a variable. ``mark_kernel_arg_chain_used`` puts the chain into a *separate* per-func set - # that's merged into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` only *after* both compile passes, - # by ``Kernel._fold_kernel_arg_chain_paths_into_pruning`` — so the fastcache args-hash narrow walk - # picks them up without breaking codegen. + # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses ``pruning.used_vars_by_func_id`` as + # ``struct_locals``, which drives ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would + # make the transformer rewrite ``self.x`` into ``Name('__qd_self__qd_x')``, and ``build_Name`` would then fail + # to find such a variable. ``mark_kernel_arg_chain_used`` puts the chain into a *separate* per-func set that's + # merged into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` only *after* both compile passes, by + # ``Pruning.fold_kernel_arg_chain_paths`` — so the fastcache args-hash narrow walk picks them up without + # breaking codegen. parent_chain = getattr(node.value, "_qd_arg_chain", None) if parent_chain is not None: flat = create_flat_name(parent_chain, node.attr) From aef1a264a2d601146b878d06b2191425ea0d0f7a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 22:38:46 -0700 Subject: [PATCH 32/46] [Style] Manually reflow underwrapped prose to 120c Hand-inspected runs surfaced by ``find_underwrapped.py --diff --target 120`` across PR #705. Reflowed multi-line comments and docstrings that wrapped at ~80-100c, packing them at the project's 120c target. Left runs alone where long inline-code tokens, error-message formatting, or two-distinct-sentence splits make a tighter layout worse. --- .../lang/_fast_caching/args_hasher.py | 8 +- python/quadrants/lang/_pruning.py | 35 ++-- .../lang/_template_mapper_hotpath.py | 6 +- python/quadrants/lang/ast/ast_transformer.py | 11 +- .../lang/ast/ast_transformer_utils.py | 21 ++- .../ast/ast_transformers/call_transformer.py | 42 ++--- .../function_def_transformer.py | 23 ++- python/quadrants/lang/kernel.py | 15 +- python/quadrants/lang/kernel_impl.py | 4 +- .../lang/ast/test_function_def_transformer.py | 4 +- tests/python/test_ad_dataclass.py | 19 +-- tests/python/test_data_oriented_ndarray.py | 158 ++++++++---------- .../test_data_oriented_qd_func_dataclass.py | 31 ++-- 13 files changed, 174 insertions(+), 203 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 77a2366d4c..c2a5417a5d 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -169,12 +169,12 @@ def dataclass_to_repr( ) -> str | _FailFastcache: """Hash a dataclass instance, optionally narrowed by pruning information. - Returns ``_FAIL_FASTCACHE`` if any field's subtree hits a recognised-but-unsupported tensor type - (``ScalarField`` / ``MatrixField``); otherwise a string. + Returns ``_FAIL_FASTCACHE`` if any field's subtree hits a recognised-but-unsupported tensor type (``ScalarField`` / + ``MatrixField``); otherwise a string. Pruning: if ``pruning_paths`` is non-None, only descend into fields whose flat name is in the set. Pruning's - prefix-expansion step ensures the set already contains all ancestors of used leaves, so checking the - immediate child's flat name is sufficient. + prefix-expansion step ensures the set already contains all ancestors of used leaves, so checking the immediate + child's flat name is sufficient. """ # PERF: For frozen dataclasses the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index a8fadd7f9b..f18d89132d 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -11,14 +11,13 @@ def _flatten_arg_node(node: expr) -> tuple[str, str] | None: - """Flatten an AST arg node into ``(flat_name, root_name_id)`` (or ``None`` if the node isn't a - recognisable name/attribute chain rooted at a plain Name). + """Flatten an AST arg node into ``(flat_name, root_name_id)`` (or ``None`` if the node isn't a recognisable + name/attribute chain rooted at a plain Name). - Returns both the full flat name (e.g. ``__qd_self__qd_dofs`` for ``self.dofs``) and the root - Name's id (``self``). Callers use the root id to distinguish kernel-arg-rooted chains - (``self.dofs`` → root ``self``) from already-flattened dataclass-arg references - (``__qd_self__qd_dofs`` → root ``__qd_self__qd_dofs``). The flat path alone is ambiguous because - ``__qd_self__qd_dofs`` could be either an attribute chain *or* a single flattened Name. + Returns both the full flat name (e.g. ``__qd_self__qd_dofs`` for ``self.dofs``) and the root Name's id (``self``). + Callers use the root id to distinguish kernel-arg-rooted chains (``self.dofs`` → root ``self``) from already- + flattened dataclass-arg references (``__qd_self__qd_dofs`` → root ``__qd_self__qd_dofs``). The flat path alone is + ambiguous because ``__qd_self__qd_dofs`` could be either an attribute chain *or* a single flattened Name. Mirrors ``FlattenAttributeNameTransformer._flatten_attribute_name`` but on the raw call-arg AST. Used by ``record_after_call`` to handle ``f(self.dofs)`` etc. — without this the callee's pruning @@ -65,16 +64,16 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_used_parameters) # only needed for args, not kwargs self.callee_param_by_caller_arg_name_by_func_id: dict[int, dict[str, str]] = defaultdict(dict) - # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. - # Populated by the AST builder when a chain like ``self.x.y`` resolves to an ndarray - # that was pre-declared by ``_predeclare_struct_ndarrays``. On the second (enforcing) - # pass, ``_predeclare_struct_ndarrays`` only registers ndarrays whose id is in this set - # — dropping every reachable-but-unused ndarray from the kernel's parameter list. + # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. Populated by the AST + # builder when a chain like ``self.x.y`` resolves to an ndarray that was pre-declared by + # ``_predeclare_struct_ndarrays``. On the second (enforcing) pass, ``_predeclare_struct_ndarrays`` only + # registers ndarrays whose id is in this set — dropping every reachable-but-unused ndarray from the kernel's + # parameter list. self.used_struct_ndarray_ids: set[int] = set() - # Whether the non-enforcing first pass actually ran for this kernel materialize. - # When fastcache hits, we skip pass 0 entirely and ``used_struct_ndarray_ids`` is - # therefore unreliable — in that case ``_predeclare_struct_ndarrays`` falls back to - # registering every reachable ndarray (same as the historical behavior). + # Whether the non-enforcing first pass actually ran for this kernel materialize. When fastcache hits, we skip + # pass 0 entirely and ``used_struct_ndarray_ids`` is therefore unreliable — in that case + # ``_predeclare_struct_ndarrays`` falls back to registering every reachable ndarray (same as historical + # behavior). self.pass_0_ran: bool = False # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / qd.template). @@ -262,8 +261,8 @@ def record_after_call( caller_flat, root_id = flat if not root_id.startswith("__qd_"): callee_param_name = kwarg.arg - # ``kwarg.arg`` is ``None`` for double-star unpacking (``**kwargs``); - # chain propagation requires a concrete parameter name so just skip. + # ``kwarg.arg`` is ``None`` for double-star unpacking (``**kwargs``); chain propagation requires + # a concrete parameter name so just skip. if callee_param_name is not None: self._propagate_chain_paths( callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index e00da3fa7a..5f3fd26549 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -148,9 +148,9 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: # across instances of the same class. That holds for the typical ``@qd.data_oriented`` container, but Genesis # ``FEMSolver`` / ``MPMSolver`` / ``SPHSolver`` and similar can hold polymorphic children (e.g. ``self.material`` # of a different concrete subclass) or swap a ``qd.Tensor``'s underlying impl between an ``Ndarray`` and a - # ``MatrixField``. When the leaf at a cached path is no longer an ``Ndarray`` we silently skip it: ``v.element_type`` - # / ``v.shape`` / ``v._qd_layout`` are Ndarray-only accessors. The per-instance ``weakref(arg)`` part of the spec - # key still ensures correct cache discrimination across instances. + # ``MatrixField``. When the leaf at a cached path is no longer an ``Ndarray`` we silently skip it: + # ``v.element_type`` / ``v.shape`` / ``v._qd_layout`` are Ndarray-only accessors. The per-instance ``weakref(arg)`` + # part of the spec key still ensures correct cache discrimination across instances. for chain in _struct_nd_paths_for(arg): v = arg for a in chain: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 0b22b066e8..637af4cf83 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -684,10 +684,9 @@ def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 pruning = ctx.global_context.pruning - # Mirror ``build_Name``'s mark_used gate: only mark on the non-enforcing first pass - # and not during synthetic per-leaf argument expansion for ``@qd.func`` calls. The - # callee body's own accesses (which run with ``expanding_dataclass_call_parameters - # = False``) are what we want to count. + # Mirror ``build_Name``'s mark_used gate: only mark on the non-enforcing first pass and not during synthetic + # per-leaf argument expansion for ``@qd.func`` calls. The callee body's own accesses (which run with + # ``expanding_dataclass_call_parameters = False``) are what we want to count. should_mark = not pruning.enforcing and not ctx.expanding_dataclass_call_parameters if isinstance(value, Ndarray): cache = ctx.global_context.ndarray_to_any_array @@ -698,8 +697,8 @@ def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> pruning.used_struct_ndarray_ids.add(key) return arr return value - # Pre-promoted ``AnyArray`` flowing through an inlined ``@qd.func`` body. Mark the - # underlying ndarray as used so it survives the enforcing-pass pruning. + # Pre-promoted ``AnyArray`` flowing through an inlined ``@qd.func`` body. Mark the underlying ndarray as used + # so it survives the enforcing-pass pruning. if should_mark: src_id = getattr(value, "_qd_source_ndarray_id", None) if src_id is not None: diff --git a/python/quadrants/lang/ast/ast_transformer_utils.py b/python/quadrants/lang/ast/ast_transformer_utils.py index 5e0af0ea55..fa784a3522 100644 --- a/python/quadrants/lang/ast/ast_transformer_utils.py +++ b/python/quadrants/lang/ast/ast_transformer_utils.py @@ -247,17 +247,16 @@ def __init__( self.visited_funcdef = False self.is_real_function = is_real_function self.kernel_args: list = [] - # Names of the bare (non-flattened) parameters of a ``@qd.func`` being processed. Used by - # ``build_Name`` to seed ``_qd_arg_chain`` for attribute accesses rooted at a func param - # (e.g. ``static_rigid_sim_config.para_level`` where ``static_rigid_sim_config`` is a - # ``qd.template()`` arg bound to a ``@qd.data_oriented`` instance). Without this, chains - # rooted at func params would not be recorded in pruning, and the args-hasher would skip - # over kernel-read primitive members of nested data_oriented containers — leading to stale - # fastcache hits when those members change between calls. - # ``kernel_args`` only tracks top-level ``@qd.kernel`` args; ``_transform_func_arg`` for a - # ``@qd.func`` does not append to it (see function_def_transformer.py). This separate set - # avoids piggy-backing on ``kernel_args`` so the existing "kernel arg is immutable" - # diagnostic in ``build_assign_annotated`` doesn't start firing for func params. + # Names of the bare (non-flattened) parameters of a ``@qd.func`` being processed. Used by ``build_Name`` to + # seed ``_qd_arg_chain`` for attribute accesses rooted at a func param (e.g. + # ``static_rigid_sim_config.para_level`` where ``static_rigid_sim_config`` is a ``qd.template()`` arg bound to + # a ``@qd.data_oriented`` instance). Without this, chains rooted at func params would not be recorded in + # pruning, and the args-hasher would skip over kernel-read primitive members of nested data_oriented + # containers — leading to stale fastcache hits when those members change between calls. + # ``kernel_args`` only tracks top-level ``@qd.kernel`` args; ``_transform_func_arg`` for a ``@qd.func`` does + # not append to it (see function_def_transformer.py). This separate set avoids piggy-backing on + # ``kernel_args`` so the existing "kernel arg is immutable" diagnostic in ``build_assign_annotated`` doesn't + # start firing for func params. self.fn_param_names: set[str] = set() self.only_parse_function_def: bool = False self.autodiff_mode = autodiff_mode diff --git a/python/quadrants/lang/ast/ast_transformers/call_transformer.py b/python/quadrants/lang/ast/ast_transformers/call_transformer.py index d9ae1a66b7..2bc22e8650 100644 --- a/python/quadrants/lang/ast/ast_transformers/call_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/call_transformer.py @@ -172,14 +172,11 @@ def _expand_Call_dataclass_args( callee_arg_names: list[str] | None = None, ) -> tuple[tuple[ast.stmt, ...], tuple[ast.stmt, ...]]: """ - We require that each node has a .ptr attribute added to it, that contains - the associated Python object. + We require that each node has a .ptr attribute added to it, that contains the associated Python object. - ``called_needed`` and ``callee_arg_names`` are used only for the - attribute-accessed-instance branch (Option A for data_oriented @qd.func calls): - the caller cannot construct a flat name from its own ``arg.id`` (the arg is - an ast.Attribute), so we look up pruning against the callee's parameter name - at the same positional index. + ``called_needed`` and ``callee_arg_names`` are used only for the attribute-accessed-instance branch (Option A + for data_oriented @qd.func calls): the caller cannot construct a flat name from its own ``arg.id`` (the arg is + an ast.Attribute), so we look up pruning against the callee's parameter name at the same positional index. """ args_new = [] added_args = [] @@ -214,15 +211,14 @@ def _expand_Call_dataclass_args( args_new.append(arg_node) added_args.append(arg_node) elif dataclasses.is_dataclass(val) and not isinstance(val, type): - # Dataclass *instance* passed positionally (e.g. ``self.state`` inside a - # @qd.data_oriented kernel method). Expand into per-leaf attribute accesses - # against the same AST node, mirroring the typed-arg (instance-of-type) path - # above but emitting ``ast.Attribute`` children rather than ``ast.Name``. - # ``added_args`` items must not carry ``.ptr`` (build_stmt populates it - # downstream); only the intermediate node used for recursion does. + # Dataclass *instance* passed positionally (e.g. ``self.state`` inside a @qd.data_oriented kernel + # method). Expand into per-leaf attribute accesses against the same AST node, mirroring the typed-arg + # (instance-of-type) path above but emitting ``ast.Attribute`` children rather than ``ast.Name``. + # ``added_args`` items must not carry ``.ptr`` (build_stmt populates it downstream); only the + # intermediate node used for recursion does. dataclass_type = type(val) - # For pruning, match the callee's flat name (it may have pruned unused - # fields). Use the callee's parameter name at this positional index. + # For pruning, match the callee's flat name (it may have pruned unused fields). Use the callee's + # parameter name at this positional index. callee_param = ( callee_arg_names[arg_idx] if (called_needed is not None and callee_arg_names is not None and arg_idx < len(callee_arg_names)) @@ -246,9 +242,8 @@ def _expand_Call_dataclass_args( ) if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): child_node.ptr = child_val - # Recurse, threading the renamed scope: the callee's expanded flat - # name (e.g. ``__qd_state__inner``) is the synthetic param name for - # the nested level. + # Recurse, threading the renamed scope: the callee's expanded flat name (e.g. + # ``__qd_state__inner``) is the synthetic param name for the nested level. nested_callee_param = ( create_flat_name(callee_param, field.name) if callee_param is not None else None ) @@ -321,10 +316,9 @@ def _expand_Call_dataclass_kwargs( kwargs_new.append(kwarg_node) added_kwargs.append(kwarg_node) elif dataclasses.is_dataclass(val) and not isinstance(val, type): - # Dataclass *instance* passed as a keyword arg (e.g. - # ``write(state=self.state)`` inside a @qd.data_oriented kernel method). - # Expand into per-leaf keyword args whose values are attribute accesses - # against the original value node (e.g. ``__qd_state__x=self.state.x``). + # Dataclass *instance* passed as a keyword arg (e.g. ``write(state=self.state)`` inside a + # @qd.data_oriented kernel method). Expand into per-leaf keyword args whose values are attribute + # accesses against the original value node (e.g. ``__qd_state__x=self.state.x``). dataclass_type = type(val) for field in dataclasses.fields(dataclass_type): child_name = create_flat_name(kwarg.arg, field.name) @@ -391,8 +385,8 @@ def build_Call(ctx: ASTTransformerFuncContext, node: ast.Call, build_stmt, build called_func_id_ = func.wrapper.func_id # type: ignore called_needed = pruning.used_vars_by_func_id[called_func_id_] if is_func_base_wrapper: - # callee param names (used by the attribute-instance positional-expansion path - # so it can match the callee's already-pruned flat names). + # callee param names (used by the attribute-instance positional-expansion path so it can match the + # callee's already-pruned flat names). try: callee_arg_names = [m.name for m in func.wrapper.arg_metas] # type: ignore[attr-defined] except AttributeError: diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 2f37ef8c28..ee5cc5accb 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -246,16 +246,15 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: launch_info = ctx.global_context.struct_ndarray_launch_info pruning = ctx.global_context.pruning used_ids = getattr(pruning, "used_struct_ndarray_ids", None) - # Only prune on the enforcing pass when we actually ran pass 0 to populate the - # used-ndarray set. On a fastcache hit pass 0 is skipped and the set is empty. + # Only prune on the enforcing pass when we actually ran pass 0 to populate the used-ndarray set. On a + # fastcache hit pass 0 is skipped and the set is empty. prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) - # On a fastcache hit (enforcing without a pass-0 run), the `id(nd)` set is empty, but the - # *flat-name* set on ``used_vars_by_func_id[KERNEL_FUNC_ID]`` was loaded from cache and - # already contains every kernel-accessed leaf path (folded in by - # ``_fold_struct_nd_paths_into_pruning`` during the compile that produced the cache entry). - # Use that to prune the walk so we register the exact same ndarray set as the originating - # compile produced — without this, every reachable ndarray gets registered, the kernel's - # arg slots get rebound to the wrong ndarrays at launch, and physics silently breaks. + # On a fastcache hit (enforcing without a pass-0 run), the `id(nd)` set is empty, but the *flat-name* set on + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` was loaded from cache and already contains every kernel-accessed + # leaf path (folded in by ``Pruning.fold_struct_nd_paths`` during the compile that produced the cache entry). + # Use that to prune the walk so we register the exact same ndarray set as the originating compile produced — + # without this, every reachable ndarray gets registered, the kernel's arg slots get rebound to the wrong + # ndarrays at launch, and physics silently breaks. prune_from_flat_names = pruning.enforcing and not getattr(pruning, "pass_0_ran", False) kernel_used_flat_names = ( pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) if prune_from_flat_names else None @@ -325,9 +324,9 @@ def _register_ndarray(nd, arg_idx, attr_chain): _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), _qd_layout=layout, ) - # Tag the AnyArray with the source ndarray id so ``_promote_ndarray_if_declared`` - # can mark this ndarray as used even when the access reaches it via an already- - # promoted AnyArray (e.g. callee bodies bound to per-leaf args by Option A). + # Tag the AnyArray with the source ndarray id so ``_promote_ndarray_if_declared`` can mark this ndarray + # as used even when the access reaches it via an already-promoted AnyArray (e.g. callee bodies bound to + # per-leaf args by Option A). arr._qd_source_ndarray_id = key cache[key] = arr launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 880796d905..d7dc4e6db6 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -522,8 +522,8 @@ def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args Called at the end of ``materialize`` once both passes have completed (or once pass 1 has completed with a loaded artifact). Two responsibilities: - 1. If L1 was missing (``self._pruning_paths_from_l1 is None``), write the freshly-computed - pruning info so the next call from a new process can skip the args-walk warm-up. + 1. If L1 was missing (``self._pruning_paths_from_l1 is None``), write the freshly-computed pruning info so + the next call from a new process can skip the args-walk warm-up. 2. If ``fast_checksum`` is still None (which means either L1 was missing, or L1 hit but phase 2 of ``_try_load_fastcache`` saw a FIELD-related FastcacheSkip — in which case we keep ``None`` @@ -587,12 +587,11 @@ def launch_kernel( if self._struct_ndarray_launch_info_by_key: struct_nd_info = self._struct_ndarray_launch_info_by_key.get(key) if struct_nd_info: - # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated - # with ``@qd.data_oriented(stable_members=True)``) promise their ndarray - # members are never reassigned, so we exclude them from the per-call - # ``_resolve_struct_ndarray`` walk that builds ``args_hash``. This is a - # *launch-time perf hint only* and has no fastcache role — fastcache derives - # its key from kernel-pruning info regardless of this flag. + # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated with + # ``@qd.data_oriented(stable_members=True)``) promise their ndarray members are never reassigned, + # so we exclude them from the per-call ``_resolve_struct_ndarray`` walk that builds ``args_hash``. + # This is a *launch-time perf hint only* and has no fastcache role — fastcache derives its key + # from kernel-pruning info regardless of this flag. self._mutable_nd_cached_val = [ (idx, chain) for _, idx, chain in struct_nd_info diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index b9e4cd2980..764435fdf8 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -335,8 +335,8 @@ def _kernel_indirect(self, *args, **kwargs): raise type(e)("\n" + str(e)) from None ret = QuadrantsCallable(fun, _kernel_indirect) - # setattr-after-class doesn't trigger __set_name__; set the name explicitly so - # QuadrantsCallable.__get__ can cache the BoundQuadrantsCallable on instance.__dict__. + # setattr-after-class doesn't trigger __set_name__; set the name explicitly so QuadrantsCallable.__get__ can + # cache the BoundQuadrantsCallable on instance.__dict__. ret._attr_name = attr_name if is_property: ret = property(ret) diff --git a/tests/python/quadrants/lang/ast/test_function_def_transformer.py b/tests/python/quadrants/lang/ast/test_function_def_transformer.py index 1408c99014..20b422b20b 100644 --- a/tests/python/quadrants/lang/ast/test_function_def_transformer.py +++ b/tests/python/quadrants/lang/ast/test_function_def_transformer.py @@ -81,8 +81,8 @@ def test_process_func_arg(argument_name: str, argument_type: Any, expected_varia class MockContext: def __init__(self) -> None: self.variables: dict[str, Any] = {} - # Mirror the real ``ASTTransformerFuncContext.fn_param_names`` so - # ``_transform_func_arg`` can record bare param names without crashing. + # Mirror the real ``ASTTransformerFuncContext.fn_param_names`` so ``_transform_func_arg`` can record bare + # param names without crashing. self.fn_param_names: set[str] = set() def create_variable(self, name: str, data: Any) -> None: diff --git a/tests/python/test_ad_dataclass.py b/tests/python/test_ad_dataclass.py index dfadb24a9b..5b5d6dab27 100644 --- a/tests/python/test_ad_dataclass.py +++ b/tests/python/test_ad_dataclass.py @@ -7,14 +7,14 @@ * ``qd.field`` — ``qd.template()`` path; gradient via ``qd.ad.Tape``. * ``qd.tensor(backend=NDARRAY)`` — same path as ``qd.ndarray``; the dispatcher returns a wrapper whose ndarray ``_impl`` is unwrapped by the dataclass-annotation infrastructure. -* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` - (or ``qd.template()``). With ``object`` / no annotation the wrapper survives into kernel scope - and host-side ``__getitem__`` asserts. +* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` (or + ``qd.template()``). With ``object`` / no annotation the wrapper survives into kernel scope and host-side + ``__getitem__`` asserts. * mixed — single dataclass holding both a ``qd.ndarray`` and a ``qd.field`` member. Pattern mirrors ``test_ad_ndarray.py`` (ndarray) and ``test_ad_basics.py`` (field). See -``docs/source/user_guide/compound_types.md`` overview table — column "supports differentiation?" -for ``dataclasses.dataclass``. +``docs/source/user_guide/compound_types.md`` overview table — column "supports differentiation?" for +``dataclasses.dataclass``. """ import dataclasses @@ -183,11 +183,10 @@ def compute(s: State): def test_ad_dataclass_tensor_field_backend_tape(): """dataclass holding qd.tensor(..., backend=FIELD) members; field-AD via qd.ad.Tape. - Note: members must be annotated as ``qd.Tensor`` (not ``object``) when the value is a - ``qd.tensor(...)`` wrapper. The typed-dataclass / template machinery uses the member - annotation to decide whether to unwrap the wrapper into its underlying impl before the - kernel sees ``s.x[i]``. With ``object`` annotation the wrapper survives into kernel scope - and its host-side ``__getitem__`` asserts. + Note: members must be annotated as ``qd.Tensor`` (not ``object``) when the value is a ``qd.tensor(...)`` wrapper. + The typed-dataclass / template machinery uses the member annotation to decide whether to unwrap the wrapper into + its underlying impl before the kernel sees ``s.x[i]``. With ``object`` annotation the wrapper survives into + kernel scope and its host-side ``__getitem__`` asserts. """ N = 5 diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index fb5717d924..e844e73554 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1,17 +1,16 @@ """Tests for ``@qd.data_oriented`` classes whose members are raw ``qd.ndarray`` (not ``qd.field``, not ``qd.Tensor`` wrappers). -The user-guide doc ``docs/source/user_guide/compound_types.md`` claims this pattern is not supported -("can contain ndarray? no" for ``@qd.data_oriented``). But the in-tree error message in -``python/quadrants/lang/impl.py`` lists ``@qd.data_oriented / frozen-dataclass template`` as a -*supported* route, and the ndarray-in-struct infrastructure added by ``#561 [Type] Tensor 24`` -(2026-04-28) — specifically ``_predeclare_struct_ndarrays`` in +The user-guide doc ``docs/source/user_guide/compound_types.md`` claims this pattern is not supported ("can contain +ndarray? no" for ``@qd.data_oriented``). But the in-tree error message in ``python/quadrants/lang/impl.py`` lists +``@qd.data_oriented / frozen-dataclass template`` as a *supported* route, and the ndarray-in-struct infrastructure +added by ``#561 [Type] Tensor 24`` (2026-04-28) — specifically ``_predeclare_struct_ndarrays`` in ``python/quadrants/lang/ast/ast_transformers/function_def_transformer.py`` — explicitly walks both -``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the -data_oriented case. +``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the data_oriented +case. -This file pins what actually works, and documents the gaps. See -``perso_hugh/doc/data_oriented_ndarray.md`` for the design analysis. +This file pins what actually works, and documents the gaps. See ``perso_hugh/doc/data_oriented_ndarray.md`` for the +design analysis. """ import dataclasses @@ -285,10 +284,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9b. Fastcache end-to-end with ``@qd.data_oriented`` holding ndarrays. Pattern adapted from -# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory -# to simulate two processes, monkeypatch ``launch_kernel`` to capture whether -# ``compiled_kernel_data`` was loaded from disk. On the second init the data_oriented + ndarray -# kernel should be served from the on-disk fastcache. +# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory to simulate two +# processes, monkeypatch ``launch_kernel`` to capture whether ``compiled_kernel_data`` was loaded from disk. On +# the second init the data_oriented + ndarray kernel should be served from the on-disk fastcache. # --------------------------------------------------------------------------- @@ -300,10 +298,9 @@ def test_data_oriented_ndarray_fastcache_cross_init(tmp_path, monkeypatch): captured_compiled_kernel_data = [] def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): - # Filter to the user kernel only; .to_numpy() launches an internal - # ``ndarray_to_ext_arr`` kernel that is not fastcache-eligible - # (is_pure=False) and would always make compiled_kernel_data=None, - # masking the actual fastcache behaviour of ``run``. + # Filter to the user kernel only; .to_numpy() launches an internal ``ndarray_to_ext_arr`` kernel that is not + # fastcache-eligible (is_pure=False) and would always make compiled_kernel_data=None, masking the actual + # fastcache behaviour of ``run``. if self.func.__name__ == "run": captured_compiled_kernel_data.append(compiled_kernel_data) return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) @@ -334,9 +331,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9c. Same as 9b but with a *nested* ``@qd.data_oriented`` holding an ndarray. Pins that the -# fastcache args_hasher recursion handles nested data_oriented containers correctly across -# processes. +# 9c. Same as 9b but with a *nested* ``@qd.data_oriented`` holding an ndarray. Pins that the fastcache args_hasher +# recursion handles nested data_oriented containers correctly across processes. # --------------------------------------------------------------------------- @@ -383,9 +379,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9d. Fastcache key is dtype-sensitive: same kernel source, different ndarray dtype in the -# data_oriented member -> two distinct disk cache entries. Pins the args_hasher's -# ``[nd-{dtype}-{ndim}{layout}]`` repr. +# 9d. Fastcache key is dtype-sensitive: same kernel source, different ndarray dtype in the data_oriented member -> +# two distinct disk cache entries. Pins the args_hasher's ``[nd-{dtype}-{ndim}{layout}]`` repr. # --------------------------------------------------------------------------- @@ -433,10 +428,10 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the -# whole call (args_hasher returns None for ScalarField). The kernel still runs correctly via -# non-fastcache compilation. This test pins the documented fallback so a future "support -# fields in fastcache" change explicitly chooses to update this test. +# 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the whole call +# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. +# This test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to +# update this test. # --------------------------------------------------------------------------- @@ -517,9 +512,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported -# route still works; if this fails, the test environment itself is broken, not the data_oriented -# path). +# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported route still +# works; if this fails, the test environment itself is broken, not the data_oriented path). # --------------------------------------------------------------------------- @@ -544,9 +538,8 @@ def run(s: State): # --------------------------------------------------------------------------- -# 12. data_oriented holding a (frozen) dataclass that holds an ndarray. -# Exercises the ``else`` branch of ``_walk_obj`` recursing through a dataclass child — added by -# the Bug 1 fix. +# 12. data_oriented holding a (frozen) dataclass that holds an ndarray. Exercises the ``else`` branch of +# ``_walk_obj`` recursing through a dataclass child — added by the Bug 1 fix. # --------------------------------------------------------------------------- @@ -721,9 +714,9 @@ def fill_y_from_x(s: qd.template()): # --------------------------------------------------------------------------- -# 17. data_oriented + ndarray + @qd.func sub-call. Pins that the AST-time attribute resolution in -# ``build_Attribute`` (which uses the predeclared AnyArray cache) works when the access happens -# inside a func, not just the top-level kernel. +# 17. data_oriented + ndarray + @qd.func sub-call. Pins that the AST-time attribute resolution in ``build_Attribute`` +# (which uses the predeclared AnyArray cache) works when the access happens inside a func, not just the top-level +# kernel. # --------------------------------------------------------------------------- @@ -836,10 +829,9 @@ def run_f32(s: qd.template()): # --------------------------------------------------------------------------- -# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly -# pointing the user to ``qd.template()``. The two patterns are incompatible at the kernel-arg -# layer: dataclass kernel args are flattened using annotations, data_oriented containers need a -# value-driven walk. Pins the helpful error message. +# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly pointing the user to +# ``qd.template()``. The two patterns are incompatible at the kernel-arg layer: dataclass kernel args are +# flattened using annotations, data_oriented containers need a value-driven walk. Pins the helpful error message. # --------------------------------------------------------------------------- @@ -891,13 +883,11 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 22. Robustness: object graphs with Pydantic-style metaclass ``__getattr__`` recursion, -# and cyclic attribute references. Real-world container classes (notably Genesis's -# ``RigidOptions`` / ``SimOptions``) inherit from ``pydantic.BaseModel`` whose -# ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attributes. -# Quadrants' walker must not blow the stack when it traverses a ``data_oriented`` arg -# that contains such an object, or that contains a back-reference to itself / its -# parent (e.g. ``solver.scene.solver``). +# 22. Robustness: object graphs with Pydantic-style metaclass ``__getattr__`` recursion, and cyclic attribute +# references. Real-world container classes (notably Genesis's ``RigidOptions`` / ``SimOptions``) inherit from +# ``pydantic.BaseModel`` whose ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attributes. +# Quadrants' walker must not blow the stack when it traverses a ``data_oriented`` arg that contains such an +# object, or that contains a back-reference to itself / its parent (e.g. ``solver.scene.solver``). # --------------------------------------------------------------------------- @@ -996,8 +986,8 @@ def run_field(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_with_cyclic_attr_graph(): - """A ``@qd.data_oriented`` class whose attribute graph contains a cycle - (``parent.child.parent is parent``). Walker must not re-enter the cycle.""" + """A ``@qd.data_oriented`` class whose attribute graph contains a cycle (``parent.child.parent is parent``). + Walker must not re-enter the cycle.""" N = 4 @qd.data_oriented @@ -1027,8 +1017,7 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # Pruning-driven fastcache behaviour for @qd.data_oriented containers. # -# These pin the three rules enforced by the args hasher (see fastcache.md -# "Pruning-driven argument hashing"): +# These pin the three rules enforced by the args hasher (see fastcache.md "Pruning-driven argument hashing"): # 1. The cache key may only include contributions from kernel-pruned paths. # 2. Unrecognised types at kernel-read paths must not be silently dropped. # 3. Fastcache works for @qd.data_oriented kernel args end-to-end. @@ -1039,9 +1028,9 @@ def run(s: qd.template()): def test_data_oriented_kernel_unused_opaque_member_does_not_affect_cache(tmp_path, monkeypatch): """Rule 1: kernel-unused opaque members do not affect the fastcache key. - Two ``State`` instances differ only in an opaque ``uuid`` member that the kernel never reads. - Both must hit the same compiled artifact on the second process — proof that the args hasher's - pruning narrow walk skips the opaque attribute (no qualname-fallback, no spurious miss).""" + Two ``State`` instances differ only in an opaque ``uuid`` member that the kernel never reads. Both must hit the + same compiled artifact on the second process — proof that the args hasher's pruning narrow walk skips the opaque + attribute (no qualname-fallback, no spurious miss).""" import uuid from quadrants._test_tools import qd_init_same_arch @@ -1073,8 +1062,8 @@ def run(s: qd.template()): run(a) run(b) - # Second process: cold-start, must load from disk. If the uuid had leaked into the cache key, - # different uuid → different L2 key → no artifact would load. + # Second process: cold-start, must load from disk. If the uuid had leaked into the cache key, different uuid → + # different L2 key → no artifact would load. qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) a = State(x=qd.ndarray(qd.i32, shape=(4,))) b = State(x=qd.ndarray(qd.i32, shape=(4,))) @@ -1087,9 +1076,8 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_kernel_read_opaque_member_fails_fastcache(tmp_path, capfd) -> None: - """Rule 2: when the kernel actually reads an unrecognised-type member, fastcache fails loudly - with [UNKNOWN_TYPE] + [INVALID_FUNC] — no silent drop, no qualname fallback. The kernel still - runs via normal compilation.""" + """Rule 2: when the kernel actually reads an unrecognised-type member, fastcache fails loudly with [UNKNOWN_TYPE] + + [INVALID_FUNC] — no silent drop, no qualname fallback. The kernel still runs via normal compilation.""" from quadrants._test_tools import qd_init_same_arch from quadrants.lang._fast_caching.args_hasher import reset_unknown_type_warn_state @@ -1128,10 +1116,9 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_kernel_read_primitive_distinguishes_cache_key(tmp_path, monkeypatch) -> None: - """Rule 3 (data_oriented works) + pruning correctness: when the kernel reads a primitive member, - its value is baked into the kernel and must drive a distinct cache entry per value. Two State - instances differing only in ``n`` (read by the kernel) cold-compile separately and both load - from disk on the second process.""" + """Rule 3 (data_oriented works) + pruning correctness: when the kernel reads a primitive member, its value is + baked into the kernel and must drive a distinct cache entry per value. Two State instances differing only in + ``n`` (read by the kernel) cold-compile separately and both load from disk on the second process.""" from quadrants._test_tools import qd_init_same_arch launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel @@ -1174,8 +1161,8 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_kernel_unread_primitive_does_not_affect_cache(tmp_path, monkeypatch) -> None: - """Rule 1: kernel-unused primitive members do not affect the cache key. Mirror of the opaque - case for primitives. Two State instances differing only in ``unused_n`` must share the cache.""" + """Rule 1: kernel-unused primitive members do not affect the cache key. Mirror of the opaque case for + primitives. Two State instances differing only in ``unused_n`` must share the cache.""" from quadrants._test_tools import qd_init_same_arch launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel @@ -1216,10 +1203,10 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_qd_func_chain_propagation_distinguishes_cache_key(tmp_path, monkeypatch) -> None: - """Pruning chain propagation through ``@qd.func`` calls (``record_after_call`` extension): - when the kernel calls ``f(self.dofs)`` and ``f`` reads ``s.x``, the kernel's pruning set - must include ``__qd_self__qd_dofs__qd_x`` so that changes to the inner ndarray's dtype - invalidate the cache. Two States differing in ``dofs.x``'s dtype must cold-compile separately.""" + """Pruning chain propagation through ``@qd.func`` calls (``record_after_call`` extension): when the kernel calls + ``f(self.dofs)`` and ``f`` reads ``s.x``, the kernel's pruning set must include ``__qd_self__qd_dofs__qd_x`` so + that changes to the inner ndarray's dtype invalidate the cache. Two States differing in ``dofs.x``'s dtype must + cold-compile separately.""" from quadrants._test_tools import qd_init_same_arch launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel @@ -1267,23 +1254,22 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key(tmp_path, monkeypatch) -> None: - """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested - data_oriented containers. - - Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the - caller-side arg flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — - ``self.cfg`` → ``__qd_self__qd_cfg``). When that happened, primitive members read inside the - callee (``cfg.n`` → ``__qd_cfg__qd_n`` in the callee's chain set) never made it into the kernel's - pruning set, so the args-hasher walked ``self.cfg`` as data_oriented and found no pruned children, - yielding an identical hash for *any* value of ``cfg.n``. Two configs that should produce - different kernels (different ``range(s.cfg.n)`` trip counts baked into codegen) would then share - a fastcache entry — leading to stale-kernel hits and silent miscompiles (e.g. Genesis' - ``test_ndarray_no_compile`` was failing with iter-N kernels reused for iter-N+1 scenes that have - a different ``RigidSimStaticConfig.para_level`` baked into their ``qd.static`` branches). - - The fix in ``_pruning.py`` gates propagation on the *root Name* of the chain (``self``, not the - flat result), so both ``f(self)`` and ``f(self.cfg)`` propagate, while already-flattened - dataclass refs (``Name('__qd_state__qd_x')``) are still skipped.""" + """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested data_oriented + containers. + + Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the caller-side arg + flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — ``self.cfg`` → + ``__qd_self__qd_cfg``). When that happened, primitive members read inside the callee (``cfg.n`` → + ``__qd_cfg__qd_n`` in the callee's chain set) never made it into the kernel's pruning set, so the args-hasher + walked ``self.cfg`` as data_oriented and found no pruned children, yielding an identical hash for *any* value of + ``cfg.n``. Two configs that should produce different kernels (different ``range(s.cfg.n)`` trip counts baked into + codegen) would then share a fastcache entry — leading to stale-kernel hits and silent miscompiles (e.g. Genesis' + ``test_ndarray_no_compile`` was failing with iter-N kernels reused for iter-N+1 scenes that have a different + ``RigidSimStaticConfig.para_level`` baked into their ``qd.static`` branches). + + The fix in ``_pruning.py`` gates propagation on the *root Name* of the chain (``self``, not the flat result), so + both ``f(self)`` and ``f(self.cfg)`` propagate, while already-flattened dataclass refs + (``Name('__qd_state__qd_x')``) are still skipped.""" from quadrants._test_tools import qd_init_same_arch launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py index 20d620aa97..b504aae31c 100644 --- a/tests/python/test_data_oriented_qd_func_dataclass.py +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -1,14 +1,12 @@ -"""Tests for calling @qd.func that takes a typed-dataclass arg, from a @qd.kernel -method of a @qd.data_oriented class, passing ``self.dataclass_member`` as the arg. +"""Tests for calling @qd.func that takes a typed-dataclass arg, from a @qd.kernel method of a @qd.data_oriented +class, passing ``self.dataclass_member`` as the arg. -Genesis's @qd.func helpers declare typed-dataclass parameters (e.g. -``def func(links_state: LinksState, ...):``) and are designed to be called from kernels -that also take typed-dataclass kernel args (so the dataclass is flattened into per-leaf -kernel-locals on both sides of the call boundary). +Genesis's @qd.func helpers declare typed-dataclass parameters (e.g. ``def func(links_state: LinksState, ...):``) and +are designed to be called from kernels that also take typed-dataclass kernel args (so the dataclass is flattened into +per-leaf kernel-locals on both sides of the call boundary). -When migrating Genesis modules to @qd.data_oriented, we'd like to call the same @qd.func -helpers from a data_oriented kernel method, passing ``self.links_state`` as the arg. -Today this fails at AST resolution: +When migrating Genesis modules to @qd.data_oriented, we'd like to call the same @qd.func helpers from a data_oriented +kernel method, passing ``self.links_state`` as the arg. Today this fails at AST resolution: Missing argument '__qd_links_state__qd_cinr_inertial'. Unexpected argument 'links_state'. @@ -185,8 +183,8 @@ def run(self): @test_utils.test(arch=qd.cpu) def test_data_oriented_method_calls_qd_func_with_nested_dataclass_member(): - """data_oriented holds an Outer{ Inner{ ndarray } } and passes ``self.outer`` to a - @qd.func that expands the nested dataclass into flat leaves on both sides.""" + """data_oriented holds an Outer{ Inner{ ndarray } } and passes ``self.outer`` to a @qd.func that expands the + nested dataclass into flat leaves on both sides.""" N = 4 @dataclasses.dataclass @@ -253,10 +251,9 @@ def run(self): @test_utils.test(arch=qd.cpu) def test_data_oriented_method_qd_func_chain_with_dataclass_member(): - """data_oriented kernel calls outer @qd.func, which in turn calls inner @qd.func, - threading the same dataclass arg through. Both qd.funcs have the typed-dataclass - parameter; only the outermost call site (data_oriented method body) uses self.X. - The two inner call sites use the typed-arg path that already worked.""" + """data_oriented kernel calls outer @qd.func, which in turn calls inner @qd.func, threading the same dataclass + arg through. Both qd.funcs have the typed-dataclass parameter; only the outermost call site (data_oriented method + body) uses self.X. The two inner call sites use the typed-arg path that already worked.""" N = 4 @dataclasses.dataclass @@ -288,8 +285,8 @@ def run(self): @test_utils.test(arch=qd.cpu) def test_data_oriented_method_qd_func_chain_with_nested_dataclass_member(): - """Combination: nested dataclass passed through a chain of two @qd.func calls - from a @qd.data_oriented self-method via self.outer.""" + """Combination: nested dataclass passed through a chain of two @qd.func calls from a @qd.data_oriented + self-method via self.outer.""" N = 4 @dataclasses.dataclass From 197d150c99843b3a776933aa68dc8d969fb2c7d9 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 22:52:06 -0700 Subject: [PATCH 33/46] [Style] Reflow more underwrapped prose to 120c (round 2) More aggressive pass through the find_underwrapped report: pack docstrings, multi-line comments and section dividers to the 120c target where the prose isn't intentionally on separate lines (bullet items, paragraph breaks, side-by-side error-message reproductions are left alone). --- .../lang/_fast_caching/args_hasher.py | 79 ++++++++-------- .../lang/_fast_caching/src_hasher.py | 54 +++++------ python/quadrants/lang/_pruning.py | 94 +++++++++---------- .../lang/_template_mapper_hotpath.py | 4 +- python/quadrants/lang/ast/ast_transformer.py | 28 +++--- .../function_def_transformer.py | 17 ++-- python/quadrants/lang/kernel.py | 37 ++++---- python/quadrants/lang/kernel_impl.py | 4 +- .../test_fastcache_field_warnings.py | 9 +- .../lang/fast_caching/test_src_ll_cache.py | 12 +-- tests/python/test_ad_dataclass.py | 5 +- tests/python/test_data_oriented_ndarray.py | 73 +++++++------- tests/python/test_template_typing.py | 8 +- 13 files changed, 202 insertions(+), 222 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index c2a5417a5d..91e3c72029 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -46,8 +46,8 @@ # - Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). # - Unrecognised type at a kernel-read path (no qualname fallback — see rules in fastcache.md). # -# Containers (``dataclass_to_repr``, ``data_oriented`` branch, top-level ``hash_args`` loop) must propagate it -# upward — fastcache is disabled for the whole call and the caller writes the appropriate diagnostic. +# Containers (``dataclass_to_repr``, ``data_oriented`` branch, top-level ``hash_args`` loop) must propagate it upward +# — fastcache is disabled for the whole call and the caller writes the appropriate diagnostic. class _FailFastcache: """Singleton sentinel; identity-compared.""" @@ -178,12 +178,11 @@ def dataclass_to_repr( """ # PERF: For frozen dataclasses the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py - # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``_DC_REPR_NONE`` sentinel distinguishes "computed but not fast-cacheable" from "not yet computed". + # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. A + # cached ``_DC_REPR_NONE`` sentinel distinguishes "computed but not fast-cacheable" from "not yet computed". # - # The cache is keyed by ``(is_frozen, pruning_paths is None)`` because a frozen dataclass's pruned repr - # depends on the pruning_paths set — we use separate cache slots for pruned vs unpruned to avoid serving - # the wrong narrowing. + # The cache is keyed by ``(is_frozen, pruning_paths is None)`` because a frozen dataclass's pruned repr depends on + # the pruning_paths set — we use separate cache slots for pruned vs unpruned to avoid serving the wrong narrowing. cache_attr = "_qd_dc_repr" if pruning_paths is None else "_qd_dc_repr_narrow" is_frozen = type(arg).__hash__ is not None if is_frozen: @@ -253,39 +252,37 @@ def stringify_obj_type( * Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). * Unrecognised type at this kernel-read path (see ``_fail_unknown_type``). - Two rules from ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing" govern this - function: + Two rules from ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing" govern this function: 1. The cache key may *only* include contributions from paths that pruning has marked kernel-accessed - (``pruning_paths``). Container walkers (dataclass + data_oriented) check ``_is_path_used`` per - child and skip non-pruned subtrees — kernel-unread paths are *guaranteed* not to affect codegen so - this is safe by construction. + (``pruning_paths``). Container walkers (dataclass + data_oriented) check ``_is_path_used`` per child and + skip non-pruned subtrees — kernel-unread paths are *guaranteed* not to affect codegen so this is safe by + construction. - 2. At paths the kernel *does* read, unrecognised types must not be silently dropped or hashed by - type-name — fastcache fails the call (loudly, with a one-shot warning) so the gap can be closed. + 2. At paths the kernel *does* read, unrecognised types must not be silently dropped or hashed by type-name — + fastcache fails the call (loudly, with a one-shot warning) so the gap can be closed. Parameters: - - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. - Determines whether primitive values are baked into the cache key (template-position primitives and - all primitive members of data-oriented containers). + - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. Determines + whether primitive values are baked into the cache key (template-position primitives and all primitive members + of data-oriented containers). - ``pruning_paths``: optional set of kernel-accessed flat names from L1 cache. When provided, - ``dataclass_to_repr`` and the ``data_oriented`` branch below descend only into children whose flat - name is in the set. Pruning info is populated by ``ASTTransformer.build_Name`` / - ``build_Attribute`` (kernel-arg-rooted chains) plus ``Kernel._fold_struct_nd_paths_into_pruning`` - (ndarray accesses through data_oriented containers). - - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the - ``self`` arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's - flat name for the narrow-walk lookup. + ``dataclass_to_repr`` and the ``data_oriented`` branch below descend only into children whose flat name is in + the set. Pruning info is populated by ``ASTTransformer.build_Name`` / ``build_Attribute`` (kernel-arg-rooted + chains) plus ``Pruning.fold_struct_nd_paths`` (ndarray accesses through data_oriented containers). + - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the ``self`` + arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's flat name for + the narrow-walk lookup. """ # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` # strips wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / # data-oriented walkers below do raw ``getattr`` to fetch struct fields, so a wrapper stored as a struct field - # arrives here un-stripped. Without this branch the hasher would hash the wrapper as an unknown type instead - # of unwrapping to the recognised impl. See ``perso_hugh/doc/quadrants-tensor.md`` §8.14. + # arrives here un-stripped. Without this branch the hasher would hash the wrapper as an unknown type instead of + # unwrapping to the recognised impl. See ``perso_hugh/doc/quadrants-tensor.md`` §8.14. # - # PERF-CRITICAL: the ``_any_tensor_constructed`` guard makes this check zero-cost when no ``qd.Tensor`` has - # been created. ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a - # pointer comparison (~10 ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. + # PERF-CRITICAL: the ``_any_tensor_constructed`` guard makes this check zero-cost when no ``qd.Tensor`` has been + # created. ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer + # comparison (~10 ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. if ( _tensor_wrapper._any_tensor_constructed and type(obj) in _TENSOR_WRAPPER_TYPES ): # pyright: ignore[reportOptionalMemberAccess] @@ -298,8 +295,8 @@ def stringify_obj_type( if isinstance(obj, VectorNdarray): return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): - # Recognised-but-unsupported: shape/dtype affect kernel codegen but fastcache doesn't yet hash them. - # Disable fastcache for the whole call. + # Recognised-but-unsupported: shape/dtype affect kernel codegen but fastcache doesn't yet hash them. Disable + # fastcache for the whole call. # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) return _FAIL_FASTCACHE @@ -319,13 +316,12 @@ def stringify_obj_type( raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat ) if is_data_oriented(obj): - # Walk the data_oriented container's members, narrowed by pruning info — the kernel-compile path - # records every kernel-accessed attribute chain (ndarrays via ``_promote_ndarray_if_declared`` + - # ``_fold_struct_nd_paths_into_pruning``; primitives, opaque members, nested structs via - # ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` propagation calling - # ``pruning.mark_used``). Members not in ``pruning_paths`` are *guaranteed* not to affect kernel - # codegen because the kernel cannot read them. Dropping them from the hash satisfies rule 1 - # (cache only pruned paths). + # Walk the data_oriented container's members, narrowed by pruning info — the kernel-compile path records + # every kernel-accessed attribute chain (ndarrays via ``_promote_ndarray_if_declared`` + + # ``Pruning.fold_struct_nd_paths``; primitives, opaque members, nested structs via + # ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` propagation calling ``pruning.mark_used``). Members + # not in ``pruning_paths`` are *guaranteed* not to affect kernel codegen because the kernel cannot read them. + # Dropping them from the hash satisfies rule 1 (cache only pruned paths). child_repr_l = ["da"] try: _asdict = getattr(obj, "_asdict") @@ -333,10 +329,9 @@ def stringify_obj_type( except AttributeError: _dict = obj.__dict__ for k, v in _dict.items(): - # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the - # per-instance ``BoundQuadrantsCallable`` on ``instance.__dict__`` so subsequent ``instance.method`` - # lookups skip the descriptor allocation; those entries are not data and must not invalidate the - # fastcache key. + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the per-instance + # ``BoundQuadrantsCallable`` on ``instance.__dict__`` so subsequent ``instance.method`` lookups skip the + # descriptor allocation; those entries are not data and must not invalidate the fastcache key. v_type = type(v) if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: continue diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 2d93245397..1e15eff0f0 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -2,44 +2,42 @@ Background (pre-refactor) ------------------------- -Fastcache used a single cache key derived from source + config + a *wide* args hash that walked every member -of every container argument. That walk was brittle: +Fastcache used a single cache key derived from source + config + a *wide* args hash that walked every member of +every container argument. That walk was brittle: - Encountering any unrecognised type silently disabled fastcache (``[FASTCACHE][PARAM_INVALID]`` warning + ``None`` return); a single Genesis ``RigidSolver._uid`` member killed the cache for the whole solver. - - Working around it via ``@qd.data_oriented(stable_members=True)`` opt-in only swapped one brittleness for - another: a new tensor-like type added later but missed in args_hasher's recognised set would be silently - skipped, serving stale cached results. + - Working around it via ``@qd.data_oriented(stable_members=True)`` opt-in only swapped one brittleness for another: + a new tensor-like type added later but missed in args_hasher's recognised set would be silently skipped, serving + stale cached results. -Both fundamentally stem from the wide walk *blindly* visiting paths the kernel never reads. The pre-refactor -design had no way to know which paths actually mattered before compile. +Both fundamentally stem from the wide walk *blindly* visiting paths the kernel never reads. The pre-refactor design +had no way to know which paths actually mattered before compile. Two-level cache --------------- -The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so -the args hash can walk *only* paths the kernel reads: +The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so the args +hash can walk *only* paths the kernel reads: - - L1 (this module's ``make_source_config_key`` + ``load_pruning_info`` / ``store_pruning_info``): - keyed by source+config only (no args). Stores ``PruningInfo`` — the set of kernel-accessed flat names - (e.g. ``__qd_state__qd_x``) plus the ``graph_do_while_arg`` (also a kernel-source property). + - L1 (this module's ``make_source_config_key`` + ``load_pruning_info`` / ``store_pruning_info``): keyed by + source+config only (no args). Stores ``PruningInfo`` — the set of kernel-accessed flat names (e.g. + ``__qd_state__qd_x``) plus the ``graph_do_while_arg`` (also a kernel-source property). - - L2 (``make_full_cache_key`` + ``load_full`` / ``store_full``): keyed by L1 key + the *narrow* args hash - computed with pruning info from L1. Stores the C++ ``frontend_cache_key`` that names the compiled - artifact. + - L2 (``make_full_cache_key`` + ``load_full`` / ``store_full``): keyed by L1 key + the *narrow* args hash computed + with pruning info from L1. Stores the C++ ``frontend_cache_key`` that names the compiled artifact. Lookup flow on a warm call: L1 lookup → narrow args hash (paths from L1) → L2 lookup → load artifact. -Cold compile flow: L1 miss → cold compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store -L2. +Cold compile flow: L1 miss → cold compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store L2. Safety implication ------------------ -A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect -kernel codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go -through ``args_hasher.stringify_obj_type`` which falls back to a ``type(v).__qualname__``-based string for -unrecognised types and emits a one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning, so a missed type registration -is impossible to miss but doesn't silently disable fastcache. +A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect kernel +codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go through +``args_hasher.stringify_obj_type`` which falls back to a ``type(v).__qualname__``-based string for unrecognised types +and emits a one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning, so a missed type registration is impossible to miss but +doesn't silently disable fastcache. """ import json @@ -61,10 +59,10 @@ from .hash_utils import hash_iterable_strings from .python_side_cache import PythonSideCache -# Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to -# the same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had -# no such prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from -# prior Quadrants installs are simply ignored rather than mis-served. +# Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to the +# same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had no such +# prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from prior Quadrants +# installs are simply ignored rather than mis-served. _L1_MARKER = "l1" _L2_MARKER = "l2" @@ -73,8 +71,8 @@ def make_source_config_key(kernel_source_info: FunctionSourceInfo) -> str: """Build the L1 cache key: source + config + version, with no dependence on args. Used by ``_try_load_fastcache`` before any args walking. The same key drives ``load_pruning_info`` / - ``store_pruning_info``; the matching ``make_full_cache_key`` derives the L2 key from this plus the narrow - args hash. + ``store_pruning_info``; the matching ``make_full_cache_key`` derives the L2 key from this plus the narrow args + hash. """ kernel_hash = function_hasher.hash_kernel(kernel_source_info) config_hash = config_hasher.hash_compile_config() diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index f18d89132d..aaa71620ce 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -64,8 +64,8 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_used_parameters) # only needed for args, not kwargs self.callee_param_by_caller_arg_name_by_func_id: dict[int, dict[str, str]] = defaultdict(dict) - # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. Populated by the AST - # builder when a chain like ``self.x.y`` resolves to an ndarray that was pre-declared by + # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. Populated by the + # AST builder when a chain like ``self.x.y`` resolves to an ndarray that was pre-declared by # ``_predeclare_struct_ndarrays``. On the second (enforcing) pass, ``_predeclare_struct_ndarrays`` only # registers ndarrays whose id is in this set — dropping every reachable-but-unused ndarray from the kernel's # parameter list. @@ -76,10 +76,10 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: # behavior). self.pass_0_ran: bool = False # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). - # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / qd.template). - # Kept *separate* from ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on the enforcing - # pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite ``s.x`` → - # ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a ``QuadrantsNameError: Name + # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / + # qd.template). Kept *separate* from ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on + # the enforcing pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite ``s.x`` + # → ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a ``QuadrantsNameError: Name # "__qd_s__qd_x" is not defined``. ``record_after_call`` propagates entries from callee to caller (so # ``f(self.dofs)`` where ``f`` reads ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). # After both compile passes, ``Pruning.fold_kernel_arg_chain_paths`` merges the kernel's set into @@ -102,28 +102,27 @@ def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None def fold_struct_nd_paths( self, struct_ndarray_launch_info: list[tuple[Any, int, tuple[str, ...]]], arg_metas: list[ArgMetadata] ) -> None: - """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat - name set so ``args_hasher.hash_args`` narrow-walks them correctly. - - Background: ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat names - produced by ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* - args. ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` - in the AST and don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths - *are* recorded — in ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — - but only for ndarray members. - - Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` - and union all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the - same convention used for dataclass args. - - Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at - compile, opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk - cannot distinguish "kernel reads this primitive" from "kernel does not read this primitive". The - ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking - *all* attrs of a data_oriented container — narrowing only suppresses subtrees explicitly absent - from the pruning set. So for a data_oriented arg with mostly-ndarray members, the cache key - correctly depends on the ndarray paths it uses; for one with primitive members whose values - matter, those members are still folded into the hash (qualname-fallback / value paths). + """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat name set so + ``args_hasher.hash_args`` narrow-walks them correctly. + + Background: ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat names produced by + ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* args. + ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` in the AST and + don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths *are* recorded — in + ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — but only for ndarray members. + + Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` and union + all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the same convention used + for dataclass args. + + Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at compile, + opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk cannot distinguish + "kernel reads this primitive" from "kernel does not read this primitive". The + ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking *all* attrs of + a data_oriented container — narrowing only suppresses subtrees explicitly absent from the pruning set. So for + a data_oriented arg with mostly-ndarray members, the cache key correctly depends on the ndarray paths it + uses; for one with primitive members whose values matter, those members are still folded into the hash + (qualname-fallback / value paths). """ if not struct_ndarray_launch_info: return @@ -140,21 +139,20 @@ def fold_struct_nd_paths( kernel_used.add(flat) def fold_kernel_arg_chain_paths(self) -> None: - """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both - compile passes have completed. - - Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain - (e.g. ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into ``kernel_arg_chain_paths_by_func_id`` - rather than ``used_vars_by_func_id``, because the latter is read on the enforcing pass to build - ``struct_locals`` for ``FlattenAttributeNameTransformer``. If chain names appeared there, the - transformer would rewrite ``self.n`` into ``Name('__qd_self__qd_n')`` and ``build_Name`` would - fail to find such a variable. - - Doing the merge here — after pass 1, just like ``fold_struct_nd_paths`` — avoids that interaction - while still making the chain paths available to the fastcache args-hash narrow walk. The set on + """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both compile + passes have completed. + + Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain (e.g. + ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into ``kernel_arg_chain_paths_by_func_id`` rather than + ``used_vars_by_func_id``, because the latter is read on the enforcing pass to build ``struct_locals`` for + ``FlattenAttributeNameTransformer``. If chain names appeared there, the transformer would rewrite ``self.n`` + into ``Name('__qd_self__qd_n')`` and ``build_Name`` would fail to find such a variable. + + Doing the merge here — after pass 1, just like ``fold_struct_nd_paths`` — avoids that interaction while + still making the chain paths available to the fastcache args-hash narrow walk. The set on ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* object as - ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so updating - one updates both. + ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so updating one updates + both. """ kernel_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) if not kernel_chain_paths: @@ -229,12 +227,12 @@ def record_after_call( callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) - # Propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) - # AND through plain-Name args of non-flattened types (``f(self)``). Gate on the *root* - # Name id, not the resulting flat string: ``self.dofs`` flattens to ``__qd_self__qd_dofs`` - # (which starts with ``__qd_``) but its root is the bare kernel arg ``self`` — we still - # need to propagate. Already-flattened dataclass refs like ``Name('__qd_self__qd_dofs')`` - # have a ``__qd_*`` root and are handled by the ``vars_to_unprune`` path above. + # Propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) AND through + # plain-Name args of non-flattened types (``f(self)``). Gate on the *root* Name id, not the resulting + # flat string: ``self.dofs`` flattens to ``__qd_self__qd_dofs`` (which starts with ``__qd_``) but its + # root is the bare kernel arg ``self`` — we still need to propagate. Already-flattened dataclass refs + # like ``Name('__qd_self__qd_dofs')`` have a ``__qd_*`` root and are handled by the ``vars_to_unprune`` + # path above. flat = _flatten_arg_node(arg) if flat is not None: caller_flat, root_id = flat diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 5f3fd26549..cc38bb5b1a 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -243,8 +243,8 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati # Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is # a no-op for the existing data_oriented + qd.field workloads (genesis field-backend). # - # Opt-out: ``_qd_stable_members = True`` on the class (or - # ``@qd.data_oriented(stable_members=True)``) skips the per-call descriptor walk. + # Opt-out: ``_qd_stable_members = True`` on the class (or ``@qd.data_oriented(stable_members=True)``) + # skips the per-call descriptor walk. if type(arg).__dict__.get("_qd_stable_members"): return weakref.ref(arg) nd_descriptors: list = [] diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 637af4cf83..75c3f88ef8 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -87,17 +87,16 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters and node.id.startswith("__qd_"): ctx.global_context.pruning.mark_used(ctx.func.func_id, node.id) # Track chains rooted at non-flattened parameter names: top-level ``@qd.kernel`` args - # (``ctx.kernel_args``) and ``@qd.func`` params (``ctx.fn_param_names``). Both appear in the - # AST as bare names (``self`` for a data_oriented kernel arg; ``static_rigid_sim_config`` for - # a ``qd.template()`` func arg bound to a ``@qd.data_oriented`` instance). - # ``build_Attribute`` propagates this annotation through ``state.dofs.x`` chains and - # ``mark_kernel_arg_chain_used``-s the flat name. The kernel's pruning narrow walk picks them - # up directly (kernel case) or after ``record_after_call`` propagates the callee's func-arg - # chains back through the call boundary (func case): e.g. ``func(s=self._sub)`` where ``func`` - # reads ``s.x`` ends up with ``__qd_self__qd__sub__qd_x`` recorded in the kernel's pruning, - # so the args-hasher hashes that primitive value into the fastcache key. - # Dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as - # already-flat ``__qd_…`` Names, handled by the block above via ``mark_used``. + # (``ctx.kernel_args``) and ``@qd.func`` params (``ctx.fn_param_names``). Both appear in the AST as bare + # names (``self`` for a data_oriented kernel arg; ``static_rigid_sim_config`` for a ``qd.template()`` func + # arg bound to a ``@qd.data_oriented`` instance). ``build_Attribute`` propagates this annotation through + # ``state.dofs.x`` chains and ``mark_kernel_arg_chain_used``-s the flat name. The kernel's pruning narrow + # walk picks them up directly (kernel case) or after ``record_after_call`` propagates the callee's func-arg + # chains back through the call boundary (func case): e.g. ``func(s=self._sub)`` where ``func`` reads ``s.x`` + # ends up with ``__qd_self__qd__sub__qd_x`` recorded in the kernel's pruning, so the args-hasher hashes that + # primitive value into the fastcache key. + # Dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as already-flat + # ``__qd_…`` Names, handled by the block above via ``mark_used``. if not node.id.startswith("__qd_") and (node.id in ctx.kernel_args or node.id in ctx.fn_param_names): node._qd_arg_chain = node.id # type: ignore[attr-defined] else: @@ -812,10 +811,9 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) - # Propagate the kernel-arg-rooted chain annotation and record this access in pruning's *separate* - # chain-paths set. ``build_Name`` sets ``_qd_arg_chain`` on non-flattened kernel args (e.g. - # data_oriented ``self``); each Attribute access in the chain extends it - # (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). + # Propagate the kernel-arg-rooted chain annotation and record this access in pruning's *separate* chain-paths + # set. ``build_Name`` sets ``_qd_arg_chain`` on non-flattened kernel args (e.g. data_oriented ``self``); each + # Attribute access in the chain extends it (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). # # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses ``pruning.used_vars_by_func_id`` as # ``struct_locals``, which drives ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index ee5cc5accb..7668620467 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -156,8 +156,8 @@ def _transform_kernel_arg( elif isinstance(field.type, type) and getattr(field.type, "_data_oriented", False): # ``@qd.data_oriented`` field type inside a typed-dataclass kernel arg. The two patterns are # semantically incompatible at this layer: dataclass kernel-arg recursion uses annotations to - # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers don't - # carry per-attribute type annotations — they need a value-driven walk + # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers + # don't carry per-attribute type annotations — they need a value-driven walk # (``_predeclare_struct_ndarrays``), which only fires for ``qd.template()`` / ``qd.Tensor`` # annotations. Rather than silently miscompile, raise a clear error pointing users to the # recommended pattern. @@ -362,13 +362,12 @@ def _transform_func_arg( argument_type: Any, data: Any, ) -> None: - # Record the bare (non-flattened) func param name so ``build_Name`` can seed ``_qd_arg_chain`` - # for attribute accesses rooted at this param. Critical for ``qd.template()`` args bound to - # ``@qd.data_oriented`` instances (e.g. ``static_rigid_sim_config.para_level`` inside a - # ``@qd.func``): without this, the kernel's pruning set never learns about ``.para_level``, - # the args-hasher skips the value, and different ``para_level`` configurations collide in the - # fastcache key. Flat names starting with ``__qd_`` arrive here too via the dataclass-flatten - # recursion below; they're harmless to add (``build_Name``'s chain branch gates on + # Record the bare (non-flattened) func param name so ``build_Name`` can seed ``_qd_arg_chain`` for attribute + # accesses rooted at this param. Critical for ``qd.template()`` args bound to ``@qd.data_oriented`` instances + # (e.g. ``static_rigid_sim_config.para_level`` inside a ``@qd.func``): without this, the kernel's pruning set + # never learns about ``.para_level``, the args-hasher skips the value, and different ``para_level`` + # configurations collide in the fastcache key. Flat names starting with ``__qd_`` arrive here too via the + # dataclass-flatten recursion below; they're harmless to add (``build_Name``'s chain branch gates on # ``not node.id.startswith("__qd_")``) but the bare-name entries are what enables propagation. ctx.fn_param_names.add(argument_name) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index d7dc4e6db6..b00165337f 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -479,22 +479,21 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ctx.global_context, "struct_ndarray_launch_info", [] ) # Fold data_oriented ndarray attribute chains into the kernel's used-flat-names set so - # ``args_hasher.hash_args`` can narrow data_oriented walks too. ``used_vars_by_func_id`` - # only contains flat names from dataclass-arg expansion in - # ``extract_struct_locals_from_context``; data_oriented args don't go through that - # expansion, so accesses like ``self.x`` on an ndarray member are only tracked via - # ``struct_ndarray_launch_info``. Without this fold, narrow hashing for data_oriented - # args walks nothing — every (arg_idx, attr_chain) pair gets the same hash regardless - # of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the + # ``args_hasher.hash_args`` can narrow data_oriented walks too. ``used_vars_by_func_id`` only + # contains flat names from dataclass-arg expansion in ``extract_struct_locals_from_context``; + # data_oriented args don't go through that expansion, so accesses like ``self.x`` on an ndarray + # member are only tracked via ``struct_ndarray_launch_info``. Without this fold, narrow hashing + # for data_oriented args walks nothing — every (arg_idx, attr_chain) pair gets the same hash + # regardless of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). pruning.fold_struct_nd_paths(self._struct_ndarray_launch_info_by_key.get(key, []), self.arg_metas) - # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested - # struct paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` - # tracking. Kept separate from ``used_vars_by_func_id`` during compile (would otherwise - # poison ``struct_locals`` and break codegen) — see the field-level docstring on - # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` - # assignment to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set - # by reference, so the final fastcache L1 entry sees all kernel-accessed paths. + # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested struct + # paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` tracking. Kept + # separate from ``used_vars_by_func_id`` during compile (would otherwise poison ``struct_locals`` + # and break codegen) — see the field-level docstring on + # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` assignment + # to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set by reference, so the + # final fastcache L1 entry sees all kernel-accessed paths. pruning.fold_kernel_arg_chain_paths() else: for used_parameters in pruning.used_vars_by_func_id.values(): @@ -525,11 +524,11 @@ def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args 1. If L1 was missing (``self._pruning_paths_from_l1 is None``), write the freshly-computed pruning info so the next call from a new process can skip the args-walk warm-up. - 2. If ``fast_checksum`` is still None (which means either L1 was missing, or L1 hit but phase 2 - of ``_try_load_fastcache`` saw a FIELD-related FastcacheSkip — in which case we keep ``None`` - and the post-compile ``src_hasher.store`` is skipped), compute the narrow args hash *now* - using the just-populated pruning info and derive the L2 key. The post-launch ``src_hasher.store`` - call uses ``self.fast_checksum`` as the L2 key. + 2. If ``fast_checksum`` is still None (which means either L1 was missing, or L1 hit but phase 2 of + ``_try_load_fastcache`` saw a FIELD-related FastcacheSkip — in which case we keep ``None`` and the + post-compile ``src_hasher.store`` is skipped), compute the narrow args hash *now* using the just- + populated pruning info and derive the L2 key. The post-launch ``src_hasher.store`` call uses + ``self.fast_checksum`` as the L2 key. Side-effect helper; split out from ``materialize`` to keep the compile loop readable. """ diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 764435fdf8..f3dedca01e 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -303,8 +303,8 @@ def data_oriented(cls=None, *, stable_members: bool = False): stable_members (bool): launch-context perf hint — if ``True``, declares that the class's ndarray-typed members are allocated once and never reassigned between kernel calls. Quadrants will skip the per-call ndarray- reference walk that ``Kernel.launch_kernel`` uses to detect ndarray reassignment on mutable containers - (~1-2 us/call savings on Genesis-style containers with dozens of ndarray attrs). Reassigning a member on a - ``stable_members`` class is undefined behaviour — the previously-compiled kernel will be reused even if + (~1-2 us/call savings on Genesis-style containers with dozens of ndarray attrs). Reassigning a member on + a ``stable_members`` class is undefined behaviour — the previously-compiled kernel will be reused even if the new ndarray has different dtype/ndim/layout. May also be set as a class-level attribute ``_qd_stable_members = True`` (equivalent). diff --git a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py index b7ff707b60..f6f4b010aa 100644 --- a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py +++ b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py @@ -163,11 +163,10 @@ def test_fastcache_field_warnings_warn_struct_template_field(tmp_path, capfd): """Struct with qd.Template-annotated field containing a Field — warning should fire when the field is actually read by the kernel. - Pruning-driven narrowing of args hashing only walks members the kernel reads; an unused dataclass field - cannot affect kernel codegen so it's correctly omitted from the hash (and from the - Field-disables-fastcache check). For the warning path to fire, the kernel must reference the Field — that - matches the user-visible contract that fastcache fails iff a "live" Field argument prevents safe - parametrisation. + Pruning-driven narrowing of args hashing only walks members the kernel reads; an unused dataclass field cannot + affect kernel codegen so it's correctly omitted from the hash (and from the Field-disables-fastcache check). For + the warning path to fire, the kernel must reference the Field — that matches the user-visible contract that + fastcache fails iff a "live" Field argument prevents safe parametrisation. """ qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 5736a6b96e..217c66ac30 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -142,12 +142,12 @@ def k1(foo: qd.Template) -> None: k1(foo=RandomClass()) _out, err = capfd.readouterr() - # Unrecognised types at a (top-level) kernel-read path now fail fastcache loudly: a one-shot - # ``[UNKNOWN_TYPE]`` warning identifies the offending type, and ``[INVALID_FUNC]`` then reports the - # disabled cache. The old silent ``[PARAM_INVALID]`` dead-end is gone — the two rules driving this - # are documented in ``args_hasher.py::_fail_unknown_type`` and ``fastcache.md`` "Pruning-driven - # argument hashing": (1) only pruned paths may contribute to the cache key (so no qualname fallback), - # (2) unrecognised types at pruned paths must not be silently dropped. + # Unrecognised types at a (top-level) kernel-read path now fail fastcache loudly: a one-shot ``[UNKNOWN_TYPE]`` + # warning identifies the offending type, and ``[INVALID_FUNC]`` then reports the disabled cache. The old silent + # ``[PARAM_INVALID]`` dead-end is gone — the two rules driving this are documented in + # ``args_hasher.py::_fail_unknown_type`` and ``fastcache.md`` "Pruning-driven argument hashing": (1) only pruned + # paths may contribute to the cache key (so no qualname fallback), (2) unrecognised types at pruned paths must + # not be silently dropped. assert "[FASTCACHE][UNKNOWN_TYPE]" in err assert RandomClass.__name__ in err assert "[FASTCACHE][INVALID_FUNC]" in err diff --git a/tests/python/test_ad_dataclass.py b/tests/python/test_ad_dataclass.py index 5b5d6dab27..b59160b6ff 100644 --- a/tests/python/test_ad_dataclass.py +++ b/tests/python/test_ad_dataclass.py @@ -7,9 +7,8 @@ * ``qd.field`` — ``qd.template()`` path; gradient via ``qd.ad.Tape``. * ``qd.tensor(backend=NDARRAY)`` — same path as ``qd.ndarray``; the dispatcher returns a wrapper whose ndarray ``_impl`` is unwrapped by the dataclass-annotation infrastructure. -* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` (or - ``qd.template()``). With ``object`` / no annotation the wrapper survives into kernel scope and host-side - ``__getitem__`` asserts. +* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` (or ``qd.template()``). + With ``object`` / no annotation the wrapper survives into kernel scope and host-side ``__getitem__`` asserts. * mixed — single dataclass holding both a ``qd.ndarray`` and a ``qd.field`` member. Pattern mirrors ``test_ad_ndarray.py`` (ndarray) and ``test_ad_basics.py`` (field). See diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index e844e73554..5b8283e8b8 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -168,10 +168,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 6. Mutation: same instance, reassign ndarray attribute to a *same-shape* ndarray between calls. -# The launch-time stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the -# live ndarray id into args_hash so the launch context is not served stale. We pin that behaviour -# here for the data_oriented case. +# 6. Mutation: same instance, reassign ndarray attribute to a *same-shape* ndarray between calls. The launch-time +# stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the live ndarray id into +# args_hash so the launch context is not served stale. We pin that behaviour here for the data_oriented case. # --------------------------------------------------------------------------- @@ -203,11 +202,11 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 7. Mutation cross-shape: reassign ndarray attribute to a *different-dtype* ndarray. -# The template-mapper specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns -# ``weakref.ref(arg)`` for ``is_data_oriented(arg)``; it does NOT descend into ndarray children to -# compute a dtype/ndim-dependent spec key. So if the data_oriented instance's id is unchanged but -# its ndarray attribute is reassigned to a different dtype, we expect either: +# 7. Mutation cross-shape: reassign ndarray attribute to a *different-dtype* ndarray. The template-mapper +# specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns ``weakref.ref(arg)`` for +# ``is_data_oriented(arg)``; it does NOT descend into ndarray children to compute a dtype/ndim-dependent spec key. +# So if the data_oriented instance's id is unchanged but its ndarray attribute is reassigned to a different dtype, +# we expect either: # - a graceful recompile/raise, or # - silent miscompilation (the bug case — current expected outcome per static analysis). # Mark xfail with strict=False so we record the actual outcome without breaking CI. @@ -241,10 +240,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 8. Distinct instances of same class -> spec-key behaviour. Documents that today each fresh instance -# triggers a recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf -# concern, not a correctness one. We assert correctness here; the recompile count is documented as -# a perf note. +# 8. Distinct instances of same class -> spec-key behaviour. Documents that today each fresh instance triggers a +# recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf concern, not a correctness +# one. We assert correctness here; the recompile count is documented as a perf note. # --------------------------------------------------------------------------- @@ -275,10 +273,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9. Fastcache cold then warm. Per the fastcache doc (``user_guide/fastcache.md`` line 129), -# ``@qd.data_oriented`` objects are supported in the cache key. We don't assert cross-process here -# (that requires a fresh interpreter); we assert that ``cache_stored`` becomes True on the first -# call and ``cache_key_generated`` is True (i.e. no PARAM_INVALID fallthrough due to the ndarray -# member). +# ``@qd.data_oriented`` objects are supported in the cache key. We don't assert cross-process here (that requires +# a fresh interpreter); we assert that ``cache_stored`` becomes True on the first call and +# ``cache_key_generated`` is True (i.e. no PARAM_INVALID fallthrough due to the ndarray member). # --------------------------------------------------------------------------- @@ -429,9 +426,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the whole call -# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. -# This test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to -# update this test. +# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. This +# test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to update +# this test. # --------------------------------------------------------------------------- @@ -483,9 +480,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 10. Pure validation: a @qd.pure @qd.kernel taking a data_oriented arg with an ndarray member should -# compile and run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test -# which only covers ``qd.field``. +# 10. Pure validation: a @qd.pure @qd.kernel taking a data_oriented arg with an ndarray member should compile and +# run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test which only covers +# ``qd.field``. # --------------------------------------------------------------------------- @@ -569,13 +566,12 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 13. Frozen dataclass holding a data_oriented holding an ndarray, kernel-arg via ``qd.template()``. -# Exercises the dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added -# by the Bug 1 fix. The outer dataclass must be frozen because (i) non-frozen dataclasses are -# unhashable in Python (``__hash__ is None``) and the template-mapper key tuple needs the value -# to be hashable, and (ii) the typed-dataclass-arg form (``def run(s: Outer):``) goes through -# ``_transform_kernel_arg`` which does not currently recurse on data_oriented field *types* (as -# opposed to values) — that's a separate follow-up. +# 13. Frozen dataclass holding a data_oriented holding an ndarray, kernel-arg via ``qd.template()``. Exercises the +# dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added by the Bug 1 fix. The outer +# dataclass must be frozen because (i) non-frozen dataclasses are unhashable in Python (``__hash__ is None``) and +# the template-mapper key tuple needs the value to be hashable, and (ii) the typed-dataclass-arg form (``def +# run(s: Outer):``) goes through ``_transform_kernel_arg`` which does not currently recurse on data_oriented +# field *types* (as opposed to values) — that's a separate follow-up. # --------------------------------------------------------------------------- @@ -944,13 +940,13 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_polymorphic_attr_across_instances(): - """The path cache in ``_struct_nd_paths_cache`` is keyed on ``type(arg)`` and assumes the set of - ndarray-reachable attribute chains is stable across instances. Some real-world ``@qd.data_oriented`` - containers (Genesis FEMSolver / MPMSolver / SPHSolver, etc.) hold polymorphic children whose - types differ between instances — e.g. ``self.material.x`` is an ``Ndarray`` on instance A and - a ``qd.field`` (``MatrixField``) on instance B. ``_collect_struct_nd_descriptors`` walks cached - paths verbatim and must not crash with ``'MatrixField' object has no attribute 'element_type'`` - when a path's leaf is no longer an ``Ndarray``; it should silently skip the stale entry.""" + """The path cache in ``_struct_nd_paths_cache`` is keyed on ``type(arg)`` and assumes the set of ndarray- + reachable attribute chains is stable across instances. Some real-world ``@qd.data_oriented`` containers (Genesis + FEMSolver / MPMSolver / SPHSolver, etc.) hold polymorphic children whose types differ between instances — e.g. + ``self.material.x`` is an ``Ndarray`` on instance A and a ``qd.field`` (``MatrixField``) on instance B. + ``_collect_struct_nd_descriptors`` walks cached paths verbatim and must not crash with ``'MatrixField' object has + no attribute 'element_type'`` when a path's leaf is no longer an ``Ndarray``; it should silently skip the stale + entry.""" N = 4 @qd.data_oriented @@ -1254,8 +1250,7 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key(tmp_path, monkeypatch) -> None: - """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested data_oriented - containers. + """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested data_oriented containers. Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the caller-side arg flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — ``self.cfg`` → diff --git a/tests/python/test_template_typing.py b/tests/python/test_template_typing.py index c4f00a081e..11e9d5da72 100644 --- a/tests/python/test_template_typing.py +++ b/tests/python/test_template_typing.py @@ -66,10 +66,10 @@ def __init__(self) -> None: self.a_float = 1.23 self.scratch = qd.ndarray(qd.i32, shape=(1,)) - # Read the primitive members so the fastcache narrow walk includes them in the hash. Pre-pruning - # the args_hasher walked every member of every container arg blindly; with pruning the kernel must - # actually access ``a.a_float`` for the raise-on-templated-floats guard to fire (the value being - # baked-in only matters when the kernel reads it). + # Read the primitive members so the fastcache narrow walk includes them in the hash. Pre-pruning the args_hasher + # walked every member of every container arg blindly; with pruning the kernel must actually access ``a.a_float`` + # for the raise-on-templated-floats guard to fire (the value being baked-in only matters when the kernel reads + # it). @qd.kernel(fastcache=True) def k1f(a: qd.Template) -> None: a.scratch[0] = qd.cast(a.a_float, qd.i32) From a47a5abb8ec103de3eed4c1aef3d1183a91e8f63 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:46:56 -0700 Subject: [PATCH 34/46] [Doc] fastcache.md: restore prose phrasing in unsupported-type + arg-type bullets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Drop "There is no qualname-fallback" qualifier from rule 2; the sentence that follows still explains why type-name fallback would be unsafe. - Restore the table's @qd.data_oriented row reference to the "Advanced — compound-type cache keying" section (types and values baked in). - Restore the original "If any kernel-used parameter is of an unsupported type..." paragraph in place of the two-failure-modes bullet rewrite. - Restore the simpler "argument types (e.g. switching from f32 to f64..." bullet. --- docs/source/user_guide/fastcache.md | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 4453bd9e47..21d1fd5e33 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -96,17 +96,14 @@ Fastcache supports the following parameter types: | `torch.Tensor` | Yes | dtype, ndim | | `numpy.ndarray` | Yes | dtype, ndim | | `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | -| `@qd.data_oriented` objects | Yes | member types recursively, narrowed by pruning (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)); primitive member values baked into kernel | +| `@qd.data_oriented` objects | Yes | member types recursively, narrowed by pruning (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)); primitive member types and values baked into kernel (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | | `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) | | Non-template primitives (int, float, bool) | Yes | type only | | `enum.Enum` | Yes | name and value | | `qd.field` / `ScalarField` / `MatrixField` at a kernel-read path | **No** | — | | Anything else at a kernel-read path | **No** | — | -Two failure modes — both loud, never silent: - -- **Recognised-but-unsupported** tensor-like types (`qd.field` / `ScalarField` / `MatrixField`) reached at a path the kernel actually reads → fastcache disabled for the call, kernel falls back to normal compilation. For these arriving through a `qd.Tensor`-annotated parameter the diagnostic is silent (normal usage); for other annotations a `[FASTCACHE][INVALID_FUNC]` log line identifies the offending parameter. -- **Unrecognised** types at a kernel-read path → fastcache disabled for the call, with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning per type identifying the offending class plus an `[INVALID_FUNC]` log line confirming the cache is off. To make a type cache-eligible, add explicit handling for it to `quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type`, or refactor the kernel so it does not read this member (pruning will then skip it). +If any kernel-used parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. Kernel-unused members of any type — including unrecognised ones — do **not** disable fastcache. The pruning narrowing in the args hasher skips them entirely, so opaque metadata (UUIDs, Pydantic configs, parent back-pointers) attached to a `@qd.data_oriented` instance is harmless as long as the kernel doesn't read it. @@ -120,7 +117,7 @@ Each compiled artifact is stored under a key derived from all of the following: - The **Quadrants version** (`quadrants.__version__`). - The **source code** of the kernel function or any `@qd.func` it calls. -- The **argument types at paths the kernel actually reads** (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing) below). +- The **argument types** (e.g. switching from `f32` to `f64`, or changing ndarray dimensionality). - The **compiler configuration** (e.g. `arch`, `debug`, `opt_level`, `fast_math`). - **Template parameter values** (since they are baked into the compiled kernel). @@ -144,7 +141,7 @@ The args hasher enforces two strict invariants: Paths *not* in the pruning set are skipped by the args hasher — they are guaranteed not to affect kernel codegen because the kernel cannot read them. -2. **Unrecognised types at kernel-read paths must not be silently dropped or hashed by type-name.** If pruning says the kernel reads a path and the value at that path is a type the args hasher doesn't explicitly handle (Pydantic models, UUIDs, third-party tensor wrappers, …), fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning identifying the offending type plus an `[INVALID_FUNC]` log line confirming the cache is off. There is no qualname-fallback — capturing type identity without type parameters (dtype/shape on a hypothetical tensor type) would silently mask a value-affecting change. +2. **Unrecognised types at kernel-read paths must not be silently dropped or hashed by type-name.** If pruning says the kernel reads a path and the value at that path is a type the args hasher doesn't explicitly handle (Pydantic models, UUIDs, third-party tensor wrappers, …), fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning identifying the offending type plus an `[INVALID_FUNC]` log line confirming the cache is off. Capturing type identity without type parameters (dtype/shape on a hypothetical tensor type) would silently mask a value-affecting change. #### Practical implications From 5debfe4ae18a2e9af704f314e3697b81d7ebb7b4 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:48:16 -0700 Subject: [PATCH 35/46] [Doc] fastcache.md: drop redundant 'every child is subject to pruning' qualifier --- docs/source/user_guide/fastcache.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 21d1fd5e33..cea42d9c27 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -178,7 +178,7 @@ On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules: -**`@qd.data_oriented`:** the walker descends into `vars(obj)`, narrowed by pruning info — *every* child (ndarray, primitive, opaque, nested struct) is subject to the pruning check. For each walked child: +**`@qd.data_oriented`:** the walker descends into `vars(obj)`, narrowed by pruning info. For each walked child: - `qd.ndarray` member, kernel-read — `(dtype, ndim, layout)` is included in the cache key. Element values are not. - `qd.ndarray` member, kernel-unused — *skipped*. Changes to its dtype/ndim/layout don't invalidate the cache. From 4e714c7997503da500d3dacd21b91f494dba8eb2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:49:14 -0700 Subject: [PATCH 36/46] [Doc] fastcache.md: revert @qd.data_oriented child-rule bullets to original 3-bullet form --- docs/source/user_guide/fastcache.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index cea42d9c27..d1e6556c14 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -180,15 +180,9 @@ The args hasher walks compound-type kernel parameters recursively. For each leaf **`@qd.data_oriented`:** the walker descends into `vars(obj)`, narrowed by pruning info. For each walked child: -- `qd.ndarray` member, kernel-read — `(dtype, ndim, layout)` is included in the cache key. Element values are not. -- `qd.ndarray` member, kernel-unused — *skipped*. Changes to its dtype/ndim/layout don't invalidate the cache. -- Primitive (`int` / `float` / `bool` / `enum.Enum`) member, kernel-read — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values that the kernel reads get different cache entries. -- Primitive member, kernel-unused — *skipped* (the kernel cannot read it so its value cannot affect codegen). -- Nested `@qd.data_oriented` member — recurses (with these same rules). -- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). -- Opaque member (anything fastcache doesn't recognise), kernel-unused — *skipped*. -- Opaque member, kernel-read — fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning. To make the type cacheable, add explicit handling to `args_hasher.py::stringify_obj_type`. -- `qd.field` member, kernel-read — fastcache is disabled for the call (recognised-but-unsupported). A `qd.field` member at a kernel-*unused* path is simply skipped (no diagnostic). +- `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not. +- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. +- Nested `@qd.data_oriented` member — recurses. **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: From 39602c6128d31828c23e438de54d4be59cf53069 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:51:46 -0700 Subject: [PATCH 37/46] [Doc] fastcache.md: tighten recognised-but-unsupported sentence --- docs/source/user_guide/fastcache.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index d1e6556c14..31efdcb1d1 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -149,7 +149,7 @@ The args hasher enforces two strict invariants: - **Kernel-unused members of unrecognised types are also fine.** Pruning narrowing skips them before the type-recognition check runs. - **Kernel-read members of unrecognised types fail fastcache loudly.** Either add explicit handling in `quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type` (for new tensor-like types whose dtype/shape matter), or move the access out of the kernel-read path (for opaque metadata that shouldn't be there in the first place). -`qd.field` / `ScalarField` / `MatrixField` are *recognised-but-unsupported*: their shape/dtype would affect codegen but fastcache doesn't yet know how to safely include them, so encountering one at a kernel-read path disables fastcache for the call (with a warn-level diagnostic). +`qd.field` / `ScalarField` / `MatrixField` are *recognised-but-unsupported*: encountering one at a kernel-read path disables fastcache for the call (with a warn-level diagnostic). ## Advanced From bd37c943ef515fedbbd2f9f70cb6c51fcb671a55 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:53:16 -0700 Subject: [PATCH 38/46] [Doc] fastcache.md: restore nested-dataclass + qd.field bullets in data_oriented child-rules --- docs/source/user_guide/fastcache.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 31efdcb1d1..b5dee08bad 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -183,6 +183,8 @@ The args hasher walks compound-type kernel parameters recursively. For each leaf - `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not. - Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. - Nested `@qd.data_oriented` member — recurses. +- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). +- `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. **`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: From 59ce5ff52289df59eaedd5e8dca5e98d1d1bfa2b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Mon, 18 May 2026 23:58:14 -0700 Subject: [PATCH 39/46] [Style] args_hasher: restore original 'field offset' comments on ScalarField/MatrixField --- python/quadrants/lang/_fast_caching/args_hasher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 91e3c72029..1792f678ee 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -295,8 +295,8 @@ def stringify_obj_type( if isinstance(obj, VectorNdarray): return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): - # Recognised-but-unsupported: shape/dtype affect kernel codegen but fastcache doesn't yet hash them. Disable - # fastcache for the whole call. + # disabled for now, because we need to think about how to handle field offset + # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) return _FAIL_FASTCACHE @@ -307,7 +307,8 @@ def stringify_obj_type( if isinstance(obj, np.ndarray): return f"[np-{obj.dtype}-{obj.ndim}]" if isinstance(obj, MatrixField): - # Recognised-but-unsupported, same as ScalarField above. + # disabled for now, because we need to think about how to handle field offset + # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) return _FAIL_FASTCACHE From a63b834856e6c595a38bcea8194c6cd5f7cff18a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 00:01:00 -0700 Subject: [PATCH 40/46] [Docs] src_hasher: remove pre-refactor background paragraph from module docstring --- python/quadrants/lang/_fast_caching/src_hasher.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 1e15eff0f0..233c2c345d 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,20 +1,5 @@ """Two-level fastcache key derivation and persistence. -Background (pre-refactor) -------------------------- -Fastcache used a single cache key derived from source + config + a *wide* args hash that walked every member of -every container argument. That walk was brittle: - - - Encountering any unrecognised type silently disabled fastcache (``[FASTCACHE][PARAM_INVALID]`` warning + - ``None`` return); a single Genesis ``RigidSolver._uid`` member killed the cache for the whole solver. - - - Working around it via ``@qd.data_oriented(stable_members=True)`` opt-in only swapped one brittleness for another: - a new tensor-like type added later but missed in args_hasher's recognised set would be silently skipped, serving - stale cached results. - -Both fundamentally stem from the wide walk *blindly* visiting paths the kernel never reads. The pre-refactor design -had no way to know which paths actually mattered before compile. - Two-level cache --------------- The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so the args From f6c68d8cee2d441a6ea6974534beb9cdb9516d94 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 00:03:42 -0700 Subject: [PATCH 41/46] [Docs] src_hasher: correct safety-implication paragraph Unrecognised types at kernel-read paths fail the call's fastcache (loudly, via _fail_unknown_type returning _FAIL_FASTCACHE); they are not hashed via a qualname-based fallback string. The qualname appears only in the [UNKNOWN_TYPE] warning identifying the offending type. --- python/quadrants/lang/_fast_caching/src_hasher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 233c2c345d..3a7222acfe 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -20,9 +20,9 @@ ------------------ A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect kernel codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go through -``args_hasher.stringify_obj_type`` which falls back to a ``type(v).__qualname__``-based string for unrecognised types -and emits a one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning, so a missed type registration is impossible to miss but -doesn't silently disable fastcache. +``args_hasher.stringify_obj_type``; if it encounters an unrecognised type at such a path it fails the call's fastcache +loudly (one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning identifying the offending ``type(v).__qualname__``), so a missed +type registration is impossible to miss and cannot serve stale cached results. """ import json From ae36b11fc14442266c7fe4f62dc9a6e362d61cc0 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 06:05:17 -0700 Subject: [PATCH 42/46] [Fix] Per-instance ndarray-path cache for @qd.data_oriented args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TemplateMapper's args_hash walk used a per-class cache of attribute paths populated from the first instance ever seen of each class. That cache was wrong for @qd.data_oriented classes whose attribute structure varies across instances (motivating case: Genesis ``DataManager``, which only allocates ``*_adjoint_cache`` members when ``requires_grad=True``). Two failure modes existed: - Forward direction (first instance has the attr, second misses it): the walk crashed with ``AttributeError: 'DataManager' object has no attribute 'dofs_state_adjoint_cache'`` when launching kernels on the second instance. Observed on Genesis ``test_rigid_mpm_legacy_coupling`` (macos-15 GPU job in PR genesis-world#2799). - Inverse direction (first instance lacks the attr, second has it): silently miscached — the new ndarray's id never made it into args_hash, so a later reassignment of that attribute wouldn't trigger spec re-derivation. Fix: stash the walked path list on the *instance* (``arg._qd_nd_paths``) via ``object.__setattr__`` (compatible with frozen dataclasses, mirroring the existing ``_qd_dc_repr`` pattern in ``args_hasher.dataclass_to_repr``). Each instance is walked once on first kernel call; subsequent calls fetch the cached list via instance ``__dict__`` lookup (~30 ns, same order as the previous class-level ``dict.get``). Steady-state perf: unchanged on franka cpu single env (one solver instance, walked once at scene build, fetched per-call thereafter). Startup pays one walk per instance lifetime — ~10us per scene build for Genesis-shaped workloads. ``__slots__`` classes that can't accept the instance stash fall back to per-class caching and retain the legacy polymorphic-instance limitation; Genesis data_oriented containers don't use ``__slots__``. ``_classify_for_args_hash`` is split into a per-class disposition (``_SKIP`` / ``_PER_INSTANCE``) plus a per-instance ``_struct_nd_paths_for`` call. The ``_qd_stable_members`` flag still short-circuits the entire walk for users who opt into the "no ndarray reassignment, ever" promise. Test ``test_data_oriented_polymorphic_attribute_set_across_instances`` covers both forward and inverse directions on a ``DataManager``-shaped class. --- python/quadrants/lang/_template_mapper.py | 61 +++++++-------- .../lang/_template_mapper_hotpath.py | 78 +++++++++++++------ tests/python/test_data_oriented_ndarray.py | 61 +++++++++++++-- 3 files changed, 137 insertions(+), 63 deletions(-) diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index 80fc320a6d..07d32892c1 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, TypeAlias, cast +from typing import Any, TypeAlias from weakref import ReferenceType from quadrants.lang import impl @@ -15,31 +15,28 @@ _struct_nd_paths_for, ) -# Per-``type(arg)`` precomputed dispatch for the args_hash ndarray-id walk in ``TemplateMapper.lookup``. Each entry is -# either the cached attribute path list (when the class is data_oriented, opted into ndarray tracking, and actually -# holds ndarrays) or ``None`` (when the per-call walk is a no-op — covers the common case of typed-dataclass args, -# non-data_oriented composite args, primitives, AND data_oriented classes with ``_qd_stable_members = True`` or with -# no ndarray members). One dict lookup per arg per call, ~30 ns, replacing the previous unconditional -# ``is_data_oriented(arg)`` + ``type(arg).__dict__.get`` chain. -_arg_nd_paths_or_none: dict[type, "list[tuple] | None"] = {} -_UNCLASSIFIED = object() +# Per-class disposition for the args_hash ndarray-id walk in ``TemplateMapper.lookup``: one of ``_SKIP`` (this class +# never contributes — non-data_oriented, or ``@qd.data_oriented(stable_members=True)``) or ``_PER_INSTANCE`` (delegate +# to ``_struct_nd_paths_for`` for a per-instance walk). The disposition depends only on type (data_oriented? +# stable_members?), so caching by class is correct. The *actual* path list is per-instance because @qd.data_oriented +# classes can have polymorphic attribute structure across instances (Genesis ``DataManager`` is the motivating case). +_arg_disposition: dict[type, object] = {} +_SKIP = object() +_PER_INSTANCE = object() -def _classify_for_args_hash(arg: Any) -> "list[tuple] | None": - """First-sighting classification for ``type(arg)`` in the args_hash walk. Returns the path list to walk (when the - arg is a data_oriented container without ``_qd_stable_members`` that actually contains ndarrays), or ``None`` to - skip subsequent per-call work for this type. +def _classify_disposition(arg: Any) -> object: + """First-sighting per-class disposition for the args_hash walk. Returns ``_SKIP`` (no per-call walk for this + class) or ``_PER_INSTANCE`` (delegate to ``_struct_nd_paths_for`` for a per-instance walk). ``_qd_stable_members`` here is a *launch-time perf hint only* (see ``@qd.data_oriented(stable_members=...)``). - It does not affect fastcache key derivation.""" + It promises that ndarray members are never reassigned, which lets us skip the per-call walk entirely. It does + not affect fastcache key derivation.""" if not is_data_oriented(arg): - return None + return _SKIP if type(arg).__dict__.get("_qd_stable_members"): - return None - paths = _struct_nd_paths_for(arg) - if not paths: - return None - return paths + return _SKIP + return _PER_INSTANCE Key: TypeAlias = tuple[Any, ...] @@ -114,21 +111,23 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl # iterate ``template_slot_locations`` instead of all args (Genesis main kernel_step_1: 4 template positions # of 16 args; Genesis branch step_1/step_2: 4 of 4). # - # For each candidate position, a per-class cache in ``_arg_nd_paths_or_none`` maps ``type(arg)`` to either the - # cached ndarray-path list to walk or ``None`` to skip (typical for primitive template-args, stable_members - # data_oriented, and data_oriented with zero ndarrays). One dict.get per candidate per call after warmup. + # For each candidate, ``_arg_disposition`` caches the per-class decision (skip vs walk-per-instance) and the + # actual paths come from ``_struct_nd_paths_for`` (per-instance, stashed on ``arg._qd_nd_paths``). Per-instance + # path caching is load-bearing for correctness — @qd.data_oriented classes can have polymorphic attribute + # structure across instances (Genesis ``DataManager`` only allocates adjoint-cache members when + # ``requires_grad=True``); a per-class cache populated from one instance can't safely be reused for another. nd_ids: list = [] for i in self.template_slot_locations: arg = args[i] cls = type(arg) - cached = _arg_nd_paths_or_none.get(cls, _UNCLASSIFIED) - if cached is _UNCLASSIFIED: - paths = _classify_for_args_hash(arg) - _arg_nd_paths_or_none[cls] = paths - else: - # Narrow the ``object`` sentinel union back to the actual cached value type. - paths = cast("list[tuple] | None", cached) - if paths is None: + disposition = _arg_disposition.get(cls) + if disposition is None: + disposition = _classify_disposition(arg) + _arg_disposition[cls] = disposition + if disposition is _SKIP: + continue + paths = _struct_nd_paths_for(arg) + if not paths: continue for chain in paths: v = arg diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index cc38bb5b1a..3a77f6702e 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -76,13 +76,22 @@ _primitive_types = {int, float, bool} -# Per-class cache: ``type(arg) -> list[tuple[str, ...]]`` of attribute paths whose values are ``Ndarray`` instances at -# first observation. Populated lazily by ``_struct_nd_paths_for`` on the first call with each new data_oriented (or -# nested dataclass) class. Empty list means "this class holds no ndarrays anywhere", in which case subsequent calls -# pay only a dict-lookup per arg. Non-empty list short-circuits the full ``vars()`` recursion and just resolves each -# cached path via ``getattr`` chains. Critical for the genesis field-backend hot path: the ``@qd.data_oriented`` -# Solver is passed as ``self`` to most kernels and holds dozens of attributes, so a full per-call ``vars()`` walk -# costs >100ns per kernel and trashed FPS until this cache was added. +# Per-instance cache of ndarray attribute paths, stashed on the instance via ``object.__setattr__`` (compatible with +# frozen dataclasses). Used by both ``TemplateMapper.lookup``'s args_hash walk and the ``_extract_arg`` data_oriented +# descriptor walk. Per-instance caching is necessary because @qd.data_oriented classes can have *different attribute +# structures across instances of the same class* — Genesis ``DataManager``, for instance, only allocates +# ``*_adjoint_cache`` members when ``requires_grad=True``. A class-level cache populated from the first-ever instance +# would either crash on missing attributes (forward direction, "first instance has, second misses") or silently miss +# new ones (inverse direction), both of which produce wrong-shape kernel reuse. +# +# Steady-state cost: one ``__dict__`` lookup per arg per call (~30ns), same order as the previous class-level +# ``dict.get``. The walk itself (``_build_struct_nd_paths``) is paid once per instance lifetime at first kernel +# launch with that instance — typically O(10) instances per Genesis scene, so ~10us total at scene build. +# +# ``_struct_nd_paths_cache`` (below) is a fallback for ``__slots__`` classes that have no ``__dict__`` and so can't +# accept the ``object.__setattr__`` stash. Such classes inherit the legacy per-class-cache behaviour (and its +# polymorphic-instance limitations). Genesis data_oriented containers don't use ``__slots__``, so this branch is +# unreachable in practice. _struct_nd_paths_cache: dict[type, list[tuple]] = {} @@ -120,21 +129,41 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] def _struct_nd_paths_for(arg: Any) -> list[tuple]: - """Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are - reachable from ``arg`` of type ``type(arg)``. First call for a class walks ``arg`` once via - ``_build_struct_nd_paths``; subsequent calls are dict-lookups. + """Return the per-instance cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` + instances are reachable from ``arg``. First call walks ``arg`` once via ``_build_struct_nd_paths`` and stashes + the result on the instance as ``_qd_nd_paths`` (via ``object.__setattr__`` so it works for frozen dataclasses + and ``@qd.data_oriented`` containers alike); subsequent calls fetch it via instance ``__dict__`` lookup. - Trades freshness for speed: assumes the *set* of ndarray-holding attribute paths is stable across instances of - the same class. The genesis Solver and similar ``@qd.data_oriented`` containers satisfy this — their ndarray - members are declared in ``__init__`` and not added later. If you need to add an ndarray attribute after the first - kernel launch on an instance of a given class, the new attribute won't be tracked. Call ``invalidate_struct_nd_ - paths_for`` (below) or restart the program. + Per-instance caching is correctness-load-bearing: ``@qd.data_oriented`` classes can have different attribute + sets across instances of the same class (e.g. Genesis ``DataManager`` with vs without ``requires_grad``), so a + per-class cache populated from one instance can't be reused for another. ``__slots__`` classes without a + ``__dict__`` fall back to per-class caching (see ``_struct_nd_paths_cache``) and retain the legacy limitation. + + Limitation: the path list is recorded once per instance. If a new ndarray attribute is attached to an instance + *after* its first kernel call (uncommon — Genesis containers declare all ndarrays in ``__init__``), it won't be + tracked until the cache is invalidated. Workaround: ``del arg.__dict__['_qd_nd_paths']`` (or restart the + process). """ + # Fast path: instance already walked. ``__dict__["…"]`` skips descriptor / ``__getattr__`` machinery (some + # third-party metaclasses, e.g. Pydantic, recurse infinitely on probe-style ``getattr`` for unknown names — + # see ``is_data_oriented`` for the same defensiveness). + try: + return arg.__dict__["_qd_nd_paths"] + except (AttributeError, KeyError): + pass + # ``__slots__`` fallback or first-sighting of this instance: check the class-level cache too, so that a + # ``__slots__`` class doesn't re-walk on every call. cls = type(arg) paths = _struct_nd_paths_cache.get(cls) - if paths is None: - paths = [] - _build_struct_nd_paths(arg, (), paths) + if paths is not None: + return paths + paths = [] + _build_struct_nd_paths(arg, (), paths) + try: + object.__setattr__(arg, "_qd_nd_paths", paths) + except AttributeError: + # ``__slots__`` class without a ``_qd_nd_paths`` slot — degrade to per-class caching. Loses correctness + # under polymorphic-instance attribute structure, but Genesis data_oriented containers don't use slots. _struct_nd_paths_cache[cls] = paths return paths @@ -144,13 +173,12 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding ndarrays — see the data_oriented branch in ``_extract_arg``. """ - # The path cache is keyed on ``type(arg)`` and assumes the *set* of ndarray-reachable attribute chains is stable - # across instances of the same class. That holds for the typical ``@qd.data_oriented`` container, but Genesis - # ``FEMSolver`` / ``MPMSolver`` / ``SPHSolver`` and similar can hold polymorphic children (e.g. ``self.material`` - # of a different concrete subclass) or swap a ``qd.Tensor``'s underlying impl between an ``Ndarray`` and a - # ``MatrixField``. When the leaf at a cached path is no longer an ``Ndarray`` we silently skip it: - # ``v.element_type`` / ``v.shape`` / ``v._qd_layout`` are Ndarray-only accessors. The per-instance ``weakref(arg)`` - # part of the spec key still ensures correct cache discrimination across instances. + # The path cache is per-instance (see ``_struct_nd_paths_for``) so polymorphic-instance attribute structure is + # handled correctly. Within a single instance's lifetime, a cached path's leaf may still cease to be an + # ``Ndarray`` (e.g. ``qd.Tensor``'s underlying impl swapped between an ``Ndarray`` and a ``MatrixField``); when + # that happens we silently skip the descriptor — ``v.element_type`` / ``v.shape`` / ``v._qd_layout`` are + # Ndarray-only accessors. The per-instance ``weakref(arg)`` part of the spec key still ensures correct cache + # discrimination across instances. for chain in _struct_nd_paths_for(arg): v = arg for a in chain: diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index 5b8283e8b8..cc0b2d5206 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -940,13 +940,12 @@ def run(s: qd.template()): @test_utils.test(arch=qd.cpu) def test_data_oriented_polymorphic_attr_across_instances(): - """The path cache in ``_struct_nd_paths_cache`` is keyed on ``type(arg)`` and assumes the set of ndarray- - reachable attribute chains is stable across instances. Some real-world ``@qd.data_oriented`` containers (Genesis - FEMSolver / MPMSolver / SPHSolver, etc.) hold polymorphic children whose types differ between instances — e.g. - ``self.material.x`` is an ``Ndarray`` on instance A and a ``qd.field`` (``MatrixField``) on instance B. - ``_collect_struct_nd_descriptors`` walks cached paths verbatim and must not crash with ``'MatrixField' object has - no attribute 'element_type'`` when a path's leaf is no longer an ``Ndarray``; it should silently skip the stale - entry.""" + """Some real-world ``@qd.data_oriented`` containers (Genesis FEMSolver / MPMSolver / SPHSolver, etc.) hold + polymorphic children whose types differ between instances — e.g. ``self.material.x`` is an ``Ndarray`` on + instance A and a ``qd.field`` (``MatrixField``) on instance B. The per-instance path cache walks each instance + fresh, but ``_collect_struct_nd_descriptors`` must additionally tolerate a path's leaf no longer being an + ``Ndarray`` *within a single instance's lifetime* (e.g. ``qd.Tensor`` impl swap), and silently skip the stale + entry rather than crash on ``v.element_type``.""" N = 4 @qd.data_oriented @@ -980,6 +979,54 @@ def run_field(s: qd.template()): run_field(state_b) +@test_utils.test(arch=qd.cpu) +def test_data_oriented_polymorphic_attribute_set_across_instances(): + """Models the Genesis ``DataManager`` failure mode: a ``@qd.data_oriented`` class whose ``__init__`` conditionally + allocates attributes based on a construction flag. Different instances of the same class then have different + attribute *sets* (not just different value types at the same paths). + + With a per-class path cache populated from the first instance walked, this would either AttributeError when the + second instance lacks an attribute the first had (forward direction) or silently miss an ndarray the second + instance has but the first didn't (inverse direction). Per-instance caching walks each instance fresh so both + directions work.""" + N = 4 + + @qd.data_oriented + class PolyState: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + # Forward direction: first instance has 'extra', second doesn't. Used to AttributeError on the cached + # ('extra',) path when running with state_lean. + state_full = PolyState(with_extra=True) + run(state_full) + state_lean = PolyState(with_extra=False) + run(state_lean) + np.testing.assert_array_equal(state_lean.x.to_numpy(), np.arange(1, N + 1)) + + # Inverse direction: a different class so per-class cache (if used by __slots__ fallback) starts fresh; first + # instance lacks 'extra', second has it. Verifies the second instance's 'extra' ndarray is correctly walked. + @qd.data_oriented + class PolyState2: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + + state_lean2 = PolyState2(with_extra=False) + run(state_lean2) + state_full2 = PolyState2(with_extra=True) + run(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.arange(1, N + 1)) + + @test_utils.test(arch=qd.cpu) def test_data_oriented_with_cyclic_attr_graph(): """A ``@qd.data_oriented`` class whose attribute graph contains a cycle (``parent.child.parent is parent``). From c61d32c8d4d6b46a17921a3070aa9a008e0792ee Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 06:54:06 -0700 Subject: [PATCH 43/46] [Test] Strengthen polymorphism + add cache-hit predeclare ndarray test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ``test_data_oriented_polymorphic_attribute_set_across_instances``: the inverse-direction case now uses a kernel that *reads* ``s.extra`` (the conditional attribute) — without the per-instance walk this would silently miss ``('extra',)`` from the kernel-used path list. Adds a reassignment step that verifies same-shape ndarray swaps go through the per-call ``id(v)`` folding cleanly. - ``test_src_ll_cache_hit_predeclare_struct_ndarrays_pruned``: pins ``710ee4705``. A data_oriented arg with three ndarrays (``a``/``b``/``c``) but a kernel that only writes ``b``. Cold compile populates the fastcache with the flat-name pruning set; ``qd.reset()`` + ``qd.init()`` reloads it; cache-hit branch in ``_predeclare_struct_ndarrays`` must reproduce the same single-ndarray registration set, otherwise insertion-order registration would scramble slots and the write would land in ``state.a`` instead of ``state.b``. --- .../lang/fast_caching/test_src_ll_cache.py | 57 +++++++++++++++++++ tests/python/test_data_oriented_ndarray.py | 29 +++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 217c66ac30..79ac5316b6 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -441,6 +441,63 @@ def k1(self) -> tuple[qd.i32, qd.i32]: assert my_do.k1._primal.src_ll_cache_observations.cache_validated +@test_utils.test() +def test_src_ll_cache_hit_predeclare_struct_ndarrays_pruned(tmp_path: pathlib.Path) -> None: + """Pin the cache-hit fix for ``_predeclare_struct_ndarrays``: on a fastcache hit pass 0 is skipped so the + ``id(nd)``-keyed used-ndarray set is empty; without flat-name fallback pruning every reachable ndarray gets + registered, scrambling the kernel's arg-slot bindings (e.g. a kernel compiled to write ``state.b`` ends up + writing ``state.a`` at launch). The fix uses the cached ``used_vars_by_func_id[KERNEL_FUNC_ID]`` flat-name + set to gate registration on the cache-hit branch, reproducing the exact ndarray set the originating compile + produced. + + The test exercises both the cold (cache-store) and hot (cache-load) paths in the same process via + ``qd.reset()`` cycles, and asserts both that the ndarray the kernel writes to is the *correct* one and that + the other ndarrays are untouched — without the fix the value would land in ``state.a`` (the first + insertion-order ndarray) instead of ``state.b``. + """ + import numpy as np # local import keeps the test module's top-level deps unchanged + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @qd.data_oriented + class State: + def __init__(self) -> None: + self.a = qd.ndarray(qd.i32, shape=(N,)) + self.b = qd.ndarray(qd.i32, shape=(N,)) + self.c = qd.ndarray(qd.i32, shape=(N,)) + + @qd.pure + @qd.kernel + def write_b(s: qd.template()) -> None: + for i in range(N): + s.b[i] = (i + 1) * 17 + + # Cold: cache-miss path populates the fastcache (including the kernel-used flat-name set folded in by + # ``_fold_struct_nd_paths_into_pruning``). + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + # Hot: cache-hit path skips pass 0; this is the branch the fix protects. Without flat-name pruning all three + # ndarrays would be registered in insertion order, displacing ``state.b`` from the slot the kernel was + # compiled to write — and the write would land in ``state.a`` instead. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_loaded, "expected a fastcache hit on the second run" + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + class ModifySubFuncKernelArgs(pydantic.BaseModel): arch: str offline_cache_file_path: str diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index cc0b2d5206..46a0943c3b 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1012,7 +1012,8 @@ def run(s: qd.template()): np.testing.assert_array_equal(state_lean.x.to_numpy(), np.arange(1, N + 1)) # Inverse direction: a different class so per-class cache (if used by __slots__ fallback) starts fresh; first - # instance lacks 'extra', second has it. Verifies the second instance's 'extra' ndarray is correctly walked. + # instance lacks 'extra', second has it. The kernel actually *reads* ``s.extra`` so the inverse-direction + # silent miscache (which only manifests when the kernel touches the conditional attr) is exercised end-to-end. @qd.data_oriented class PolyState2: def __init__(self, with_extra: bool): @@ -1020,11 +1021,33 @@ def __init__(self, with_extra: bool): if with_extra: self.extra = qd.ndarray(qd.i32, shape=(N,)) + @qd.kernel + def run_using_extra(s: qd.template()): + for i in range(N): + s.x[i] = s.extra[i] * 10 + + # Walk the lean instance first (no 'extra'), populating any per-class state with the *narrow* attribute set. + # With the old per-class cache, this would lock in paths = [('x',)] for the class — and the next instance's + # ``extra`` would be silently absent from args_hash and from the kernel spec, leading to a wrong-shape kernel + # or a stale-cache hit when ``extra`` is later reassigned. state_lean2 = PolyState2(with_extra=False) run(state_lean2) + np.testing.assert_array_equal(state_lean2.x.to_numpy(), np.arange(1, N + 1)) + + # Now the polymorphic-attr-bearing instance. The per-instance walk must include ``('extra',)`` so that + # ``state_full2.extra``'s shape/id participates in the spec and the kernel compiles correctly. state_full2 = PolyState2(with_extra=True) - run(state_full2) - np.testing.assert_array_equal(state_full2.x.to_numpy(), np.arange(1, N + 1)) + state_full2.extra.from_numpy(np.array([2, 3, 5, 7], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([20, 30, 50, 70], dtype=np.int32)) + + # Reassignment-detection check: swap ``state_full2.extra`` to a different ndarray. The per-instance walk caches + # the *path list* ([('x',), ('extra',)]) on the instance, but the per-call args_hash still folds in + # ``id(getattr(state_full2, 'extra'))`` — so a swap should miss the spec-key cache and re-specialise. + state_full2.extra = qd.ndarray(qd.i32, shape=(N,)) + state_full2.extra.from_numpy(np.array([11, 13, 17, 19], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([110, 130, 170, 190], dtype=np.int32)) @test_utils.test(arch=qd.cpu) From 706f9b51cedfb2aa7299f9a8a13f6023ff3b066f Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 09:06:58 -0700 Subject: [PATCH 44/46] [Test] Add bug reproducer: needs_grad not folded into fastcache args_hash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pins the L2 collision between needs_grad=False (cold) and needs_grad=True (hot) scenes that differ only on the .grad-present flag. ``args_hasher.stringify_obj_type`` stringifies ndarray leaves by (dtype, ndim) only, so the narrow args_hash is the same and the second scene loads the without-grad artifact — the kernel's compiled parameter slot has needs_grad=False baked in but the launch routes the with-grad ndarray through the _QD_ARRAY_WITH_GRAD bucket, mis-aligning the parameter struct (silent wrong results or runtime OOB depending on slot layout). Test FAILS on this commit (asserts cache_loaded is False after the with-grad launch; observed True with the unfixed args_hasher). Fix to follow in next commit. --- .../lang/fast_caching/test_src_ll_cache.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 79ac5316b6..341a005b1d 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -441,6 +441,79 @@ def k1(self) -> tuple[qd.i32, qd.i32]: assert my_do.k1._primal.src_ll_cache_observations.cache_validated +@test_utils.test() +def test_src_ll_cache_needs_grad_distinguishes_args_hash(tmp_path: pathlib.Path) -> None: + """Pin: fastcache narrow args_hash MUST fold in ``needs_grad`` for every ndarray leaf. Without this, two scenes + that differ only by whether their ndarrays carry ``.grad`` (e.g. Genesis ``requires_grad=True`` vs ``False``) + collide on the L2 key, and the second scene loads the artifact compiled with the first scene's needs_grad + flag. The kernel's compiled parameter slots have a fixed needs_grad (``insert_ndarray_param`` bakes it into + the struct type), and the launch path branches on ``v.grad is not None`` to pick between ``_QD_ARRAY`` and + ``_QD_ARRAY_WITH_GRAD`` buckets — bind a needs_grad=True ndarray to a slot declared without grad and the + parameter struct's primal pointer ends up at the wrong offset, producing silent wrong results or runtime OOB. + + Reproduces the Genesis pattern (``kernel_init_link_fields`` taking a frozen-dataclass ``LinksState`` whose + members carry ``needs_grad`` from the scene's ``requires_grad``) with the smallest possible surface: a frozen + dataclass with two ``qd.f32`` ndarray members, a kernel that writes only the second one. First process compiles + without grad and stores L1+L2; second process (via ``qd.reset()`` + ``qd.init()``) runs the same kernel with + ``needs_grad=True`` members and asserts the second result is correct *and* that the L2 entry was a miss + (so the per-call needs_grad is correctly part of the cache key). + """ + import dataclasses + import numpy as np + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @dataclasses.dataclass(frozen=True) + class State: + a: qd.types.NDArray[qd.f32, 1] + b: qd.types.NDArray[qd.f32, 1] + + @qd.pure + @qd.kernel + def write_b(s: State) -> None: + for i in range(N): + s.b[i] = qd.cast(i + 1, qd.f32) * 7.0 + + # Cold run: needs_grad=False (default). Populates L1 (pruning info) + L2 (artifact compiled with the slot for + # ``s.b`` declared needs_grad=False) using the narrow args_hash from ``stringify_obj_type`` on the without-grad + # ndarray ``[nd-f32-1]``. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a1 = qd.ndarray(qd.f32, shape=(N,)) + b1 = qd.ndarray(qd.f32, shape=(N,)) + state1 = State(a=a1, b=b1) + write_b(state1) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + expected = np.array([7, 14, 21, 28], dtype=np.float32) + np.testing.assert_allclose(b1.to_numpy(), expected) + + # Hot run: needs_grad=True. With the bug, ``stringify_obj_type`` yields the same ``[nd-f32-1]`` string for the + # with-grad ndarray, the narrow args_hash collides, and L2 returns the without-grad artifact. The launch path + # then routes ``b2`` through ``_QD_ARRAY_WITH_GRAD`` because ``b2.grad`` is not None, against a slot the + # cached kernel declared as plain ``_QD_ARRAY`` — silent miscomputation or OOB. + # + # After the fix, the args_hash differs (needs_grad folded into the ndarray descriptor), L2 misses, the kernel + # is recompiled with the correct needs_grad=True slot, and the launch is well-typed. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + b2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + state2 = State(a=a2, b=b2) + write_b(state2) + # Diagnostic: the L2 must NOT load the no-grad artifact. After the fix this is a cache miss. + assert not write_b._primal.src_ll_cache_observations.cache_loaded, ( + "fastcache hit between needs_grad=False (cold) and needs_grad=True (hot) — narrow args_hash is " + "missing needs_grad, the without-grad artifact will be launched against with-grad ndarrays" + ) + # Correctness: the kernel writes the expected values, regardless of cache state. + np.testing.assert_allclose(b2.to_numpy(), expected) + # ``b2.grad`` is allocated but not written by this kernel — sanity check it survived as zero (i.e. the + # launch didn't smear primal data into the grad slot via a misaligned param struct). + np.testing.assert_allclose(b2.grad.to_numpy(), np.zeros(N, dtype=np.float32)) + + @test_utils.test() def test_src_ll_cache_hit_predeclare_struct_ndarrays_pruned(tmp_path: pathlib.Path) -> None: """Pin the cache-hit fix for ``_predeclare_struct_ndarrays``: on a fastcache hit pass 0 is skipped so the From 4398af7d604d0ca5b90aa6c92ec514842fc1412c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 09:13:25 -0700 Subject: [PATCH 45/46] [Fix] Fold needs_grad into fastcache narrow args_hash for ndarray leaves MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``ScalarNdarray``/``VectorNdarray``/``MatrixNdarray`` instances now stringify with an extra ``-g`` tag when their grad buffer is present. needs_grad is part of the compiled parameter-struct layout (``insert_ndarray_param`` bakes the grad pointer into the slot iff needs_grad=True), and the launch path picks between ``_QD_ARRAY`` and ``_QD_ARRAY_WITH_GRAD`` buckets off ``v.grad is not None`` — so two scenes that differ only by needs_grad MUST hash distinctly, otherwise L2 returns an artifact whose slots are mismatched at launch (silent miscomputation or runtime OOB depending on slot offset alignment). This is the root cause of the Genesis ``test_diff_*`` autodiff failures: the non-grad ``kernel_init_link_fields`` artifact landed in L2 first; the ``requires_grad=True`` run loaded that artifact and routed ``links_state.quat`` through ``_QD_ARRAY_WITH_GRAD`` against a slot declared without grad, producing the "Out of bound access to ndarray at arg 44 with indices [0,0,0]" assertion. Reproducer test was added in the previous commit; it now passes on x64, vulkan and cuda. Full fast_caching + test_data_oriented_ndarray + test_ad_dataclass suite: 257 passed, 6 skipped. --- .../quadrants/lang/_fast_caching/args_hasher.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 1792f678ee..b8967412df 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -290,10 +290,18 @@ def stringify_obj_type( arg_type = type(obj) _layout = getattr(obj, "_qd_layout", None) _layout_tag = "" if _layout is None else f"-L{_layout!r}" + # needs_grad is part of the parameter struct layout that ``insert_ndarray_param`` bakes into the compiled + # artifact (the slot includes a grad pointer iff needs_grad=True). Two ndarrays with identical dtype + ndim + # but differing needs_grad MUST hash distinctly, otherwise the L2 narrow args_hash collides and the cached + # artifact's slot is mis-matched at launch (the launch picks the _QD_ARRAY vs _QD_ARRAY_WITH_GRAD bucket + # off ``v.grad is not None``, against a slot whose grad-presence was fixed at compile time) — yielding + # silent miscomputation or runtime OOB depending on slot offset alignment. if isinstance(obj, ScalarNdarray): - return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, VectorNdarray): - return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): # disabled for now, because we need to think about how to handle field offset # etc @@ -301,7 +309,8 @@ def stringify_obj_type( _mark_warn_if_not_tensor_annotation(arg_meta) return _FAIL_FASTCACHE if isinstance(obj, MatrixNdarray): - return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): return f"[pt-{obj.dtype}-{obj.ndim}]" # type: ignore if isinstance(obj, np.ndarray): From 8a7ead44eab70cf965f1b5f9b5856e9ee8c10763 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 19 May 2026 11:35:18 -0700 Subject: [PATCH 46/46] [Lint] Reorder imports in needs_grad reproducer test --- tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 341a005b1d..819f9e3701 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -459,6 +459,7 @@ def test_src_ll_cache_needs_grad_distinguishes_args_hash(tmp_path: pathlib.Path) (so the per-call needs_grad is correctly part of the cache key). """ import dataclasses + import numpy as np arch = getattr(qd, qd.lang.impl.current_cfg().arch.name)