diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 7b942e4cab..9c6a8aed88 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -196,7 +196,9 @@ state.step() ### Fastcache -`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes, but is disabled for fields; see [Advanced — compound-type cache keying](fastcache.md#compound-type-cache-keying) for more information. +`@qd.kernel(fastcache=True)` is supported on methods of `@qd.data_oriented` classes. Cache keying is *pruning-driven*: only the members the kernel actually reads contribute to the cache key. Opaque metadata members (e.g. UUIDs, Pydantic config objects, back-pointers to parent solvers) are skipped by the args-hasher's narrow walk as long as the kernel doesn't read them — they cannot affect compiled code and cannot cause spurious cache misses. See [Pruning-driven argument hashing](fastcache.md#pruning-driven-argument-hashing) for the full keying rules. + +If the kernel *does* read a member of an unrecognised type, fastcache is disabled for the call with a `[FASTCACHE][UNKNOWN_TYPE]` diagnostic — there is no qualname-fallback. `qd.field` / `MatrixField` members at a kernel-read path likewise disable fastcache (recognised-but-unsupported tensor-like types). ### Under the hood diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index 5d4e9381c8..b5dee08bad 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -96,13 +96,16 @@ Fastcache supports the following parameter types: | `torch.Tensor` | Yes | dtype, ndim | | `numpy.ndarray` | Yes | dtype, ndim | | `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | -| `@qd.data_oriented` objects | Yes | member types recursively; primitive member types and values baked into kernel (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | +| `@qd.data_oriented` objects | Yes | member types recursively, narrowed by pruning (see [Pruning-driven argument hashing](#pruning-driven-argument-hashing)); primitive member types and values baked into kernel (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | | `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) | | Non-template primitives (int, float, bool) | Yes | type only | | `enum.Enum` | Yes | name and value | -| `qd.field` / `ScalarField` / `MatrixField` | **No** | — | +| `qd.field` / `ScalarField` / `MatrixField` at a kernel-read path | **No** | — | +| Anything else at a kernel-read path | **No** | — | -If any parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. +If any kernel-used parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. + +Kernel-unused members of any type — including unrecognised ones — do **not** disable fastcache. The pruning narrowing in the args hasher skips them entirely, so opaque metadata (UUIDs, Pydantic configs, parent back-pointers) attached to a `@qd.data_oriented` instance is harmless as long as the kernel doesn't read it. ### 3. Source code must be available @@ -120,6 +123,34 @@ Each compiled artifact is stored under a key derived from all of the following: When any of these change, the resulting key is different, so a new compilation occurs and a new entry is stored. Previous entries remain on disk — multiple cached versions coexist. You do not need to manually clear the cache when making code changes — the hash mismatch causes a transparent recompilation. +### Pruning-driven argument hashing + +Fastcache uses a **two-level cache**: + +- **L1** (source + config only): stores the set of *flat names* the kernel actually reads — e.g. `__qd_state__qd_x` for a kernel that reads `state.x`. This is the kernel's pruning info, computed at compile time by the AST builder. +- **L2** (L1 + narrow argument hash): stores the compiled artifact under a key that only hashes the arg paths in the L1 pruning set. + +#### Two rules + +The args hasher enforces two strict invariants: + +1. **The cache key may only include contributions from kernel-pruned paths.** A path is "pruned" (in the pruning-info sense) if Quadrants's compiler recorded the kernel reading it. Pruning info covers: + - Dataclass-flattened param accesses (`__qd_some_dc__qd_field` Names produced by `FlattenAttributeNameTransformer`, marked via `build_Name`). + - Ndarray accesses on data_oriented / template args (`struct_ndarray_launch_info`, folded into the pruning set by `Kernel._fold_struct_nd_paths_into_pruning`). + - Any other attribute-chain access on a kernel arg (`self.dofs.x`, `cfg.n`, …) recorded by `ASTTransformer.build_Attribute`'s `_qd_arg_chain` tracking and folded by `Kernel._fold_kernel_arg_chain_paths_into_pruning`. This covers primitive members baked into the kernel, nested struct paths, and accesses through `@qd.func` callees (propagated by `Pruning.record_after_call`). + + Paths *not* in the pruning set are skipped by the args hasher — they are guaranteed not to affect kernel codegen because the kernel cannot read them. + +2. **Unrecognised types at kernel-read paths must not be silently dropped or hashed by type-name.** If pruning says the kernel reads a path and the value at that path is a type the args hasher doesn't explicitly handle (Pydantic models, UUIDs, third-party tensor wrappers, …), fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning identifying the offending type plus an `[INVALID_FUNC]` log line confirming the cache is off. Capturing type identity without type parameters (dtype/shape on a hypothetical tensor type) would silently mask a value-affecting change. + +#### Practical implications + +- **Kernel-unused members do not affect the cache key.** If a `@qd.data_oriented` container has an opaque metadata member (UUID, Pydantic config, parent-solver back-pointer), and the kernel never reads it, the member is *not* hashed. Changes to `self._uid` or `self.cfg` don't disturb the cache key of a kernel that reads `self.dofs_state.x`. +- **Kernel-unused members of unrecognised types are also fine.** Pruning narrowing skips them before the type-recognition check runs. +- **Kernel-read members of unrecognised types fail fastcache loudly.** Either add explicit handling in `quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type` (for new tensor-like types whose dtype/shape matter), or move the access out of the kernel-read path (for opaque metadata that shouldn't be there in the first place). + +`qd.field` / `ScalarField` / `MatrixField` are *recognised-but-unsupported*: encountering one at a kernel-read path disables fastcache for the call (with a warn-level diagnostic). + ## Advanced ### Diagnostics @@ -147,7 +178,7 @@ On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules: -**`@qd.data_oriented`:** the walker descends into `vars(obj)`. For each child: +**`@qd.data_oriented`:** the walker descends into `vars(obj)`, narrowed by pruning info. For each walked child: - `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not. - Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 1a949d3007..b8967412df 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -11,11 +11,13 @@ from quadrants._tensor_wrapper import Tensor as _TensorWrapper from quadrants.types.annotations import Template +from .._dataclass_util import create_flat_name from .._ndarray import ScalarNdarray +from .._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from ..field import ScalarField from ..kernel_arguments import ArgMetadata from ..matrix import MatrixField, MatrixNdarray, VectorNdarray -from ..util import is_data_oriented +from ..util import is_data_oriented, is_dataclass_instance from .hash_utils import hash_iterable_strings _FIELD_TYPES = (ScalarField, MatrixField) @@ -40,6 +42,26 @@ _DC_REPR_NONE = object() +# Sentinel returned by ``stringify_obj_type`` whenever fastcache cannot safely hash a value: +# - Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). +# - Unrecognised type at a kernel-read path (no qualname fallback — see rules in fastcache.md). +# +# Containers (``dataclass_to_repr``, ``data_oriented`` branch, top-level ``hash_args`` loop) must propagate it upward +# — fastcache is disabled for the whole call and the caller writes the appropriate diagnostic. +class _FailFastcache: + """Singleton sentinel; identity-compared.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +_FAIL_FASTCACHE = _FailFastcache() + + class FastcacheSkip(enum.Enum): """Why fastcache does not apply to this call.""" @@ -47,11 +69,23 @@ class FastcacheSkip(enum.Enum): WARN = "warn" -# Set when the fastcache skip is something callers should warn about (as opposed to a Field arriving through a -# qd.Tensor annotation, which is a normal silent path). Reset at the start of each hash_args call. +# Set when the fastcache skip is something callers should warn about (as opposed to a ``Field`` arriving through a +# ``qd.Tensor`` annotation, which is a normal silent path). Reset at the start of each ``hash_args`` call. _should_warn = False +# Set of ``type(v).__qualname__`` strings we've already emitted the "unknown type at a kernel-read path" +# warning for. Lets the loop run thousands of times without spamming the log while still telling the user once +# that fastcache encountered an unrecognised type. Cleared by ``reset_unknown_type_warn_state`` (called from +# ``qd.init``) so each new test sees a clean log. +_warned_unknown_types: set[str] = set() + + +def reset_unknown_type_warn_state() -> None: + """Clear the once-per-process warned-unknown-types set. Called from test setup / ``qd.init``.""" + _warned_unknown_types.clear() + + def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: """Flag that a warning is needed if the Field didn't arrive through a qd.Tensor annotation.""" global _should_warn # pylint: disable=global-statement @@ -64,40 +98,131 @@ def _mark_should_warn() -> None: _should_warn = True -def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | None: - # PERF: For frozen dataclasses, the repr never changes. Cache it on the instance to avoid repeated +def _fail_unknown_type(obj: object, path: tuple[str, ...]) -> _FailFastcache: + """Disable fastcache for the call when an unrecognised type appears at a kernel-read path. + + Two rules at work here (see ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing"): + + 1. The fastcache key may *only* contain contributions from kernel-pruned paths — never a + ``type(v).__qualname__`` fallback for an unrecognised type, because that hash captures type identity + only and would silently mask a value-affecting change (e.g. a new tensor-like type whose dtype matters). + + 2. We may not silently *discard* something at a kernel-read path on the basis that it's unrecognised — + that would let unrecognised but codegen-affecting values escape the cache key and serve stale results. + + The only way to honour both rules is to fail the call's fastcache loudly, with a one-shot warning per type + so the user can add explicit handling in ``stringify_obj_type``. + """ + t = type(obj) + qualname = f"{getattr(t, '__module__', '')}.{getattr(t, '__qualname__', t.__name__)}" + if qualname not in _warned_unknown_types: + _warned_unknown_types.add(qualname) + _logging.warn( + f"[FASTCACHE][UNKNOWN_TYPE] Unrecognised type {qualname} reached at kernel-read path {path}. " + f"Fastcache is disabled for this call. Add explicit handling for this type to " + f"``quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type``, or refactor the kernel " + f"so it does not read this member." + ) + _mark_should_warn() + return _FAIL_FASTCACHE + + +def _child_flat(parent_flat: str | None, child_name: str) -> str | None: + """Compute the flat name a kernel parameter would have if it pointed at this container's child. + + For a top-level arg ``state`` with child ``x``: ``__qd_state__qd_x``. + For a deeper child ``state.dofs.x``: ``__qd_state__qd_dofs__qd_x`` (built incrementally). + + ``parent_flat`` is the *kernel-side* representation of this container's root: + - top-level arg of a kernel: ``arg_meta.name`` (e.g. ``"state"``, ``"self"``) — no ``__qd_`` prefix. + - any nested level: the already-computed ``__qd_…`` flat name. + + Returns ``None`` when ``parent_flat`` itself is ``None``, indicating "no path info available" — the caller + must walk the child unconditionally (i.e. ignore ``pruning_paths`` for this branch). + """ + if parent_flat is None: + return None + return create_flat_name(parent_flat, child_name) + + +def _is_path_used(pruning_paths: set[str] | None, child_flat: str | None) -> bool: + """Return True if a child at ``child_flat`` should be hashed. + + - ``pruning_paths is None``: pre-pruning-info compile — hash everything. + - ``child_flat is None``: caller could not compute a flat-name path (no parent_flat available) — hash + everything as well, so we never accidentally drop a child we couldn't classify. + - both non-None: only hash children whose flat name is in the set. Pruning's prefix-expansion step in + ``Kernel.materialize`` guarantees that if any descendant of ``__qd_a__qd_b`` is used, ``__qd_a__qd_b`` + itself is also in the set, so this single membership check is sufficient to decide whether to descend. + """ + if pruning_paths is None or child_flat is None: + return True + return child_flat in pruning_paths + + +def dataclass_to_repr( + raise_on_templated_floats: bool, + path: tuple[str, ...], + arg: Any, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Hash a dataclass instance, optionally narrowed by pruning information. + + Returns ``_FAIL_FASTCACHE`` if any field's subtree hits a recognised-but-unsupported tensor type (``ScalarField`` / + ``MatrixField``); otherwise a string. + + Pruning: if ``pruning_paths`` is non-None, only descend into fields whose flat name is in the set. Pruning's + prefix-expansion step ensures the set already contains all ancestors of used leaves, so checking the immediate + child's flat name is sufficient. + """ + # PERF: For frozen dataclasses the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py - # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``None`` is stored as the sentinel ``_DC_REPR_NONE`` to distinguish "not yet computed" from - # "computed but not fast-cacheable". + # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. A + # cached ``_DC_REPR_NONE`` sentinel distinguishes "computed but not fast-cacheable" from "not yet computed". + # + # The cache is keyed by ``(is_frozen, pruning_paths is None)`` because a frozen dataclass's pruned repr depends on + # the pruning_paths set — we use separate cache slots for pruned vs unpruned to avoid serving the wrong narrowing. + cache_attr = "_qd_dc_repr" if pruning_paths is None else "_qd_dc_repr_narrow" is_frozen = type(arg).__hash__ is not None if is_frozen: - cached = getattr(arg, "_qd_dc_repr", None) + cached = getattr(arg, cache_attr, None) if cached is _DC_REPR_NONE: - return None - if cached is not None: + return _FAIL_FASTCACHE + if cached is not None and pruning_paths is None: + # Narrow cache may be stale if pruning_paths set changed; only reuse the unpruned cache. return cached repr_l = [] for field in dataclasses.fields(arg): + child_flat = _child_flat(parent_flat, field.name) + if not _is_path_used(pruning_paths, child_flat): + continue child_value = getattr(arg, field.name) - _repr = stringify_obj_type(raise_on_templated_floats, path + (field.name,), child_value, arg_meta=None) - if _repr is None: + _repr = stringify_obj_type( + raise_on_templated_floats, + path + (field.name,), + child_value, + arg_meta=None, + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _repr is _FAIL_FASTCACHE: if isinstance(child_value, _FIELD_TYPES) and field.type is not _TensorWrapper: _mark_should_warn() if is_frozen: try: - object.__setattr__(arg, "_qd_dc_repr", _DC_REPR_NONE) + object.__setattr__(arg, cache_attr, _DC_REPR_NONE) except AttributeError: pass - return None + return _FAIL_FASTCACHE full_repr = f"{field.name}: ({_repr})" if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False): full_repr += f" = {child_value}" repr_l.append(full_repr) result = "[" + ",".join(repr_l) + "]" - if is_frozen: + if is_frozen and pruning_paths is None: try: - object.__setattr__(arg, "_qd_dc_repr", result) + object.__setattr__(arg, cache_attr, result) except AttributeError: pass return result @@ -111,36 +236,53 @@ def _is_template(arg_meta: ArgMetadata | None) -> bool: def stringify_obj_type( - raise_on_templated_floats: bool, path: tuple[str, ...], obj: object, arg_meta: ArgMetadata | None -) -> str | None: - """ - Convert an object into a string representation that only depends on its type. - - String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have - to be the actual python type string, just a string that is representative of the type, and won't collide - with different (allowed) types. String should be non-empty. - - Note that fields are not included in fast cache. - - arg_meta should only be non-None for the top level arguments and for data oriented objects. It is - used currently to determine whether a value is added to the cache key, as well as the name. eg - - at the top level, primitive types have their values added to the cache key if their annotation is qd.Template, - since they are baked into the kernel - - in data oriented objects, the values of all primitive types are added to the cache key, since they are baked - into the kernel, and require a kernel recompilation, when they change + raise_on_templated_floats: bool, + path: tuple[str, ...], + obj: object, + arg_meta: ArgMetadata | None, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Convert ``obj`` into a deterministic string that contributes to the fastcache key. + + Return contract: + - ``str``: hashable; the returned string contributes to the cache key. + - ``_FAIL_FASTCACHE``: fastcache cannot safely hash this value — caller must propagate upward and + disable fastcache for the whole call. Triggered by: + * Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). + * Unrecognised type at this kernel-read path (see ``_fail_unknown_type``). + + Two rules from ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing" govern this function: + + 1. The cache key may *only* include contributions from paths that pruning has marked kernel-accessed + (``pruning_paths``). Container walkers (dataclass + data_oriented) check ``_is_path_used`` per child and + skip non-pruned subtrees — kernel-unread paths are *guaranteed* not to affect codegen so this is safe by + construction. + + 2. At paths the kernel *does* read, unrecognised types must not be silently dropped or hashed by type-name — + fastcache fails the call (loudly, with a one-shot warning) so the gap can be closed. + + Parameters: + - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. Determines + whether primitive values are baked into the cache key (template-position primitives and all primitive members + of data-oriented containers). + - ``pruning_paths``: optional set of kernel-accessed flat names from L1 cache. When provided, + ``dataclass_to_repr`` and the ``data_oriented`` branch below descend only into children whose flat name is in + the set. Pruning info is populated by ``ASTTransformer.build_Name`` / ``build_Attribute`` (kernel-arg-rooted + chains) plus ``Pruning.fold_struct_nd_paths`` (ndarray accesses through data_oriented containers). + - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the ``self`` + arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's flat name for + the narrow-walk lookup. """ - # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` strips - # wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / data-oriented - # walkers below (``dataclass_to_repr`` and the ``is_data_oriented`` branch) do raw ``getattr`` to fetch struct - # fields, so a wrapper stored as a struct field arrives here un-stripped. Without this branch the hasher falls - # through to the ``[FASTCACHE][PARAM_INVALID]`` warning and disables the fast path for the whole call. See - # ``perso_hugh/doc/quadrants-tensor.md`` §8.14. - # ``qd.Tensor`` wrappers: unwrap to the bare impl so the type checks below match. After unwrap, ``_qd_layout`` (if - # any) is on the impl. + # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` + # strips wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / + # data-oriented walkers below do raw ``getattr`` to fetch struct fields, so a wrapper stored as a struct field + # arrives here un-stripped. Without this branch the hasher would hash the wrapper as an unknown type instead of + # unwrapping to the recognised impl. See ``perso_hugh/doc/quadrants-tensor.md`` §8.14. # - # PERF-CRITICAL: The _any_tensor_constructed guard makes this check zero-cost when no qd.Tensor has been created. - # ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer comparison (~10 - # ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. + # PERF-CRITICAL: the ``_any_tensor_constructed`` guard makes this check zero-cost when no ``qd.Tensor`` has been + # created. ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer + # comparison (~10 ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. if ( _tensor_wrapper._any_tensor_constructed and type(obj) in _TENSOR_WRAPPER_TYPES ): # pyright: ignore[reportOptionalMemberAccess] @@ -148,18 +290,27 @@ def stringify_obj_type( arg_type = type(obj) _layout = getattr(obj, "_qd_layout", None) _layout_tag = "" if _layout is None else f"-L{_layout!r}" + # needs_grad is part of the parameter struct layout that ``insert_ndarray_param`` bakes into the compiled + # artifact (the slot includes a grad pointer iff needs_grad=True). Two ndarrays with identical dtype + ndim + # but differing needs_grad MUST hash distinctly, otherwise the L2 narrow args_hash collides and the cached + # artifact's slot is mis-matched at launch (the launch picks the _QD_ARRAY vs _QD_ARRAY_WITH_GRAD bucket + # off ``v.grad is not None``, against a slot whose grad-presence was fixed at compile time) — yielding + # silent miscomputation or runtime OOB depending on slot offset alignment. if isinstance(obj, ScalarNdarray): - return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, VectorNdarray): - return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): # disabled for now, because we need to think about how to handle field offset # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return None + return _FAIL_FASTCACHE if isinstance(obj, MatrixNdarray): - return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): return f"[pt-{obj.dtype}-{obj.ndim}]" # type: ignore if isinstance(obj, np.ndarray): @@ -169,30 +320,44 @@ def stringify_obj_type( # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return None - if dataclasses.is_dataclass(obj): - return dataclass_to_repr(raise_on_templated_floats, path, obj) + return _FAIL_FASTCACHE + if is_dataclass_instance(obj): + return dataclass_to_repr( + raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat + ) if is_data_oriented(obj): + # Walk the data_oriented container's members, narrowed by pruning info — the kernel-compile path records + # every kernel-accessed attribute chain (ndarrays via ``_promote_ndarray_if_declared`` + + # ``Pruning.fold_struct_nd_paths``; primitives, opaque members, nested structs via + # ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` propagation calling ``pruning.mark_used``). Members + # not in ``pruning_paths`` are *guaranteed* not to affect kernel codegen because the kernel cannot read them. + # Dropping them from the hash satisfies rule 1 (cache only pruned paths). child_repr_l = ["da"] - _dict = {} try: - # pyright is ok with this approach _asdict = getattr(obj, "_asdict") _dict = _asdict() except AttributeError: _dict = obj.__dict__ for k, v in _dict.items(): - _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) - if _child_repr is None: - if _should_warn: - _logging.warn( - f"A kernel that has been marked as eligible for fast cache was passed 1 or more parameters " - f"that are not, in fact, eligible for fast cache: one of the parameters was a " - f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " - f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " - f"information, the path of the value was {path}." - ) - return None + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the per-instance + # ``BoundQuadrantsCallable`` on ``instance.__dict__`` so subsequent ``instance.method`` lookups skip the + # descriptor allocation; those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + child_flat = _child_flat(parent_flat, k) + if not _is_path_used(pruning_paths, child_flat): + continue + _child_repr = stringify_obj_type( + raise_on_templated_floats, + (*path, k), + v, + ArgMetadata(Template, ""), + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _child_repr is _FAIL_FASTCACHE: + return _FAIL_FASTCACHE child_repr_l.append(f"{k}: {_child_repr}") return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): @@ -210,21 +375,33 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - _mark_should_warn() - # The bit in caps should not be modified without updating corresponding test - # The rest of free text can be freely modified - # (will probably formalize this in more general doc / contributor guidelines at some point) - _logging.warn( - f"[FASTCACHE][PARAM_INVALID] Parameter with path {path} and type {arg_type} not allowed by fast cache." - ) - return None + # Unrecognised type at a kernel-read path — fail fastcache loudly. See ``_fail_unknown_type``. + return _fail_unknown_type(obj, path) def hash_args( - raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] + raise_on_templated_floats: bool, + args: Sequence[Any], + arg_metas: Sequence[ArgMetadata | None], + pruning_paths: set[str] | None = None, ) -> str | FastcacheSkip: - """Return the args hash string, or a HashFailure explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn + """Return the args hash string, or a ``FastcacheSkip`` explaining why hashing failed. + + Parameters: + - ``pruning_paths``: optional set of kernel-accessed flat names from the L1 cache (or freshly populated + after a cold compile). When provided, the container walkers skip children whose flat name is not in + the set; this is what keeps the cache key narrow and brittleness-free (no opaque-typed member can + affect the key unless the kernel actually reads it). + + Fastcache is disabled (``FastcacheSkip`` returned) when either: + - a recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``) is encountered at a + kernel-read path, OR + - an unrecognised type is encountered at a kernel-read path (see ``_fail_unknown_type``). + + Both cases are loud: ``FastcacheSkip.WARN`` triggers an ``[INVALID_FUNC]`` log line and the unknown-type + branch additionally emits a one-shot ``[UNKNOWN_TYPE]`` warning identifying the offending type. + """ + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement _should_warn = False g_num_calls += 1 g_num_args += len(args) @@ -235,11 +412,23 @@ def hash_args( ) for i_arg, arg in enumerate(args): start = time.time() - _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg]) + arg_meta = arg_metas[i_arg] + # Top-level arg flat name: matches the kernel-side ``arg_meta.name`` (no ``__qd_`` prefix at the root). + # Used by the narrow walk to construct child flat names compatible with ``pruning.used_vars_by_func_id``. + top_flat = arg_meta.name if arg_meta is not None else None + _hash = stringify_obj_type( + raise_on_templated_floats, + (str(i_arg),), + arg, + arg_meta, + pruning_paths=pruning_paths, + parent_flat=top_flat, + ) g_repr_time += time.time() - start - if not _hash: + if _hash is _FAIL_FASTCACHE: g_num_ignored_calls += 1 return FastcacheSkip.WARN if _should_warn else FastcacheSkip.FIELD_VIA_TENSOR + # All other return values are valid strings (qualname fallback handles unrecognised types). hash_l.append(_hash) start = time.time() res = hash_iterable_strings(hash_l) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 1c03bf737b..3a7222acfe 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,3 +1,30 @@ +"""Two-level fastcache key derivation and persistence. + +Two-level cache +--------------- +The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so the args +hash can walk *only* paths the kernel reads: + + - L1 (this module's ``make_source_config_key`` + ``load_pruning_info`` / ``store_pruning_info``): keyed by + source+config only (no args). Stores ``PruningInfo`` — the set of kernel-accessed flat names (e.g. + ``__qd_state__qd_x``) plus the ``graph_do_while_arg`` (also a kernel-source property). + + - L2 (``make_full_cache_key`` + ``load_full`` / ``store_full``): keyed by L1 key + the *narrow* args hash computed + with pruning info from L1. Stores the C++ ``frontend_cache_key`` that names the compiled artifact. + +Lookup flow on a warm call: L1 lookup → narrow args hash (paths from L1) → L2 lookup → load artifact. + +Cold compile flow: L1 miss → cold compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store L2. + +Safety implication +------------------ +A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect kernel +codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go through +``args_hasher.stringify_obj_type``; if it encounters an unrecognised type at such a path it fails the call's fastcache +loudly (one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning identifying the offending ``type(v).__qualname__``), so a missed +type registration is impossible to miss and cannot serve stale cached results. +""" + import json import os import warnings @@ -17,21 +44,55 @@ from .hash_utils import hash_iterable_strings from .python_side_cache import PythonSideCache +# Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to the +# same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had no such +# prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from prior Quadrants +# installs are simply ignored rather than mis-served. +_L1_MARKER = "l1" +_L2_MARKER = "l2" + + +def make_source_config_key(kernel_source_info: FunctionSourceInfo) -> str: + """Build the L1 cache key: source + config + version, with no dependence on args. + + Used by ``_try_load_fastcache`` before any args walking. The same key drives ``load_pruning_info`` / + ``store_pruning_info``; the matching ``make_full_cache_key`` derives the L2 key from this plus the narrow args + hash. + """ + kernel_hash = function_hasher.hash_kernel(kernel_source_info) + config_hash = config_hasher.hash_compile_config() + return hash_iterable_strings( + ( + _L1_MARKER, + quadrants.__version_str__, + kernel_hash, + config_hash, + kernel_source_info.filepath, + str(kernel_source_info.start_lineno), + "pruned", + "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", + ) + ) + -def create_cache_key( +def make_full_cache_key(source_config_key: str, narrow_args_hash: str) -> str: + """Build the L2 cache key from the L1 key + narrow args hash. See module docstring.""" + return hash_iterable_strings((_L2_MARKER, source_config_key, narrow_args_hash)) + + +def compute_narrow_args_hash( raise_on_templated_floats: bool, kernel_source_info: FunctionSourceInfo, args: Sequence[Any], arg_metas: Sequence[ArgMetadata], + pruning_paths: set[str] | None, ) -> str | None: + """Compute the args hash narrowed by ``pruning_paths`` (or wide if ``pruning_paths is None``). + + Returns ``None`` if a recognised-but-unsupported tensor-like type forces fastcache off — the caller emits + the appropriate user-visible diagnostic via the ``FastcacheSkip.WARN`` branch. """ - cache key takes into account: - - arg types - - cache value arg values - - kernel function (but not sub functions) - - compilation config (which includes arch, and debug) - """ - args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas) + args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas, pruning_paths=pruning_paths) if isinstance(args_hash, FastcacheSkip): if args_hash is FastcacheSkip.WARN: # the bit in caps at start should not be modified without modifying corresponding text @@ -41,24 +102,78 @@ def create_cache_key( "fast cached, because one or more parameter types were invalid" ) return None - kernel_hash = function_hasher.hash_kernel(kernel_source_info) - config_hash = config_hasher.hash_compile_config() - cache_key = hash_iterable_strings( - ( - quadrants.__version_str__, - kernel_hash, - args_hash, - config_hash, - kernel_source_info.filepath, - str(kernel_source_info.start_lineno), - "pruned", - "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", - ) + return args_hash + + +class L1CacheValue(BaseModel): + """Persisted L1 entry — pruning info that's source-and-config-deterministic (not args-dependent). + + Pruning info is the set of *flat names* (``__qd___qd___qd_…``) that the kernel actually reads. + Computed during compile (``Pruning.used_vars_by_func_id``); persisted here so subsequent calls can build + a narrow args hash without having to recompile. + + ``graph_do_while_arg`` is also stored here because it's a property of the kernel source (not of any + particular arg value). + + ``hashed_function_source_infos`` is the same content-hash list used for L2 validation; an L1 hit is + rejected if any helper source has changed since the L1 entry was written, even if the kernel source + itself hasn't (kernel_hash only covers the entry point). + """ + + used_py_dataclass_parameters: set[str] + hashed_function_source_infos: list[HashedFunctionSourceInfo] + graph_do_while_arg: str | None = None + + +def store_pruning_info( + source_config_key: str, + function_source_infos: Iterable[FunctionSourceInfo], + used_py_dataclass_parameters: set[str], + graph_do_while_arg: str | None = None, +) -> None: + """Persist the L1 entry after a cold compile. See ``L1CacheValue`` for what's stored / why.""" + if not source_config_key: + return + cache = PythonSideCache() + hashed_function_source_infos = function_hasher.hash_functions(function_source_infos) + cache_value = L1CacheValue( + used_py_dataclass_parameters=used_py_dataclass_parameters, + hashed_function_source_infos=list(hashed_function_source_infos), + graph_do_while_arg=graph_do_while_arg, ) - return cache_key + cache.store(source_config_key, cache_value.model_dump_json()) + + +def load_pruning_info( + source_config_key: str, +) -> tuple[set[str], str | None] | tuple[None, None]: + """Look up L1 cache. Returns (pruning_paths, graph_do_while_arg) on hit, (None, None) on miss / invalid. + + Validates ``hashed_function_source_infos`` against the current on-disk source; if any helper has changed + since the entry was written, the entry is invalid and we treat the lookup as a miss so the caller does a + cold compile (which will overwrite the stale L1 entry). + """ + cache = PythonSideCache() + maybe_value_json = cache.try_load(source_config_key) + if maybe_value_json is None: + return None, None + try: + cache_value = L1CacheValue.model_validate_json(maybe_value_json) + except (pydantic.ValidationError, json.JSONDecodeError, UnicodeDecodeError) as e: + warnings.warn(f"Failed to parse L1 cache entry: {e}") + return None, None + if not function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos): + return None, None + return cache_value.used_py_dataclass_parameters, cache_value.graph_do_while_arg class CacheValue(BaseModel): + """Persisted L2 entry — frontend cache key for the compiled artifact + source-validation metadata. + + The full pruning info is duplicated here for backward-compat with existing on-disk caches; it's the same + set that L1 also stores. The L1 set is the source of truth for narrowing the args hash on warm calls. + """ + frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] @@ -72,22 +187,10 @@ def store( used_py_dataclass_parameters: set[str], graph_do_while_arg: str | None = None, ) -> None: - """ - Note that unlike other caches, this cache is not going to store the actual value we want. - This cache is only used for verification that our cache key is valid. Big picture: - - we have a cache key, based on args and top level kernel function - - we want to use this to look up LLVM IR, in C++ side cache - - however, before doing that, we first want to validate that the source code didn't change - - i.e. is our cache key still valid? - - the python side cache contains information we will use to verify that our cache key is valid - - ie the list of function source infos - - Update! We are now going to store parameter pruning infomation, which is: - - used_py_dataclass_parameters: set[str] - - Update 2: we are going to store the cache key used by the c++ kernel cache, so that we can use that - to retrieve the immutable cached c++ kernel later, rather than, before, we were storing the c++ - cached kernel using the fast cache key, leading to bugs, when cached kernel file then had to be mutable. + """Persist the L2 entry — the C++ frontend cache key that names the compiled artifact for this call. + + ``fast_cache_key`` is the L2 key from ``make_full_cache_key``. The L1 entry has typically been stored + earlier by ``store_pruning_info`` during the same materialize. """ if not fast_cache_key: return @@ -117,9 +220,9 @@ def _try_load(cache_key: str) -> CacheValue | None: def load(cache_key: str) -> tuple[set[str], str, str | None] | tuple[None, None, None]: - """ - loads function source infos from cache, if available - checks the hashes against the current source code + """Look up L2 cache. Returns (used_pruning_paths, frontend_cache_key, graph_do_while_arg) on hit. + + Validates helper-source hashes against the live source; an L2 entry is invalidated if any helper changed. """ cache_value = _try_load(cache_key) if cache_value is None: diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 3289365767..aaa71620ce 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -1,13 +1,38 @@ -from ast import Name, Starred, expr, keyword +from ast import Attribute, Name, Starred, expr, keyword from collections import defaultdict from typing import TYPE_CHECKING, Any +from ._dataclass_util import create_flat_name from ._exceptions import raise_exception from ._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from .exception import QuadrantsSyntaxError from .func import Func from .kernel_arguments import ArgMetadata + +def _flatten_arg_node(node: expr) -> tuple[str, str] | None: + """Flatten an AST arg node into ``(flat_name, root_name_id)`` (or ``None`` if the node isn't a recognisable + name/attribute chain rooted at a plain Name). + + Returns both the full flat name (e.g. ``__qd_self__qd_dofs`` for ``self.dofs``) and the root Name's id (``self``). + Callers use the root id to distinguish kernel-arg-rooted chains (``self.dofs`` → root ``self``) from already- + flattened dataclass-arg references (``__qd_self__qd_dofs`` → root ``__qd_self__qd_dofs``). The flat path alone is + ambiguous because ``__qd_self__qd_dofs`` could be either an attribute chain *or* a single flattened Name. + + Mirrors ``FlattenAttributeNameTransformer._flatten_attribute_name`` but on the raw call-arg AST. + Used by ``record_after_call`` to handle ``f(self.dofs)`` etc. — without this the callee's pruning + info for attribute-chain args is dropped at the call boundary.""" + if isinstance(node, Name): + return node.id, node.id + if isinstance(node, Attribute): + parent = _flatten_arg_node(node.value) + if parent is None: + return None + parent_flat, root_id = parent + return create_flat_name(parent_flat, node.attr), root_id + return None + + if TYPE_CHECKING: import ast @@ -39,11 +64,123 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_used_parameters) # only needed for args, not kwargs self.callee_param_by_caller_arg_name_by_func_id: dict[int, dict[str, str]] = defaultdict(dict) + # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. Populated by the + # AST builder when a chain like ``self.x.y`` resolves to an ndarray that was pre-declared by + # ``_predeclare_struct_ndarrays``. On the second (enforcing) pass, ``_predeclare_struct_ndarrays`` only + # registers ndarrays whose id is in this set — dropping every reachable-but-unused ndarray from the kernel's + # parameter list. + self.used_struct_ndarray_ids: set[int] = set() + # Whether the non-enforcing first pass actually ran for this kernel materialize. When fastcache hits, we skip + # pass 0 entirely and ``used_struct_ndarray_ids`` is therefore unreliable — in that case + # ``_predeclare_struct_ndarrays`` falls back to registering every reachable ndarray (same as historical + # behavior). + self.pass_0_ran: bool = False + # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). + # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / + # qd.template). Kept *separate* from ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on + # the enforcing pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite ``s.x`` + # → ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a ``QuadrantsNameError: Name + # "__qd_s__qd_x" is not defined``. ``record_after_call`` propagates entries from callee to caller (so + # ``f(self.dofs)`` where ``f`` reads ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). + # After both compile passes, ``Pruning.fold_kernel_arg_chain_paths`` merges the kernel's set into + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` so fastcache stores them in L1 and the args_hasher narrow walk + # picks them up. + self.kernel_arg_chain_paths_by_func_id: dict[int, set[str]] = defaultdict(set) def mark_used(self, func_id: int, parameter_flat_name: str) -> None: assert not self.enforcing self.used_vars_by_func_id[func_id].add(parameter_flat_name) + def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None: + """Record a kernel-arg-rooted attribute chain (e.g. ``__qd_self__qd_dofs__qd_x``). + + Stored separately from ``used_vars_by_func_id`` — see the docstring on ``kernel_arg_chain_paths_by_func_id`` + for why.""" + assert not self.enforcing + self.kernel_arg_chain_paths_by_func_id[func_id].add(chain_flat_name) + + def fold_struct_nd_paths( + self, struct_ndarray_launch_info: list[tuple[Any, int, tuple[str, ...]]], arg_metas: list[ArgMetadata] + ) -> None: + """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat name set so + ``args_hasher.hash_args`` narrow-walks them correctly. + + Background: ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat names produced by + ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* args. + ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` in the AST and + don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths *are* recorded — in + ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — but only for ndarray members. + + Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` and union + all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the same convention used + for dataclass args. + + Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at compile, + opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk cannot distinguish + "kernel reads this primitive" from "kernel does not read this primitive". The + ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking *all* attrs of + a data_oriented container — narrowing only suppresses subtrees explicitly absent from the pruning set. So for + a data_oriented arg with mostly-ndarray members, the cache key correctly depends on the ndarray paths it + uses; for one with primitive members whose values matter, those members are still folded into the hash + (qualname-fallback / value paths). + """ + if not struct_ndarray_launch_info: + return + kernel_used: set[str] = self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID] + for _arg_id_cpp, arg_idx, attr_chain in struct_ndarray_launch_info: + if arg_idx < 0 or arg_idx >= len(arg_metas): + continue + arg_name = arg_metas[arg_idx].name + if not arg_name: + continue + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + kernel_used.add(flat) + + def fold_kernel_arg_chain_paths(self) -> None: + """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both compile + passes have completed. + + Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain (e.g. + ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into ``kernel_arg_chain_paths_by_func_id`` rather than + ``used_vars_by_func_id``, because the latter is read on the enforcing pass to build ``struct_locals`` for + ``FlattenAttributeNameTransformer``. If chain names appeared there, the transformer would rewrite ``self.n`` + into ``Name('__qd_self__qd_n')`` and ``build_Name`` would fail to find such a variable. + + Doing the merge here — after pass 1, just like ``fold_struct_nd_paths`` — avoids that interaction while + still making the chain paths available to the fastcache args-hash narrow walk. The set on + ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* object as + ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so updating one updates + both. + """ + kernel_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) + if not kernel_chain_paths: + return + self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_chain_paths) + + @staticmethod + def _propagate_chain_paths( + callee_chain_paths: set[str], + callee_param_name: str, + caller_flat: str, + chain_paths_to_propagate: set[str], + ) -> None: + """When ``f(self.dofs)`` is called and ``f``'s body reads ``s.x`` (callee param ``s`` bound to caller + attribute chain ``self.dofs``), the callee's chain-paths set contains ``__qd_s__qd_x`` but the + caller's chain-paths set must record ``__qd_self__qd_dofs__qd_x``. This helper does that + prefix substitution. Only chain paths starting with ``__qd___qd_`` are propagated + (chains rooted in unrelated callee args don't apply to this caller arg).""" + prefix = f"__qd_{callee_param_name}__qd_" + for sub in callee_chain_paths: + if sub.startswith(prefix): + rest = sub[len(prefix) :] + if caller_flat.startswith("__qd_"): + new_flat = f"{caller_flat}__qd_{rest}" + else: + new_flat = f"__qd_{caller_flat}__qd_{rest}" + chain_paths_to_propagate.add(new_flat) + def enforce(self) -> None: self.enforcing = True @@ -70,7 +207,9 @@ def record_after_call( callee_func_id = func.wrapper.func_id # type: ignore # Copy the used parameters from the child function into our own function. callee_used_vars = self.used_vars_by_func_id[callee_func_id] + callee_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(callee_func_id, set()) vars_to_unprune: set[str] = set() + chain_paths_to_propagate: set[str] = set() arg_id = 0 # node.args ordering will match that of the called function's metas_expanded, # because of the way calling with sequential args works. @@ -88,6 +227,20 @@ def record_after_call( callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + # Propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) AND through + # plain-Name args of non-flattened types (``f(self)``). Gate on the *root* Name id, not the resulting + # flat string: ``self.dofs`` flattens to ``__qd_self__qd_dofs`` (which starts with ``__qd_``) but its + # root is the bare kernel arg ``self`` — we still need to propagate. Already-flattened dataclass refs + # like ``Name('__qd_self__qd_dofs')`` have a ``__qd_*`` root and are handled by the ``vars_to_unprune`` + # path above. + flat = _flatten_arg_node(arg) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 # Note that our own arg_metas ordering will in general NOT match that of the child's. That's # because our ordering is based on the order in which we pass arguments to the function, but the @@ -101,8 +254,20 @@ def record_after_call( callee_param_name = kwarg.arg if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + flat = _flatten_arg_node(kwarg.value) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = kwarg.arg + # ``kwarg.arg`` is ``None`` for double-star unpacking (``**kwargs``); chain propagation requires + # a concrete parameter name so just skip. + if callee_param_name is not None: + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 self.used_vars_by_func_id[my_func_id].update(vars_to_unprune) + self.kernel_arg_chain_paths_by_func_id[my_func_id].update(chain_paths_to_propagate) used_callee_vars = self.used_vars_by_func_id[callee_func_id] child_arg_id = 0 diff --git a/python/quadrants/lang/_quadrants_callable.py b/python/quadrants/lang/_quadrants_callable.py index ba7e7b8217..0c071c6919 100644 --- a/python/quadrants/lang/_quadrants_callable.py +++ b/python/quadrants/lang/_quadrants_callable.py @@ -90,15 +90,31 @@ def __init__(self, fn: Callable, wrapper: Callable) -> None: self._adjoint: "Kernel | None" = None self.grad: "Kernel | None" = None self.is_pure: bool = False + self._attr_name: str | None = None update_wrapper(self, fn) + def __set_name__(self, owner: type, name: str) -> None: + # Captured at class-body time. ``data_oriented.make_kernel_indirect`` sets this explicitly on its replacement + # callable since setattr-after-class doesn't trigger __set_name__. + self._attr_name = name + def __call__(self, *args, **kwargs): return self.wrapper.__call__(*args, **kwargs) def __get__(self, instance, owner): if instance is None: return self - return BoundQuadrantsCallable(instance, self) + bound = BoundQuadrantsCallable(instance, self) + # Non-data descriptor (no __set__): a __dict__ entry on the instance wins over the descriptor on subsequent + # attribute lookups. Stash the bound callable there so future ``instance.method`` accesses skip __get__ + # allocation entirely (~0.6-1.2 us/call). Skip if the class uses __slots__ (no __dict__) or the attribute name + # wasn't captured. + name = self._attr_name + if name is not None: + inst_dict = getattr(instance, "__dict__", None) + if inst_dict is not None: + inst_dict[name] = bound + return bound class BoundQuadrantsCallable: diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index fd45b9913e..07d32892c1 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -15,18 +15,28 @@ _struct_nd_paths_for, ) - -def _collect_data_oriented_nd_ids(arg: Any, out: list) -> None: - """Append ``id(ndarray)`` for every ndarray reachable from ``arg``, using the per-class path cache in - ``_template_mapper_hotpath._struct_nd_paths_for`` so the first call walks ``vars(arg)`` once and subsequent calls - are just ``getattr`` chains. Empty path list short-circuits with zero work — critical for genesis's - ``@qd.data_oriented`` Solver passed as ``self`` to every kernel. - """ - for chain in _struct_nd_paths_for(arg): - v = arg - for a in chain: - v = getattr(v, a) - out.append(id(v)) +# Per-class disposition for the args_hash ndarray-id walk in ``TemplateMapper.lookup``: one of ``_SKIP`` (this class +# never contributes — non-data_oriented, or ``@qd.data_oriented(stable_members=True)``) or ``_PER_INSTANCE`` (delegate +# to ``_struct_nd_paths_for`` for a per-instance walk). The disposition depends only on type (data_oriented? +# stable_members?), so caching by class is correct. The *actual* path list is per-instance because @qd.data_oriented +# classes can have polymorphic attribute structure across instances (Genesis ``DataManager`` is the motivating case). +_arg_disposition: dict[type, object] = {} +_SKIP = object() +_PER_INSTANCE = object() + + +def _classify_disposition(arg: Any) -> object: + """First-sighting per-class disposition for the args_hash walk. Returns ``_SKIP`` (no per-call walk for this + class) or ``_PER_INSTANCE`` (delegate to ``_struct_nd_paths_for`` for a per-instance walk). + + ``_qd_stable_members`` here is a *launch-time perf hint only* (see ``@qd.data_oriented(stable_members=...)``). + It promises that ndarray members are never reassigned, which lets us skip the per-call walk entirely. It does + not affect fastcache key derivation.""" + if not is_data_oriented(arg): + return _SKIP + if type(arg).__dict__.get("_qd_stable_members"): + return _SKIP + return _PER_INSTANCE Key: TypeAlias = tuple[Any, ...] @@ -93,12 +103,37 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl # ``@qd.data_oriented`` containers can have their member ndarrays reassigned between calls on the same instance # (``state.x = other_ndarray``). The id(arg) alone does not capture that, so the spec-key cache below would # serve a stale entry and the new ndarray's dtype/ndim would be wrong. Fold the reachable ndarray ids into the - # hash. No-op for data_oriented containers that hold no ndarrays — the walker returns an empty list. See - # ``_collect_data_oriented_nd_ids``. + # hash for the (small) set of arg positions that need it. + # + # The kernel's ``template_slot_locations`` already gives us the subset of arg positions annotated as + # ``qd.template()`` — the only positions where a data_oriented container could appear (typed-dataclass args + # carry a specific dataclass type by construction and a data_oriented class is never a dataclass). So we only + # iterate ``template_slot_locations`` instead of all args (Genesis main kernel_step_1: 4 template positions + # of 16 args; Genesis branch step_1/step_2: 4 of 4). + # + # For each candidate, ``_arg_disposition`` caches the per-class decision (skip vs walk-per-instance) and the + # actual paths come from ``_struct_nd_paths_for`` (per-instance, stashed on ``arg._qd_nd_paths``). Per-instance + # path caching is load-bearing for correctness — @qd.data_oriented classes can have polymorphic attribute + # structure across instances (Genesis ``DataManager`` only allocates adjoint-cache members when + # ``requires_grad=True``); a per-class cache populated from one instance can't safely be reused for another. nd_ids: list = [] - for arg in args: - if is_data_oriented(arg): - _collect_data_oriented_nd_ids(arg, nd_ids) + for i in self.template_slot_locations: + arg = args[i] + cls = type(arg) + disposition = _arg_disposition.get(cls) + if disposition is None: + disposition = _classify_disposition(arg) + _arg_disposition[cls] = disposition + if disposition is _SKIP: + continue + paths = _struct_nd_paths_for(arg) + if not paths: + continue + for chain in paths: + v = arg + for a in chain: + v = getattr(v, a) + nd_ids.append(id(v)) if nd_ids: args_hash = args_hash + tuple(nd_ids) try: diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 6df1b54358..3a77f6702e 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -46,7 +46,11 @@ from quadrants.lang.expr import Expr from quadrants.lang.matrix import MatrixType from quadrants.lang.snode import SNode -from quadrants.lang.util import is_data_oriented, to_quadrants_type +from quadrants.lang.util import ( + is_data_oriented, + is_dataclass_instance, + to_quadrants_type, +) from quadrants.types import ( buffer_view_type, ndarray_type, @@ -72,18 +76,33 @@ _primitive_types = {int, float, bool} -# Per-class cache: ``type(arg) -> list[tuple[str, ...]]`` of attribute paths whose values are ``Ndarray`` instances at -# first observation. Populated lazily by ``_struct_nd_paths_for`` on the first call with each new data_oriented (or -# nested dataclass) class. Empty list means "this class holds no ndarrays anywhere", in which case subsequent calls -# pay only a dict-lookup per arg. Non-empty list short-circuits the full ``vars()`` recursion and just resolves each -# cached path via ``getattr`` chains. Critical for the genesis field-backend hot path: the ``@qd.data_oriented`` -# Solver is passed as ``self`` to most kernels and holds dozens of attributes, so a full per-call ``vars()`` walk -# costs >100ns per kernel and trashed FPS until this cache was added. +# Per-instance cache of ndarray attribute paths, stashed on the instance via ``object.__setattr__`` (compatible with +# frozen dataclasses). Used by both ``TemplateMapper.lookup``'s args_hash walk and the ``_extract_arg`` data_oriented +# descriptor walk. Per-instance caching is necessary because @qd.data_oriented classes can have *different attribute +# structures across instances of the same class* — Genesis ``DataManager``, for instance, only allocates +# ``*_adjoint_cache`` members when ``requires_grad=True``. A class-level cache populated from the first-ever instance +# would either crash on missing attributes (forward direction, "first instance has, second misses") or silently miss +# new ones (inverse direction), both of which produce wrong-shape kernel reuse. +# +# Steady-state cost: one ``__dict__`` lookup per arg per call (~30ns), same order as the previous class-level +# ``dict.get``. The walk itself (``_build_struct_nd_paths``) is paid once per instance lifetime at first kernel +# launch with that instance — typically O(10) instances per Genesis scene, so ~10us total at scene build. +# +# ``_struct_nd_paths_cache`` (below) is a fallback for ``__slots__`` classes that have no ``__dict__`` and so can't +# accept the ``object.__setattr__`` stash. Such classes inherit the legacy per-class-cache behaviour (and its +# polymorphic-instance limitations). Genesis data_oriented containers don't use ``__slots__``, so this branch is +# unreachable in practice. _struct_nd_paths_cache: dict[type, list[tuple]] = {} -def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None: - if dataclasses.is_dataclass(obj) and not isinstance(obj, type): +def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] | None" = None) -> None: + # Cycle-safe walker. Genesis object graphs have cross-references (e.g. ``solver -> scene -> sim -> solver``) and + # Pydantic-options-style children. ``_seen`` tracks ``id(obj)`` for the current traversal to avoid re-entering a + # node we've already expanded. Cheap (one ``set`` op per frame, only allocated when we actually start recursing) + # and bounds the walk to a finite depth regardless of the graph shape. + if _seen is None: + _seen = {id(obj)} + if is_dataclass_instance(obj): children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)) else: # ``NamedTuple`` (decorated as ``@qd.data_oriented``) has no instance ``__dict__`` — fall back to ``_asdict()`` @@ -101,26 +120,50 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None: v_type = type(v) if issubclass(v_type, Ndarray): out.append(chain) - elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)): - _build_struct_nd_paths(v, chain, out) + elif is_data_oriented(v) or is_dataclass_instance(v): + v_id = id(v) + if v_id in _seen: + continue + _seen.add(v_id) + _build_struct_nd_paths(v, chain, out, _seen) def _struct_nd_paths_for(arg: Any) -> list[tuple]: - """Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are - reachable from ``arg`` of type ``type(arg)``. First call for a class walks ``arg`` once via - ``_build_struct_nd_paths``; subsequent calls are dict-lookups. + """Return the per-instance cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` + instances are reachable from ``arg``. First call walks ``arg`` once via ``_build_struct_nd_paths`` and stashes + the result on the instance as ``_qd_nd_paths`` (via ``object.__setattr__`` so it works for frozen dataclasses + and ``@qd.data_oriented`` containers alike); subsequent calls fetch it via instance ``__dict__`` lookup. + + Per-instance caching is correctness-load-bearing: ``@qd.data_oriented`` classes can have different attribute + sets across instances of the same class (e.g. Genesis ``DataManager`` with vs without ``requires_grad``), so a + per-class cache populated from one instance can't be reused for another. ``__slots__`` classes without a + ``__dict__`` fall back to per-class caching (see ``_struct_nd_paths_cache``) and retain the legacy limitation. - Trades freshness for speed: assumes the *set* of ndarray-holding attribute paths is stable across instances of - the same class. The genesis Solver and similar ``@qd.data_oriented`` containers satisfy this — their ndarray - members are declared in ``__init__`` and not added later. If you need to add an ndarray attribute after the first - kernel launch on an instance of a given class, the new attribute won't be tracked. Call ``invalidate_struct_nd_ - paths_for`` (below) or restart the program. + Limitation: the path list is recorded once per instance. If a new ndarray attribute is attached to an instance + *after* its first kernel call (uncommon — Genesis containers declare all ndarrays in ``__init__``), it won't be + tracked until the cache is invalidated. Workaround: ``del arg.__dict__['_qd_nd_paths']`` (or restart the + process). """ + # Fast path: instance already walked. ``__dict__["…"]`` skips descriptor / ``__getattr__`` machinery (some + # third-party metaclasses, e.g. Pydantic, recurse infinitely on probe-style ``getattr`` for unknown names — + # see ``is_data_oriented`` for the same defensiveness). + try: + return arg.__dict__["_qd_nd_paths"] + except (AttributeError, KeyError): + pass + # ``__slots__`` fallback or first-sighting of this instance: check the class-level cache too, so that a + # ``__slots__`` class doesn't re-walk on every call. cls = type(arg) paths = _struct_nd_paths_cache.get(cls) - if paths is None: - paths = [] - _build_struct_nd_paths(arg, (), paths) + if paths is not None: + return paths + paths = [] + _build_struct_nd_paths(arg, (), paths) + try: + object.__setattr__(arg, "_qd_nd_paths", paths) + except AttributeError: + # ``__slots__`` class without a ``_qd_nd_paths`` slot — degrade to per-class caching. Loses correctness + # under polymorphic-instance attribute structure, but Genesis data_oriented containers don't use slots. _struct_nd_paths_cache[cls] = paths return paths @@ -130,15 +173,28 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding ndarrays — see the data_oriented branch in ``_extract_arg``. """ + # The path cache is per-instance (see ``_struct_nd_paths_for``) so polymorphic-instance attribute structure is + # handled correctly. Within a single instance's lifetime, a cached path's leaf may still cease to be an + # ``Ndarray`` (e.g. ``qd.Tensor``'s underlying impl swapped between an ``Ndarray`` and a ``MatrixField``); when + # that happens we silently skip the descriptor — ``v.element_type`` / ``v.shape`` / ``v._qd_layout`` are + # Ndarray-only accessors. The per-instance ``weakref(arg)`` part of the spec key still ensures correct cache + # discrimination across instances. for chain in _struct_nd_paths_for(arg): v = arg for a in chain: v = getattr(v, a) if type(v) in _TENSOR_WRAPPER_TYPES: v = v._unwrap() + if not isinstance(v, Ndarray): + continue + # ``Ndarray.shape`` can legitimately be ``None`` (uninitialised ``_physical_shape``); such an instance + # has no meaningful spec contribution, so skip it rather than crashing on ``len(None)``. + shape = v.shape + if shape is None: + continue type_id = id(v.element_type) element_type = type_id if type_id in primitive_types.type_ids else v.element_type - out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout)) + out.append((".".join(chain), element_type, len(shape), v.grad is not None, v._qd_layout)) def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: AnnotationType, arg_name: str) -> Any: @@ -214,6 +270,11 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati # # Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is # a no-op for the existing data_oriented + qd.field workloads (genesis field-backend). + # + # Opt-out: ``_qd_stable_members = True`` on the class (or ``@qd.data_oriented(stable_members=True)``) + # skips the per-call descriptor walk. + if type(arg).__dict__.get("_qd_stable_members"): + return weakref.ref(arg) nd_descriptors: list = [] _collect_struct_nd_descriptors(arg, nd_descriptors) if nd_descriptors: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 263a4a11a3..75c3f88ef8 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -16,6 +16,7 @@ from quadrants._lib import core as _qd_core from quadrants.lang import exception, expr, impl, matrix, mesh from quadrants.lang import ops as qd_ops +from quadrants.lang._dataclass_util import create_flat_name from quadrants.lang._ndrange import _Ndrange from quadrants.lang.ast.ast_transformer_utils import ( ASTTransformerFuncContext, @@ -85,6 +86,21 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): pruning = ctx.global_context.pruning if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters and node.id.startswith("__qd_"): ctx.global_context.pruning.mark_used(ctx.func.func_id, node.id) + # Track chains rooted at non-flattened parameter names: top-level ``@qd.kernel`` args + # (``ctx.kernel_args``) and ``@qd.func`` params (``ctx.fn_param_names``). Both appear in the AST as bare + # names (``self`` for a data_oriented kernel arg; ``static_rigid_sim_config`` for a ``qd.template()`` func + # arg bound to a ``@qd.data_oriented`` instance). ``build_Attribute`` propagates this annotation through + # ``state.dofs.x`` chains and ``mark_kernel_arg_chain_used``-s the flat name. The kernel's pruning narrow + # walk picks them up directly (kernel case) or after ``record_after_call`` propagates the callee's func-arg + # chains back through the call boundary (func case): e.g. ``func(s=self._sub)`` where ``func`` reads ``s.x`` + # ends up with ``__qd_self__qd__sub__qd_x`` recorded in the kernel's pruning, so the args-hasher hashes that + # primitive value into the fastcache key. + # Dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as already-flat + # ``__qd_…`` Names, handled by the block above via ``mark_used``. + if not node.id.startswith("__qd_") and (node.id in ctx.kernel_args or node.id in ctx.fn_param_names): + node._qd_arg_chain = node.id # type: ignore[attr-defined] + else: + node._qd_arg_chain = None # type: ignore[attr-defined] node.violates_pure, node.ptr, node.violates_pure_reason = ctx.get_var_by_name(node.id) # Flattened struct fields (``__qd_foo__qd_bar``) injected by ``populate_global_vars_from_dataclass`` are raw # ``Ndarray`` instances. ``build_Attribute`` already promotes these via ``_promote_ndarray_if_declared`` but @@ -656,14 +672,37 @@ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerFuncContext, n @staticmethod def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> Any: """If *value* is a bare ``Ndarray`` that was pre-declared as a kernel arg (in ``_predeclare_struct_ndarrays``), - return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged.""" + return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged. + + Also records the source ndarray id in ``pruning.used_struct_ndarray_ids`` on the non-enforcing first pass, so + that the enforcing second-pass ``_predeclare_struct_ndarrays`` can skip ndarrays that the kernel never actually + accesses. Both ``Ndarray`` instances and pre-existing ``AnyArray`` proxies (tagged with + ``_qd_source_ndarray_id``) are handled — the latter is the case for accesses in inlined ``@qd.func`` bodies + whose params were bound to already-promoted proxies by Option A in ``call_transformer``. + """ from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 - if not isinstance(value, Ndarray): + pruning = ctx.global_context.pruning + # Mirror ``build_Name``'s mark_used gate: only mark on the non-enforcing first pass and not during synthetic + # per-leaf argument expansion for ``@qd.func`` calls. The callee body's own accesses (which run with + # ``expanding_dataclass_call_parameters = False``) are what we want to count. + should_mark = not pruning.enforcing and not ctx.expanding_dataclass_call_parameters + if isinstance(value, Ndarray): + cache = ctx.global_context.ndarray_to_any_array + key = id(value) + arr = cache.get(key) + if arr is not None: + if should_mark: + pruning.used_struct_ndarray_ids.add(key) + return arr return value - cache = ctx.global_context.ndarray_to_any_array - arr = cache.get(id(value)) - return arr if arr is not None else value + # Pre-promoted ``AnyArray`` flowing through an inlined ``@qd.func`` body. Mark the underlying ndarray as used + # so it survives the enforcing-pass pruning. + if should_mark: + src_id = getattr(value, "_qd_source_ndarray_id", None) + if src_id is not None: + pruning.used_struct_ndarray_ids.add(src_id) + return value @staticmethod def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): @@ -772,6 +811,26 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) + # Propagate the kernel-arg-rooted chain annotation and record this access in pruning's *separate* chain-paths + # set. ``build_Name`` sets ``_qd_arg_chain`` on non-flattened kernel args (e.g. data_oriented ``self``); each + # Attribute access in the chain extends it (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). + # + # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses ``pruning.used_vars_by_func_id`` as + # ``struct_locals``, which drives ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would + # make the transformer rewrite ``self.x`` into ``Name('__qd_self__qd_x')``, and ``build_Name`` would then fail + # to find such a variable. ``mark_kernel_arg_chain_used`` puts the chain into a *separate* per-func set that's + # merged into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` only *after* both compile passes, by + # ``Pruning.fold_kernel_arg_chain_paths`` — so the fastcache args-hash narrow walk picks them up without + # breaking codegen. + parent_chain = getattr(node.value, "_qd_arg_chain", None) + if parent_chain is not None: + flat = create_flat_name(parent_chain, node.attr) + node._qd_arg_chain = flat # type: ignore[attr-defined] + pruning = ctx.global_context.pruning + if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters: + pruning.mark_kernel_arg_chain_used(ctx.func.func_id, flat) + else: + node._qd_arg_chain = None # type: ignore[attr-defined] return node.ptr @staticmethod diff --git a/python/quadrants/lang/ast/ast_transformer_utils.py b/python/quadrants/lang/ast/ast_transformer_utils.py index 506778c683..fa784a3522 100644 --- a/python/quadrants/lang/ast/ast_transformer_utils.py +++ b/python/quadrants/lang/ast/ast_transformer_utils.py @@ -247,6 +247,17 @@ def __init__( self.visited_funcdef = False self.is_real_function = is_real_function self.kernel_args: list = [] + # Names of the bare (non-flattened) parameters of a ``@qd.func`` being processed. Used by ``build_Name`` to + # seed ``_qd_arg_chain`` for attribute accesses rooted at a func param (e.g. + # ``static_rigid_sim_config.para_level`` where ``static_rigid_sim_config`` is a ``qd.template()`` arg bound to + # a ``@qd.data_oriented`` instance). Without this, chains rooted at func params would not be recorded in + # pruning, and the args-hasher would skip over kernel-read primitive members of nested data_oriented + # containers — leading to stale fastcache hits when those members change between calls. + # ``kernel_args`` only tracks top-level ``@qd.kernel`` args; ``_transform_func_arg`` for a ``@qd.func`` does + # not append to it (see function_def_transformer.py). This separate set avoids piggy-backing on + # ``kernel_args`` so the existing "kernel arg is immutable" diagnostic in ``build_assign_annotated`` doesn't + # start firing for func params. + self.fn_param_names: set[str] = set() self.only_parse_function_def: bool = False self.autodiff_mode = autodiff_mode self.loop_depth: int = 0 diff --git a/python/quadrants/lang/ast/ast_transformers/call_transformer.py b/python/quadrants/lang/ast/ast_transformers/call_transformer.py index 0d709ebd01..2bc22e8650 100644 --- a/python/quadrants/lang/ast/ast_transformers/call_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/call_transformer.py @@ -166,17 +166,23 @@ def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywo @staticmethod def _expand_Call_dataclass_args( - ctx: ASTTransformerFuncContext, args: tuple[ast.stmt, ...] + ctx: ASTTransformerFuncContext, + args: tuple[ast.stmt, ...], + called_needed: set[str] | None = None, + callee_arg_names: list[str] | None = None, ) -> tuple[tuple[ast.stmt, ...], tuple[ast.stmt, ...]]: """ - We require that each node has a .ptr attribute added to it, that contains - the associated Python object + We require that each node has a .ptr attribute added to it, that contains the associated Python object. + + ``called_needed`` and ``callee_arg_names`` are used only for the attribute-accessed-instance branch (Option A + for data_oriented @qd.func calls): the caller cannot construct a flat name from its own ``arg.id`` (the arg is + an ast.Attribute), so we look up pruning against the callee's parameter name at the same positional index. """ args_new = [] added_args = [] pruning = ctx.global_context.pruning func_id = ctx.func.func_id - for arg in args: + for arg_idx, arg in enumerate(args): val = arg.ptr if dataclasses.is_dataclass(val) and isinstance(val, type): dataclass_type = val @@ -204,6 +210,54 @@ def _expand_Call_dataclass_args( else: args_new.append(arg_node) added_args.append(arg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed positionally (e.g. ``self.state`` inside a @qd.data_oriented kernel + # method). Expand into per-leaf attribute accesses against the same AST node, mirroring the typed-arg + # (instance-of-type) path above but emitting ``ast.Attribute`` children rather than ``ast.Name``. + # ``added_args`` items must not carry ``.ptr`` (build_stmt populates it downstream); only the + # intermediate node used for recursion does. + dataclass_type = type(val) + # For pruning, match the callee's flat name (it may have pruned unused fields). Use the callee's + # parameter name at this positional index. + callee_param = ( + callee_arg_names[arg_idx] + if (called_needed is not None and callee_arg_names is not None and arg_idx < len(callee_arg_names)) + else None + ) + for field in dataclasses.fields(dataclass_type): + if called_needed is not None and callee_param is not None: + callee_flat_name = create_flat_name(callee_param, field.name) + if callee_flat_name not in called_needed: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + child_node = ast.Attribute( + value=arg, + attr=field.name, + ctx=load_ctx, + lineno=arg.lineno, + end_lineno=arg.end_lineno, + col_offset=arg.col_offset, + end_col_offset=arg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + child_node.ptr = child_val + # Recurse, threading the renamed scope: the callee's expanded flat name (e.g. + # ``__qd_state__inner``) is the synthetic param name for the nested level. + nested_callee_param = ( + create_flat_name(callee_param, field.name) if callee_param is not None else None + ) + _added_args, _args_new = CallTransformer._expand_Call_dataclass_args( + ctx, + (child_node,), + called_needed=called_needed, + callee_arg_names=[nested_callee_param] if nested_callee_param is not None else None, + ) + args_new.extend(_args_new) + added_args.extend(_added_args) + else: + args_new.append(child_node) + added_args.append(child_node) else: args_new.append(arg) return tuple(added_args), tuple(args_new) @@ -261,6 +315,46 @@ def _expand_Call_dataclass_kwargs( else: kwargs_new.append(kwarg_node) added_kwargs.append(kwarg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed as a keyword arg (e.g. ``write(state=self.state)`` inside a + # @qd.data_oriented kernel method). Expand into per-leaf keyword args whose values are attribute + # accesses against the original value node (e.g. ``__qd_state__x=self.state.x``). + dataclass_type = type(val) + for field in dataclasses.fields(dataclass_type): + child_name = create_flat_name(kwarg.arg, field.name) + if used_args is not None and child_name not in used_args: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + src_node = ast.Attribute( + value=kwarg.value, + attr=field.name, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + src_node.ptr = child_val + kwarg_node = ast.keyword( + arg=child_name, + value=src_node, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + kwarg_node.ptr = {child_name: child_val} + _added_kwargs, _kwargs_new = CallTransformer._expand_Call_dataclass_kwargs( + ctx, [kwarg_node], used_args + ) + kwargs_new.extend(_kwargs_new) + added_kwargs.extend(_added_kwargs) + else: + kwargs_new.append(kwarg_node) + added_kwargs.append(kwarg_node) else: kwargs_new.append(kwarg) return added_kwargs, kwargs_new @@ -286,11 +380,21 @@ def build_Call(ctx: ASTTransformerFuncContext, node: ast.Call, build_stmt, build is_func_base_wrapper = func_type in {QuadrantsCallable, BoundQuadrantsCallable} pruning = ctx.global_context.pruning called_needed = None + callee_arg_names: list[str] | None = None if pruning.enforcing and is_func_base_wrapper: called_func_id_ = func.wrapper.func_id # type: ignore called_needed = pruning.used_vars_by_func_id[called_func_id_] + if is_func_base_wrapper: + # callee param names (used by the attribute-instance positional-expansion path so it can match the + # callee's already-pruned flat names). + try: + callee_arg_names = [m.name for m in func.wrapper.arg_metas] # type: ignore[attr-defined] + except AttributeError: + callee_arg_names = None - added_args, node_args = CallTransformer._expand_Call_dataclass_args(ctx, node.args) + added_args, node_args = CallTransformer._expand_Call_dataclass_args( + ctx, node.args, called_needed=called_needed, callee_arg_names=callee_arg_names + ) added_keywords, node_keywords = CallTransformer._expand_Call_dataclass_kwargs(ctx, node.keywords, called_needed) # Create variables for the now-expanded dataclass members. diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 1bdd14dbd8..7668620467 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -34,7 +34,11 @@ from quadrants.lang.matrix import MatrixType from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType -from quadrants.lang.util import is_data_oriented, to_quadrants_type +from quadrants.lang.util import ( + is_data_oriented, + is_dataclass_instance, + to_quadrants_type, +) from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types @@ -152,8 +156,8 @@ def _transform_kernel_arg( elif isinstance(field.type, type) and getattr(field.type, "_data_oriented", False): # ``@qd.data_oriented`` field type inside a typed-dataclass kernel arg. The two patterns are # semantically incompatible at this layer: dataclass kernel-arg recursion uses annotations to - # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers don't - # carry per-attribute type annotations — they need a value-driven walk + # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers + # don't carry per-attribute type annotations — they need a value-driven walk # (``_predeclare_struct_ndarrays``), which only fires for ``qd.template()`` / ``qd.Tensor`` # annotations. Rather than silently miscompile, raise a clear error pointing users to the # recommended pattern. @@ -227,37 +231,85 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: Also stores ``(arg_id, template_arg_idx, attr_chain)`` tuples in ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. + + Pruning: in the enforcing (second) compile pass, ``pruning.used_struct_ndarray_ids`` contains the set of + ``id(ndarray)`` values that ``_promote_ndarray_if_declared`` observed being accessed during the first pass + (directly in the kernel body, or transitively through ``@qd.func`` inlining). We register only those, dropping + every unused ndarray from the kernel's parameter list. On the first pass the set is empty / not yet populated, + so we register everything as today (correctness: the first pass needs every reachable ndarray in the cache for + ``build_Attribute`` to resolve the accesses that *will* populate the set). """ + from quadrants.lang._pruning import Pruning # pylint: disable=C0415 from quadrants.lang.util import cook_dtype # pylint: disable=C0415 cache = ctx.global_context.ndarray_to_any_array launch_info = ctx.global_context.struct_ndarray_launch_info + pruning = ctx.global_context.pruning + used_ids = getattr(pruning, "used_struct_ndarray_ids", None) + # Only prune on the enforcing pass when we actually ran pass 0 to populate the used-ndarray set. On a + # fastcache hit pass 0 is skipped and the set is empty. + prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) + # On a fastcache hit (enforcing without a pass-0 run), the `id(nd)` set is empty, but the *flat-name* set on + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` was loaded from cache and already contains every kernel-accessed + # leaf path (folded in by ``Pruning.fold_struct_nd_paths`` during the compile that produced the cache entry). + # Use that to prune the walk so we register the exact same ndarray set as the originating compile produced — + # without this, every reachable ndarray gets registered, the kernel's arg slots get rebound to the wrong + # ndarrays at launch, and physics silently breaks. + prune_from_flat_names = pruning.enforcing and not getattr(pruning, "pass_0_ran", False) + kernel_used_flat_names = ( + pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) if prune_from_flat_names else None + ) - def _walk_obj(obj, arg_idx, path): - if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + # Cycle-safe walker: Genesis object graphs have cross-references (e.g. solver <-> scene <-> sim) so we must + # avoid re-entering the same node. ``seen`` is shared across the whole arg's traversal — ``id(obj)`` is + # stable for the duration of this compile and we never need to revisit a node since the ndarray-set rooted at + # it doesn't depend on the path we took to reach it. + def _walk_obj(obj, arg_idx, path, seen): + if is_dataclass_instance(obj): for field in dataclasses.fields(obj): child = getattr(obj, field.name) if isinstance(child, _TensorClass): child = child._unwrap() if isinstance(child, _ndarray.Ndarray): _register_ndarray(child, arg_idx, (*path, field.name)) - elif (dataclasses.is_dataclass(child) and not isinstance(child, type)) or is_data_oriented(child): - _walk_obj(child, arg_idx, (*path, field.name)) + elif is_dataclass_instance(child) or is_data_oriented(child): + child_id = id(child) + if child_id in seen: + continue + seen.add(child_id) + _walk_obj(child, arg_idx, (*path, field.name), seen) else: for attr_name, attr_val in vars(obj).items(): if isinstance(attr_val, _TensorClass): attr_val = attr_val._unwrap() if isinstance(attr_val, _ndarray.Ndarray): _register_ndarray(attr_val, arg_idx, (*path, attr_name)) - elif (dataclasses.is_dataclass(attr_val) and not isinstance(attr_val, type)) or is_data_oriented( - attr_val - ): - _walk_obj(attr_val, arg_idx, (*path, attr_name)) + elif is_dataclass_instance(attr_val) or is_data_oriented(attr_val): + attr_id = id(attr_val) + if attr_id in seen: + continue + seen.add(attr_id) + _walk_obj(attr_val, arg_idx, (*path, attr_name), seen) def _register_ndarray(nd, arg_idx, attr_chain): key = id(nd) if key in cache: return + if prune and key not in used_ids: + return + if prune_from_flat_names: + # Build the leaf flat name (e.g. ``__qd_self__qd__collider_state__qd_active_buffer``) + # and skip registration when the kernel's cached pruning set doesn't contain it. + if arg_idx < 0 or arg_idx >= len(ctx.func.arg_metas): + return + arg_name = ctx.func.arg_metas[arg_idx].name + if not arg_name: + return + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + if flat not in kernel_used_flat_names: + return from quadrants._lib import core as _qd_core # pylint: disable=C0415 element_type = cook_dtype(nd.element_type) @@ -272,6 +324,10 @@ def _register_ndarray(nd, arg_idx, attr_chain): _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), _qd_layout=layout, ) + # Tag the AnyArray with the source ndarray id so ``_promote_ndarray_if_declared`` can mark this ndarray + # as used even when the access reaches it via an already-promoted AnyArray (e.g. callee bodies bound to + # per-leaf args by Option A). + arr._qd_source_ndarray_id = key cache[key] = arr launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) @@ -287,10 +343,10 @@ def _register_ndarray(nd, arg_idx, attr_chain): val = val._unwrap() if isinstance(val, _ndarray.Ndarray): continue - if dataclasses.is_dataclass(val) and not isinstance(val, type): - _walk_obj(val, i, ()) + if is_dataclass_instance(val): + _walk_obj(val, i, (), {id(val)}) elif hasattr(val, "__dict__"): - _walk_obj(val, i, ()) + _walk_obj(val, i, (), {id(val)}) @staticmethod def _unwrap_tensor(data: Any) -> Any: @@ -306,6 +362,15 @@ def _transform_func_arg( argument_type: Any, data: Any, ) -> None: + # Record the bare (non-flattened) func param name so ``build_Name`` can seed ``_qd_arg_chain`` for attribute + # accesses rooted at this param. Critical for ``qd.template()`` args bound to ``@qd.data_oriented`` instances + # (e.g. ``static_rigid_sim_config.para_level`` inside a ``@qd.func``): without this, the kernel's pruning set + # never learns about ``.para_level``, the args-hasher skips the value, and different ``para_level`` + # configurations collide in the fastcache key. Flat names starting with ``__qd_`` arrive here too via the + # dataclass-flatten recursion below; they're harmless to add (``build_Name``'s chain branch gates on + # ``not node.id.startswith("__qd_")``) but the bare-name entries are what enables propagation. + ctx.fn_param_names.add(argument_name) + # Template arguments are passed by reference. if isinstance(argument_type, annotations.template): ctx.create_variable(argument_name, data) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 6b636e717d..b00165337f 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -340,23 +340,60 @@ def reset(self) -> None: self.fe_ll_cache_observations = FeLlCacheObservations() def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType") -> set[str] | None: - frontend_cache_key: str | None = None + """Two-phase fastcache lookup. + + Phase 1 — L1 lookup keyed by source+config only (no args). Returns the set of kernel-accessed flat + names (pruning info). Hit OR miss, this only determines whether we have pruning info for the narrow + args walk; it never on its own justifies skipping pass 0 — that requires the C++ artifact to load. + + Phase 2 — narrow args walk + L2 lookup + artifact load. Only when *all three* succeed do we return + non-None and let ``materialize`` skip pass 0. The reason: pass 0 is what populates pruning info for + *every called ``@qd.func``* (not just the kernel itself). Skipping pass 0 is only safe when pass 1 + runs in ``only_parse_function_def`` mode (i.e. the C++ artifact is already loaded so the AST walker + never enters any callee body); otherwise callee variables can't be found in their func's empty + ``used_vars_by_func_id`` set and the build fails with "Name __qd_… is not defined". + + Side effects: populates ``self._l1_key`` (always when fastcache is active), ``self._pruning_paths_from_l1`` + (the L1 pruning info, or None if L1 miss — used by ``materialize`` for L1-store skipping and for + post-compile narrow-hash construction), and ``self.fast_checksum`` (the L2 key, when phase 2 computed + the narrow args hash). All three are read by the post-compile path in ``_maybe_persist_l1_and_set_l2_key``. + """ + self._l1_key = None # type: ignore[attr-defined] + self._pruning_paths_from_l1 = None # type: ignore[attr-defined] + self.fast_checksum = None if self.runtime.src_ll_cache and self.quadrants_callable and self.quadrants_callable.is_pure: kernel_source_info, _src = get_source_info_and_src(self.func) - self.fast_checksum = src_hasher.create_cache_key( - self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas + self._kernel_source_info_cached = kernel_source_info # reused by materialize / launch_kernel + self._l1_key = src_hasher.make_source_config_key(kernel_source_info) + + # Phase 1: L1 lookup — pruning info only, no args walk yet. + pruning_paths, cached_graph_do_while_arg = src_hasher.load_pruning_info(self._l1_key) + if pruning_paths is None: + # Cold L1. ``materialize`` will compile pass 0 + pass 1 to populate pruning info, then we + # store L1 + L2 after compile. ``cache_key_generated`` is intentionally NOT flipped to True + # here: it tracks "fastcache produced a valid L2 args hash" (the pre-refactor semantic), and + # we don't know yet whether the narrow args walk will succeed. + return None + self._pruning_paths_from_l1 = pruning_paths + + # Phase 2: narrow args hash + L2 lookup. + narrow_args_hash = src_hasher.compute_narrow_args_hash( + self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas, pruning_paths + ) + if narrow_args_hash is None: + # Recognised-but-unsupported tensor-like (Field / MatrixField) — fastcache off for this call. + # ``self.fast_checksum`` stays None so no L2 entry is written; ``cache_key_generated`` stays + # False to match the pre-refactor "Field disables fastcache key generation" contract. + return None + self.fast_checksum = src_hasher.make_full_cache_key(self._l1_key, narrow_args_hash) + self.src_ll_cache_observations.cache_key_generated = True + + used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg_l2 = src_hasher.load( + self.fast_checksum ) - used_py_dataclass_parameters = None - cached_graph_do_while_arg: str | None = None - if self.fast_checksum: - self.src_ll_cache_observations.cache_key_generated = True - used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg = src_hasher.load( # type: ignore[reportAssignmentType] - self.fast_checksum - ) if used_py_dataclass_parameters is not None and frontend_cache_key is not None: self.src_ll_cache_observations.cache_validated = True prog = impl.get_runtime().prog - assert self.fast_checksum is not None self.compiled_kernel_data_by_key[key] = prog.load_fast_cache( frontend_cache_key, self.func.__name__, @@ -366,11 +403,15 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True self.used_py_dataclass_parameters_by_key_enforcing[key] = used_py_dataclass_parameters - if cached_graph_do_while_arg is not None: - self.graph_do_while_arg = cached_graph_do_while_arg + self.graph_do_while_arg = cached_graph_do_while_arg_l2 or cached_graph_do_while_arg return used_py_dataclass_parameters + # L2 miss or artifact load failed: report cold so ``materialize`` does pass 0 + pass 1 (needed + # to populate per-callee pruning info). ``self.fast_checksum`` is still set so the post-compile + # ``src_hasher.store`` will write a fresh L2 entry under the narrow-args key. + self.graph_do_while_arg = cached_graph_do_while_arg or self.graph_do_while_arg + return None - elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: + if self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test # freetext can be freely modified. # As for why we are using `print` rather than eg logger.info, it is because this is only printed when @@ -381,7 +422,6 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, ...], arg_features=None): if key is None: key = (self.func, 0, self.autodiff_mode) - self.fast_checksum = None if key in self.materialized_kernels: return @@ -403,6 +443,8 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . range_begin = 0 if used_py_dataclass_parameters is None else 1 runtime = impl.get_runtime() for _pass in range(range_begin, 2): + if _pass == 0: + pruning.pass_0_ran = True if _pass >= 1: pruning.enforce() tree, ctx = self.get_tree_and_ctx( @@ -436,6 +478,23 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . self._struct_ndarray_launch_info_by_key[key] = getattr( ctx.global_context, "struct_ndarray_launch_info", [] ) + # Fold data_oriented ndarray attribute chains into the kernel's used-flat-names set so + # ``args_hasher.hash_args`` can narrow data_oriented walks too. ``used_vars_by_func_id`` only + # contains flat names from dataclass-arg expansion in ``extract_struct_locals_from_context``; + # data_oriented args don't go through that expansion, so accesses like ``self.x`` on an ndarray + # member are only tracked via ``struct_ndarray_launch_info``. Without this fold, narrow hashing + # for data_oriented args walks nothing — every (arg_idx, attr_chain) pair gets the same hash + # regardless of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the + # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). + pruning.fold_struct_nd_paths(self._struct_ndarray_launch_info_by_key.get(key, []), self.arg_metas) + # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested struct + # paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` tracking. Kept + # separate from ``used_vars_by_func_id`` during compile (would otherwise poison ``struct_locals`` + # and break codegen) — see the field-level docstring on + # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` assignment + # to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set by reference, so the + # final fastcache L1 entry sees all kernel-accessed paths. + pruning.fold_kernel_arg_chain_paths() else: for used_parameters in pruning.used_vars_by_func_id.values(): new_used_parameters = set() @@ -453,6 +512,57 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ] runtime._current_global_context = None + # Post-compile fastcache bookkeeping. See ``_maybe_persist_l1_and_set_l2_key`` docstring. + self._maybe_persist_l1_and_set_l2_key(key, py_args) + + def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...]) -> None: + """After a successful materialize, persist L1 (if missing) and set ``fast_checksum`` to the L2 key. + + Called at the end of ``materialize`` once both passes have completed (or once pass 1 has completed + with a loaded artifact). Two responsibilities: + + 1. If L1 was missing (``self._pruning_paths_from_l1 is None``), write the freshly-computed pruning info so + the next call from a new process can skip the args-walk warm-up. + + 2. If ``fast_checksum`` is still None (which means either L1 was missing, or L1 hit but phase 2 of + ``_try_load_fastcache`` saw a FIELD-related FastcacheSkip — in which case we keep ``None`` and the + post-compile ``src_hasher.store`` is skipped), compute the narrow args hash *now* using the just- + populated pruning info and derive the L2 key. The post-launch ``src_hasher.store`` call uses + ``self.fast_checksum`` as the L2 key. + + Side-effect helper; split out from ``materialize`` to keep the compile loop readable. + """ + l1_key = getattr(self, "_l1_key", None) + if not l1_key: + return # fastcache inactive for this kernel (not pure / no runtime.src_ll_cache) + kernel_source_info = getattr(self, "_kernel_source_info_cached", None) + if kernel_source_info is None: + return + used_params = self.used_py_dataclass_parameters_by_key_enforcing.get(key) + if used_params is None: + return + if getattr(self, "_pruning_paths_from_l1", None) is None: + src_hasher.store_pruning_info( + l1_key, + self.visited_functions, + used_params, + graph_do_while_arg=self.graph_do_while_arg, + ) + # If phase 2 didn't run (L1 cold) or returned None (FIELD encountered earlier — but in that case + # post-compile narrow hashing would also see the FIELD and produce None, which is fine: we want + # fast_checksum to stay None so no L2 entry is stored), compute the narrow args hash now. + if self.fast_checksum is None: + narrow_args_hash = src_hasher.compute_narrow_args_hash( + self.raise_on_templated_floats, + kernel_source_info, + py_args, + self.arg_metas, + used_params, + ) + if narrow_args_hash is not None: + self.fast_checksum = src_hasher.make_full_cache_key(l1_key, narrow_args_hash) + self.src_ll_cache_observations.cache_key_generated = True + def launch_kernel( self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None ) -> Any: @@ -476,10 +586,16 @@ def launch_kernel( if self._struct_ndarray_launch_info_by_key: struct_nd_info = self._struct_ndarray_launch_info_by_key.get(key) if struct_nd_info: + # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated with + # ``@qd.data_oriented(stable_members=True)``) promise their ndarray members are never reassigned, + # so we exclude them from the per-call ``_resolve_struct_ndarray`` walk that builds ``args_hash``. + # This is a *launch-time perf hint only* and has no fastcache role — fastcache derives its key + # from kernel-pruning info regardless of this flag. self._mutable_nd_cached_val = [ (idx, chain) for _, idx, chain in struct_nd_info - if type(args[idx]).__hash__ is None or is_data_oriented(args[idx]) + if type(args[idx]).__hash__ is None + or (is_data_oriented(args[idx]) and not type(args[idx]).__dict__.get("_qd_stable_members")) ] else: self._mutable_nd_cached_val = [] diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 01c74a256d..f3dedca01e 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -275,7 +275,7 @@ def grad(self, *args, **kwargs) -> "Kernel": return self._adjoint(self._kernel_owner, *args, **kwargs) -def data_oriented(cls): +def data_oriented(cls=None, *, stable_members: bool = False): """Marks a class as Quadrants compatible. To allow for modularized code, Quadrants provides this decorator so that @@ -299,21 +299,45 @@ def data_oriented(cls): >>> a.inc() Args: - cls (Class): the class to be decorated + cls (Class): the class to be decorated. + stable_members (bool): launch-context perf hint — if ``True``, declares that the class's ndarray-typed members + are allocated once and never reassigned between kernel calls. Quadrants will skip the per-call ndarray- + reference walk that ``Kernel.launch_kernel`` uses to detect ndarray reassignment on mutable containers + (~1-2 us/call savings on Genesis-style containers with dozens of ndarray attrs). Reassigning a member on + a ``stable_members`` class is undefined behaviour — the previously-compiled kernel will be reused even if + the new ndarray has different dtype/ndim/layout. May also be set as a class-level attribute + ``_qd_stable_members = True`` (equivalent). + + Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache argument hashing — the + fastcache key is derived from pruning info (the set of flat names the kernel actually reads), and + unrecognised types at kernel-read paths fail fastcache loudly with a one-shot ``[UNKNOWN_TYPE]`` + + ``[INVALID_FUNC]`` diagnostic (no qualname fallback). See ``docs/source/user_guide/fastcache.md``. Returns: - The decorated class. + The decorated class (or, when called with arguments, a decorator). """ + if cls is None: + return lambda c: data_oriented(c, stable_members=stable_members) + + def make_kernel_indirect(fun, is_property, attr_name): + # Capture the primal at decoration time so the per-call path skips the ``_BoundedDifferentiableMethod`` + # allocation. The class itself is validated when ``_BoundedDifferentiableMethod`` is invoked via the + # ``.grad()`` path; for the common primal call here we replicate the check inline. + primal = fun._primal - def make_kernel_indirect(fun, is_property): @wraps(fun) def _kernel_indirect(self, *args, **kwargs): - nonlocal fun - ret = _BoundedDifferentiableMethod(self, fun) - ret.__name__ = fun.__name__ # type: ignore - return ret(*args, **kwargs) + try: + return primal(self, *args, **kwargs) + except (QuadrantsCompilationError, QuadrantsRuntimeError) as e: + if impl.get_runtime().print_full_traceback: + raise e + raise type(e)("\n" + str(e)) from None ret = QuadrantsCallable(fun, _kernel_indirect) + # setattr-after-class doesn't trigger __set_name__; set the name explicitly so QuadrantsCallable.__get__ can + # cache the BoundQuadrantsCallable on instance.__dict__. + ret._attr_name = attr_name if is_property: ret = property(ret) return ret @@ -331,8 +355,10 @@ def _kernel_indirect(self, *args, **kwargs): if isinstance(fun, (BoundQuadrantsCallable, QuadrantsCallable)): if fun._is_wrapped_kernel: if fun._is_classkernel and attr_type is not staticmethod: - setattr(cls, name, make_kernel_indirect(fun, is_property)) + setattr(cls, name, make_kernel_indirect(fun, is_property, name)) cls._data_oriented = True + if stable_members: + cls._qd_stable_members = True return cls diff --git a/python/quadrants/lang/util.py b/python/quadrants/lang/util.py index a9f2f4bc07..0fb153c684 100644 --- a/python/quadrants/lang/util.py +++ b/python/quadrants/lang/util.py @@ -350,9 +350,32 @@ def get_traceback(stacklevel=1): def is_data_oriented(obj: Any) -> bool: - # Use getattr on class instead of object to bypass custom __getattr__ method that is - # overwritten at instance level and very slow. - return getattr(type(obj), "_data_oriented", False) + # Look up ``_data_oriented`` directly via ``__dict__`` on each class in the MRO, never through ``getattr``. Some + # third-party metaclasses (notably Pydantic's ``ModelMetaclass``) override ``__getattr__`` and recurse infinitely + # on missing attributes when probed for arbitrary names — ``getattr(type(obj), "_data_oriented", False)`` blows + # the stack on a Genesis ``RigidOptions`` instance. The MRO walk via ``__dict__`` skips any descriptor / + # ``__getattr__`` machinery; ``@qd.data_oriented`` always sets the flag directly on the decorated class so this + # finds it via ``cls.__dict__["_data_oriented"]`` without ever touching the metaclass attribute protocol. + for klass in type(obj).__mro__: + flag = klass.__dict__.get("_data_oriented") + if flag is not None: + return flag + return False + + +def is_dataclass_instance(obj: Any) -> bool: + # Metaclass-safe replacement for ``dataclasses.is_dataclass(obj) and not isinstance(obj, type)``. The stdlib + # implementation calls ``hasattr(type(obj), '__dataclass_fields__')``, which delegates to the metaclass + # ``__getattr__`` for missing names. Pathological metaclasses (Pydantic's ``ModelMetaclass``) recurse infinitely + # on arbitrary attribute lookups and blow the stack. Walking the MRO and probing ``__dict__`` directly avoids + # any descriptor / ``__getattr__`` machinery, mirroring ``is_data_oriented`` above. Also folds in the + # ``not isinstance(obj, type)`` guard since callers always pair the two. + if isinstance(obj, type): + return False + for klass in type(obj).__mro__: + if "__dataclass_fields__" in klass.__dict__: + return True + return False def is_qd_template(annotation: Any) -> bool: diff --git a/tests/python/quadrants/lang/ast/test_function_def_transformer.py b/tests/python/quadrants/lang/ast/test_function_def_transformer.py index a46d5e2cbc..20b422b20b 100644 --- a/tests/python/quadrants/lang/ast/test_function_def_transformer.py +++ b/tests/python/quadrants/lang/ast/test_function_def_transformer.py @@ -81,6 +81,9 @@ def test_process_func_arg(argument_name: str, argument_type: Any, expected_varia class MockContext: def __init__(self) -> None: self.variables: dict[str, Any] = {} + # Mirror the real ``ASTTransformerFuncContext.fn_param_names`` so ``_transform_func_arg`` can record bare + # param names without crashing. + self.fn_param_names: set[str] = set() def create_variable(self, name: str, data: Any) -> None: assert name not in self.variables diff --git a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py index a4bda2a1b1..f6f4b010aa 100644 --- a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py +++ b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py @@ -160,7 +160,14 @@ def k(x: qd.Template): @test_utils.test(arch=qd.cpu) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows stderr not working with capfd") def test_fastcache_field_warnings_warn_struct_template_field(tmp_path, capfd): - """Struct with qd.Template-annotated field containing a Field — warning should fire.""" + """Struct with qd.Template-annotated field containing a Field — warning should fire when the field is + actually read by the kernel. + + Pruning-driven narrowing of args hashing only walks members the kernel reads; an unused dataclass field cannot + affect kernel codegen so it's correctly omitted from the hash (and from the Field-disables-fastcache check). For + the warning path to fire, the kernel must reference the Field — that matches the user-visible contract that + fastcache fails iff a "live" Field argument prevents safe parametrisation. + """ qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) @dataclasses.dataclass(frozen=True) @@ -173,7 +180,7 @@ class S: @qd.pure @qd.kernel def k(x: S): - pass + x.a[0] = 1 capfd.readouterr() k(s) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index e7a2d9952b..9a0fcfd271 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -23,6 +23,13 @@ @test_utils.test() def test_src_hasher_create_cache_key_vary_config() -> None: + """Source+config key (L1) is stable across re-init with identical config, changes when the config changes. + + Updated from the pre-refactor ``create_cache_key`` API (single-level, args-dependent) to the two-level + ``make_source_config_key`` (L1 — source+config only, no args). The L1 key is the right level to test + because config changes only affect the L1 layer; L2 adds the args-narrow hash on top. + """ + @qd.kernel def f1() -> None: pass @@ -31,15 +38,15 @@ def f1() -> None: # so we are forcing it to false each initialization for now qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_base = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_base = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_same = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_same = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False, random_seed=123) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_diff = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_diff = src_hasher.make_source_config_key(kernel_info) assert cache_key_base == cache_key_same assert cache_key_same != cache_key_diff @@ -103,7 +110,9 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc mod = temporary_module("child_diff_test_src_hasher_store_validate") kernel_info = get_fileinfos([mod.f1.fn])[0] fileinfos = get_fileinfos([mod.f1.fn, mod.f2.fn]) - fast_cache_key = src_hasher.create_cache_key(False, kernel_info, [], []) + # L2 key: source+config (L1) + narrow-args-hash. Use an empty narrow-args-hash since the test isn't + # exercising args at all — it tests the helper-source-change invalidation logic, which lives in L2. + fast_cache_key = src_hasher.make_full_cache_key(src_hasher.make_source_config_key(kernel_info), narrow_args_hash="") assert fast_cache_key is not None @@ -202,7 +211,9 @@ def src_hasher_vary_kernel_func_child(args: list[str]) -> None: sys.path.append(args_obj.module_file_path) mod = importlib.import_module(args_obj.module_name) info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) - cache_key = src_hasher.create_cache_key(False, info, [], []) + # Source+config key (L1) — varies with the *kernel source* (the property this test exercises) and is + # the same level as the pre-refactor ``create_cache_key`` call site, just without the args-dependent tail. + cache_key = src_hasher.make_source_config_key(info) print(f"CACHE_KEY={cache_key}") print(TEST_RAN) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 711839cf5d..819f9e3701 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -142,10 +142,16 @@ def k1(foo: qd.Template) -> None: k1(foo=RandomClass()) _out, err = capfd.readouterr() - assert "[FASTCACHE][PARAM_INVALID]" in err + # Unrecognised types at a (top-level) kernel-read path now fail fastcache loudly: a one-shot ``[UNKNOWN_TYPE]`` + # warning identifies the offending type, and ``[INVALID_FUNC]`` then reports the disabled cache. The old silent + # ``[PARAM_INVALID]`` dead-end is gone — the two rules driving this are documented in + # ``args_hasher.py::_fail_unknown_type`` and ``fastcache.md`` "Pruning-driven argument hashing": (1) only pruned + # paths may contribute to the cache key (so no qualname fallback), (2) unrecognised types at pruned paths must + # not be silently dropped. + assert "[FASTCACHE][UNKNOWN_TYPE]" in err assert RandomClass.__name__ in err assert "[FASTCACHE][INVALID_FUNC]" in err - assert k1.__name__ in err + assert "[FASTCACHE][PARAM_INVALID]" not in err @qd.kernel def not_pure_k1(foo: qd.Template) -> None: @@ -153,8 +159,10 @@ def not_pure_k1(foo: qd.Template) -> None: not_pure_k1(foo=RandomClass()) _out, err = capfd.readouterr() + # Without ``@qd.pure``, fastcache is not active at all — neither the new UNKNOWN_TYPE nor the old + # PARAM_INVALID / INVALID_FUNC warnings should fire. + assert "[FASTCACHE][UNKNOWN_TYPE]" not in err assert "[FASTCACHE][PARAM_INVALID]" not in err - assert RandomClass.__name__ not in err assert "[FASTCACHE][INVALID_FUNC]" not in err assert k1.__name__ not in err @@ -433,6 +441,137 @@ def k1(self) -> tuple[qd.i32, qd.i32]: assert my_do.k1._primal.src_ll_cache_observations.cache_validated +@test_utils.test() +def test_src_ll_cache_needs_grad_distinguishes_args_hash(tmp_path: pathlib.Path) -> None: + """Pin: fastcache narrow args_hash MUST fold in ``needs_grad`` for every ndarray leaf. Without this, two scenes + that differ only by whether their ndarrays carry ``.grad`` (e.g. Genesis ``requires_grad=True`` vs ``False``) + collide on the L2 key, and the second scene loads the artifact compiled with the first scene's needs_grad + flag. The kernel's compiled parameter slots have a fixed needs_grad (``insert_ndarray_param`` bakes it into + the struct type), and the launch path branches on ``v.grad is not None`` to pick between ``_QD_ARRAY`` and + ``_QD_ARRAY_WITH_GRAD`` buckets — bind a needs_grad=True ndarray to a slot declared without grad and the + parameter struct's primal pointer ends up at the wrong offset, producing silent wrong results or runtime OOB. + + Reproduces the Genesis pattern (``kernel_init_link_fields`` taking a frozen-dataclass ``LinksState`` whose + members carry ``needs_grad`` from the scene's ``requires_grad``) with the smallest possible surface: a frozen + dataclass with two ``qd.f32`` ndarray members, a kernel that writes only the second one. First process compiles + without grad and stores L1+L2; second process (via ``qd.reset()`` + ``qd.init()``) runs the same kernel with + ``needs_grad=True`` members and asserts the second result is correct *and* that the L2 entry was a miss + (so the per-call needs_grad is correctly part of the cache key). + """ + import dataclasses + + import numpy as np + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @dataclasses.dataclass(frozen=True) + class State: + a: qd.types.NDArray[qd.f32, 1] + b: qd.types.NDArray[qd.f32, 1] + + @qd.pure + @qd.kernel + def write_b(s: State) -> None: + for i in range(N): + s.b[i] = qd.cast(i + 1, qd.f32) * 7.0 + + # Cold run: needs_grad=False (default). Populates L1 (pruning info) + L2 (artifact compiled with the slot for + # ``s.b`` declared needs_grad=False) using the narrow args_hash from ``stringify_obj_type`` on the without-grad + # ndarray ``[nd-f32-1]``. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a1 = qd.ndarray(qd.f32, shape=(N,)) + b1 = qd.ndarray(qd.f32, shape=(N,)) + state1 = State(a=a1, b=b1) + write_b(state1) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + expected = np.array([7, 14, 21, 28], dtype=np.float32) + np.testing.assert_allclose(b1.to_numpy(), expected) + + # Hot run: needs_grad=True. With the bug, ``stringify_obj_type`` yields the same ``[nd-f32-1]`` string for the + # with-grad ndarray, the narrow args_hash collides, and L2 returns the without-grad artifact. The launch path + # then routes ``b2`` through ``_QD_ARRAY_WITH_GRAD`` because ``b2.grad`` is not None, against a slot the + # cached kernel declared as plain ``_QD_ARRAY`` — silent miscomputation or OOB. + # + # After the fix, the args_hash differs (needs_grad folded into the ndarray descriptor), L2 misses, the kernel + # is recompiled with the correct needs_grad=True slot, and the launch is well-typed. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + b2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + state2 = State(a=a2, b=b2) + write_b(state2) + # Diagnostic: the L2 must NOT load the no-grad artifact. After the fix this is a cache miss. + assert not write_b._primal.src_ll_cache_observations.cache_loaded, ( + "fastcache hit between needs_grad=False (cold) and needs_grad=True (hot) — narrow args_hash is " + "missing needs_grad, the without-grad artifact will be launched against with-grad ndarrays" + ) + # Correctness: the kernel writes the expected values, regardless of cache state. + np.testing.assert_allclose(b2.to_numpy(), expected) + # ``b2.grad`` is allocated but not written by this kernel — sanity check it survived as zero (i.e. the + # launch didn't smear primal data into the grad slot via a misaligned param struct). + np.testing.assert_allclose(b2.grad.to_numpy(), np.zeros(N, dtype=np.float32)) + + +@test_utils.test() +def test_src_ll_cache_hit_predeclare_struct_ndarrays_pruned(tmp_path: pathlib.Path) -> None: + """Pin the cache-hit fix for ``_predeclare_struct_ndarrays``: on a fastcache hit pass 0 is skipped so the + ``id(nd)``-keyed used-ndarray set is empty; without flat-name fallback pruning every reachable ndarray gets + registered, scrambling the kernel's arg-slot bindings (e.g. a kernel compiled to write ``state.b`` ends up + writing ``state.a`` at launch). The fix uses the cached ``used_vars_by_func_id[KERNEL_FUNC_ID]`` flat-name + set to gate registration on the cache-hit branch, reproducing the exact ndarray set the originating compile + produced. + + The test exercises both the cold (cache-store) and hot (cache-load) paths in the same process via + ``qd.reset()`` cycles, and asserts both that the ndarray the kernel writes to is the *correct* one and that + the other ndarrays are untouched — without the fix the value would land in ``state.a`` (the first + insertion-order ndarray) instead of ``state.b``. + """ + import numpy as np # local import keeps the test module's top-level deps unchanged + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @qd.data_oriented + class State: + def __init__(self) -> None: + self.a = qd.ndarray(qd.i32, shape=(N,)) + self.b = qd.ndarray(qd.i32, shape=(N,)) + self.c = qd.ndarray(qd.i32, shape=(N,)) + + @qd.pure + @qd.kernel + def write_b(s: qd.template()) -> None: + for i in range(N): + s.b[i] = (i + 1) * 17 + + # Cold: cache-miss path populates the fastcache (including the kernel-used flat-name set folded in by + # ``_fold_struct_nd_paths_into_pruning``). + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + # Hot: cache-hit path skips pass 0; this is the branch the fix protects. Without flat-name pruning all three + # ndarrays would be registered in insertion order, displacing ``state.b`` from the slot the kernel was + # compiled to write — and the write would land in ``state.a`` instead. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_loaded, "expected a fastcache hit on the second run" + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + class ModifySubFuncKernelArgs(pydantic.BaseModel): arch: str offline_cache_file_path: str diff --git a/tests/python/test_ad_dataclass.py b/tests/python/test_ad_dataclass.py index d82b1523f9..ef1267ea5f 100644 --- a/tests/python/test_ad_dataclass.py +++ b/tests/python/test_ad_dataclass.py @@ -7,14 +7,13 @@ * ``qd.field`` — ``qd.template()`` path; gradient via ``qd.ad.Tape``. * ``qd.tensor(backend=NDARRAY)`` — same path as ``qd.ndarray``; the dispatcher returns a wrapper whose ndarray ``_impl`` is unwrapped by the dataclass-annotation infrastructure. -* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` - (or ``qd.template()``). With ``object`` / no annotation the wrapper survives into kernel scope - and host-side ``__getitem__`` asserts. +* ``qd.tensor(backend=FIELD)`` — works when the dataclass member is annotated ``qd.Tensor`` (or ``qd.template()``). + With ``object`` / no annotation the wrapper survives into kernel scope and host-side ``__getitem__`` asserts. * mixed — single dataclass holding both a ``qd.ndarray`` and a ``qd.field`` member. Pattern mirrors ``test_ad_ndarray.py`` (ndarray) and ``test_ad_basics.py`` (field). See -``docs/source/user_guide/compound_types.md`` overview table — column "supports differentiation?" -for ``dataclasses.dataclass``. +``docs/source/user_guide/compound_types.md`` overview table — column "supports differentiation?" for +``dataclasses.dataclass``. """ import dataclasses @@ -185,8 +184,8 @@ def test_ad_dataclass_tensor_field_backend_tape(): Note: members must be annotated as ``qd.Tensor`` (not ``object``) when the value is a ``qd.tensor(...)`` wrapper. The typed-dataclass / template machinery uses the member annotation to decide whether to unwrap the wrapper into - its underlying impl before the kernel sees ``s.x[i]``. With ``object`` annotation the wrapper survives into kernel - scope and its host-side ``__getitem__`` asserts. + its underlying impl before the kernel sees ``s.x[i]``. With ``object`` annotation the wrapper survives into + kernel scope and its host-side ``__getitem__`` asserts. """ N = 5 diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index df10c958f5..46a0943c3b 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1,16 +1,16 @@ """Tests for ``@qd.data_oriented`` classes whose members are raw ``qd.ndarray`` (not ``qd.field``, not ``qd.Tensor`` wrappers). -The user-guide doc ``docs/source/user_guide/compound_types.md`` claims this pattern is not supported -("can contain ndarray? no" for ``@qd.data_oriented``). But the in-tree error message in -``python/quadrants/lang/impl.py`` lists ``@qd.data_oriented / frozen-dataclass template`` as a -*supported* route, and the ndarray-in-struct infrastructure added by ``#561 [Type] Tensor 24`` -(2026-04-28) — specifically ``_predeclare_struct_ndarrays`` in +The user-guide doc ``docs/source/user_guide/compound_types.md`` claims this pattern is not supported ("can contain +ndarray? no" for ``@qd.data_oriented``). But the in-tree error message in ``python/quadrants/lang/impl.py`` lists +``@qd.data_oriented / frozen-dataclass template`` as a *supported* route, and the ndarray-in-struct infrastructure +added by ``#561 [Type] Tensor 24`` (2026-04-28) — specifically ``_predeclare_struct_ndarrays`` in ``python/quadrants/lang/ast/ast_transformers/function_def_transformer.py`` — explicitly walks both -``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the data_oriented case. +``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the data_oriented +case. -This file pins what actually works, and documents the gaps. See -``perso_hugh/doc/data_oriented_ndarray.md`` for the design analysis. +This file pins what actually works, and documents the gaps. See ``perso_hugh/doc/data_oriented_ndarray.md`` for the +design analysis. """ import dataclasses @@ -168,10 +168,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 6. Mutation: same instance, reassign ndarray attribute to a *same-shape* ndarray between calls. -# The launch-time stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the -# live ndarray id into args_hash so the launch context is not served stale. We pin that behaviour -# here for the data_oriented case. +# 6. Mutation: same instance, reassign ndarray attribute to a *same-shape* ndarray between calls. The launch-time +# stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the live ndarray id into +# args_hash so the launch context is not served stale. We pin that behaviour here for the data_oriented case. # --------------------------------------------------------------------------- @@ -203,11 +202,11 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 7. Mutation cross-shape: reassign ndarray attribute to a *different-dtype* ndarray. -# The template-mapper specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns -# ``weakref.ref(arg)`` for ``is_data_oriented(arg)``; it does NOT descend into ndarray children to -# compute a dtype/ndim-dependent spec key. So if the data_oriented instance's id is unchanged but -# its ndarray attribute is reassigned to a different dtype, we expect either: +# 7. Mutation cross-shape: reassign ndarray attribute to a *different-dtype* ndarray. The template-mapper +# specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns ``weakref.ref(arg)`` for +# ``is_data_oriented(arg)``; it does NOT descend into ndarray children to compute a dtype/ndim-dependent spec key. +# So if the data_oriented instance's id is unchanged but its ndarray attribute is reassigned to a different dtype, +# we expect either: # - a graceful recompile/raise, or # - silent miscompilation (the bug case — current expected outcome per static analysis). # Mark xfail with strict=False so we record the actual outcome without breaking CI. @@ -241,10 +240,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 8. Distinct instances of same class -> spec-key behaviour. Documents that today each fresh instance -# triggers a recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf -# concern, not a correctness one. We assert correctness here; the recompile count is documented as -# a perf note. +# 8. Distinct instances of same class -> spec-key behaviour. Documents that today each fresh instance triggers a +# recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf concern, not a correctness +# one. We assert correctness here; the recompile count is documented as a perf note. # --------------------------------------------------------------------------- @@ -275,19 +273,17 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9. Fastcache cold then warm. Per the fastcache doc (``user_guide/fastcache.md`` line 129), -# ``@qd.data_oriented`` objects are supported in the cache key. We don't assert cross-process here -# (that requires a fresh interpreter); we assert that ``cache_stored`` becomes True on the first -# call and ``cache_key_generated`` is True (i.e. no PARAM_INVALID fallthrough due to the ndarray -# member). +# ``@qd.data_oriented`` objects are supported in the cache key. We don't assert cross-process here (that requires +# a fresh interpreter); we assert that ``cache_stored`` becomes True on the first call and +# ``cache_key_generated`` is True (i.e. no PARAM_INVALID fallthrough due to the ndarray member). # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # 9b. Fastcache end-to-end with ``@qd.data_oriented`` holding ndarrays. Pattern adapted from -# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory -# to simulate two processes, monkeypatch ``launch_kernel`` to capture whether -# ``compiled_kernel_data`` was loaded from disk. On the second init the data_oriented + ndarray -# kernel should be served from the on-disk fastcache. +# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory to simulate two +# processes, monkeypatch ``launch_kernel`` to capture whether ``compiled_kernel_data`` was loaded from disk. On +# the second init the data_oriented + ndarray kernel should be served from the on-disk fastcache. # --------------------------------------------------------------------------- @@ -299,10 +295,9 @@ def test_data_oriented_ndarray_fastcache_cross_init(tmp_path, monkeypatch): captured_compiled_kernel_data = [] def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): - # Filter to the user kernel only; .to_numpy() launches an internal - # ``ndarray_to_ext_arr`` kernel that is not fastcache-eligible - # (is_pure=False) and would always make compiled_kernel_data=None, - # masking the actual fastcache behaviour of ``run``. + # Filter to the user kernel only; .to_numpy() launches an internal ``ndarray_to_ext_arr`` kernel that is not + # fastcache-eligible (is_pure=False) and would always make compiled_kernel_data=None, masking the actual + # fastcache behaviour of ``run``. if self.func.__name__ == "run": captured_compiled_kernel_data.append(compiled_kernel_data) return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) @@ -333,9 +328,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9c. Same as 9b but with a *nested* ``@qd.data_oriented`` holding an ndarray. Pins that the -# fastcache args_hasher recursion handles nested data_oriented containers correctly across -# processes. +# 9c. Same as 9b but with a *nested* ``@qd.data_oriented`` holding an ndarray. Pins that the fastcache args_hasher +# recursion handles nested data_oriented containers correctly across processes. # --------------------------------------------------------------------------- @@ -382,9 +376,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9d. Fastcache key is dtype-sensitive: same kernel source, different ndarray dtype in the -# data_oriented member -> two distinct disk cache entries. Pins the args_hasher's -# ``[nd-{dtype}-{ndim}{layout}]`` repr. +# 9d. Fastcache key is dtype-sensitive: same kernel source, different ndarray dtype in the data_oriented member -> +# two distinct disk cache entries. Pins the args_hasher's ``[nd-{dtype}-{ndim}{layout}]`` repr. # --------------------------------------------------------------------------- @@ -432,10 +425,10 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the -# whole call (args_hasher returns None for ScalarField). The kernel still runs correctly via -# non-fastcache compilation. This test pins the documented fallback so a future "support -# fields in fastcache" change explicitly chooses to update this test. +# 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the whole call +# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. This +# test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to update +# this test. # --------------------------------------------------------------------------- @@ -487,9 +480,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 10. Pure validation: a @qd.pure @qd.kernel taking a data_oriented arg with an ndarray member should -# compile and run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test -# which only covers ``qd.field``. +# 10. Pure validation: a @qd.pure @qd.kernel taking a data_oriented arg with an ndarray member should compile and +# run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test which only covers +# ``qd.field``. # --------------------------------------------------------------------------- @@ -516,9 +509,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported -# route still works; if this fails, the test environment itself is broken, not the data_oriented -# path). +# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported route still +# works; if this fails, the test environment itself is broken, not the data_oriented path). # --------------------------------------------------------------------------- @@ -543,9 +535,8 @@ def run(s: State): # --------------------------------------------------------------------------- -# 12. data_oriented holding a (frozen) dataclass that holds an ndarray. -# Exercises the ``else`` branch of ``_walk_obj`` recursing through a dataclass child — added by -# the Bug 1 fix. +# 12. data_oriented holding a (frozen) dataclass that holds an ndarray. Exercises the ``else`` branch of +# ``_walk_obj`` recursing through a dataclass child — added by the Bug 1 fix. # --------------------------------------------------------------------------- @@ -575,13 +566,12 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 13. Frozen dataclass holding a data_oriented holding an ndarray, kernel-arg via ``qd.template()``. -# Exercises the dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added -# by the Bug 1 fix. The outer dataclass must be frozen because (i) non-frozen dataclasses are -# unhashable in Python (``__hash__ is None``) and the template-mapper key tuple needs the value -# to be hashable, and (ii) the typed-dataclass-arg form (``def run(s: Outer):``) goes through -# ``_transform_kernel_arg`` which does not currently recurse on data_oriented field *types* (as -# opposed to values) — that's a separate follow-up. +# 13. Frozen dataclass holding a data_oriented holding an ndarray, kernel-arg via ``qd.template()``. Exercises the +# dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added by the Bug 1 fix. The outer +# dataclass must be frozen because (i) non-frozen dataclasses are unhashable in Python (``__hash__ is None``) and +# the template-mapper key tuple needs the value to be hashable, and (ii) the typed-dataclass-arg form (``def +# run(s: Outer):``) goes through ``_transform_kernel_arg`` which does not currently recurse on data_oriented +# field *types* (as opposed to values) — that's a separate follow-up. # --------------------------------------------------------------------------- @@ -720,9 +710,9 @@ def fill_y_from_x(s: qd.template()): # --------------------------------------------------------------------------- -# 17. data_oriented + ndarray + @qd.func sub-call. Pins that the AST-time attribute resolution in -# ``build_Attribute`` (which uses the predeclared AnyArray cache) works when the access happens -# inside a func, not just the top-level kernel. +# 17. data_oriented + ndarray + @qd.func sub-call. Pins that the AST-time attribute resolution in ``build_Attribute`` +# (which uses the predeclared AnyArray cache) works when the access happens inside a func, not just the top-level +# kernel. # --------------------------------------------------------------------------- @@ -835,10 +825,9 @@ def run_f32(s: qd.template()): # --------------------------------------------------------------------------- -# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly -# pointing the user to ``qd.template()``. The two patterns are incompatible at the kernel-arg -# layer: dataclass kernel args are flattened using annotations, data_oriented containers need a -# value-driven walk. Pins the helpful error message. +# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly pointing the user to +# ``qd.template()``. The two patterns are incompatible at the kernel-arg layer: dataclass kernel args are +# flattened using annotations, data_oriented containers need a value-driven walk. Pins the helpful error message. # --------------------------------------------------------------------------- @@ -887,3 +876,511 @@ def run(s: qd.template()): # Run a second time on the same instance — should reuse the same compiled kernel. run(state) + + +# --------------------------------------------------------------------------- +# 22. Robustness: object graphs with Pydantic-style metaclass ``__getattr__`` recursion, and cyclic attribute +# references. Real-world container classes (notably Genesis's ``RigidOptions`` / ``SimOptions``) inherit from +# ``pydantic.BaseModel`` whose ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attributes. +# Quadrants' walker must not blow the stack when it traverses a ``data_oriented`` arg that contains such an +# object, or that contains a back-reference to itself / its parent (e.g. ``solver.scene.solver``). +# --------------------------------------------------------------------------- + + +def test_is_data_oriented_safe_on_pydantic_like_metaclass(): + """``is_data_oriented`` must not invoke ``__getattr__`` on the class (or metaclass), so it stays safe in the + presence of pathological metaclasses whose ``__getattr__`` blows the Python recursion limit on arbitrary + attribute lookups (e.g. Pydantic's ``ModelMetaclass`` when probed for a name not in its private-attrs cache). + """ + + from quadrants.lang.util import is_data_oriented + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Pathological(metaclass=RecursingMeta): + pass + + # Pre-fix this raised RecursionError; with the MRO+__dict__ lookup it just returns False. + assert is_data_oriented(Pathological()) is False + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_with_pydantic_like_child(): + """A ``@qd.data_oriented`` class holding a child whose metaclass has the recursing ``__getattr__`` + (Pydantic-style). Walker must classify the child as non-data-oriented and continue without blowing the stack. + """ + N = 4 + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Options(metaclass=RecursingMeta): + pass + + @qd.data_oriented + class State: + def __init__(self, x, opts): + self.x = x + self.opts = opts + + x = qd.ndarray(qd.i32, shape=(N,)) + state = State(x=x, opts=Options()) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + run(state) + np.testing.assert_array_equal(x.to_numpy(), np.arange(1, N + 1)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_polymorphic_attr_across_instances(): + """Some real-world ``@qd.data_oriented`` containers (Genesis FEMSolver / MPMSolver / SPHSolver, etc.) hold + polymorphic children whose types differ between instances — e.g. ``self.material.x`` is an ``Ndarray`` on + instance A and a ``qd.field`` (``MatrixField``) on instance B. The per-instance path cache walks each instance + fresh, but ``_collect_struct_nd_descriptors`` must additionally tolerate a path's leaf no longer being an + ``Ndarray`` *within a single instance's lifetime* (e.g. ``qd.Tensor`` impl swap), and silently skip the stale + entry rather than crash on ``v.element_type``.""" + N = 4 + + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + + # First instance: ``self.x`` is an Ndarray. The walker emits path ``('x',)`` and caches it. + x_nd = qd.ndarray(qd.i32, shape=(N,)) + state_a = State(x=x_nd) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + run(state_a) + np.testing.assert_array_equal(x_nd.to_numpy(), np.arange(1, N + 1)) + + # Second instance of the SAME class, ``self.x`` is now a ``qd.field`` (MatrixField via Vector.field). + # The cached path ``('x',)`` from instance A points to a non-Ndarray on this instance — the descriptor + # walk must skip it cleanly rather than crash on ``v.element_type``. + f = qd.Vector.field(2, qd.i32, shape=(N,)) + state_b = State(x=f) + + @qd.kernel + def run_field(s: qd.template()): + for i in range(N): + s.x[i] = [i, i + 1] + + run_field(state_b) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_polymorphic_attribute_set_across_instances(): + """Models the Genesis ``DataManager`` failure mode: a ``@qd.data_oriented`` class whose ``__init__`` conditionally + allocates attributes based on a construction flag. Different instances of the same class then have different + attribute *sets* (not just different value types at the same paths). + + With a per-class path cache populated from the first instance walked, this would either AttributeError when the + second instance lacks an attribute the first had (forward direction) or silently miss an ndarray the second + instance has but the first didn't (inverse direction). Per-instance caching walks each instance fresh so both + directions work.""" + N = 4 + + @qd.data_oriented + class PolyState: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + # Forward direction: first instance has 'extra', second doesn't. Used to AttributeError on the cached + # ('extra',) path when running with state_lean. + state_full = PolyState(with_extra=True) + run(state_full) + state_lean = PolyState(with_extra=False) + run(state_lean) + np.testing.assert_array_equal(state_lean.x.to_numpy(), np.arange(1, N + 1)) + + # Inverse direction: a different class so per-class cache (if used by __slots__ fallback) starts fresh; first + # instance lacks 'extra', second has it. The kernel actually *reads* ``s.extra`` so the inverse-direction + # silent miscache (which only manifests when the kernel touches the conditional attr) is exercised end-to-end. + @qd.data_oriented + class PolyState2: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + + @qd.kernel + def run_using_extra(s: qd.template()): + for i in range(N): + s.x[i] = s.extra[i] * 10 + + # Walk the lean instance first (no 'extra'), populating any per-class state with the *narrow* attribute set. + # With the old per-class cache, this would lock in paths = [('x',)] for the class — and the next instance's + # ``extra`` would be silently absent from args_hash and from the kernel spec, leading to a wrong-shape kernel + # or a stale-cache hit when ``extra`` is later reassigned. + state_lean2 = PolyState2(with_extra=False) + run(state_lean2) + np.testing.assert_array_equal(state_lean2.x.to_numpy(), np.arange(1, N + 1)) + + # Now the polymorphic-attr-bearing instance. The per-instance walk must include ``('extra',)`` so that + # ``state_full2.extra``'s shape/id participates in the spec and the kernel compiles correctly. + state_full2 = PolyState2(with_extra=True) + state_full2.extra.from_numpy(np.array([2, 3, 5, 7], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([20, 30, 50, 70], dtype=np.int32)) + + # Reassignment-detection check: swap ``state_full2.extra`` to a different ndarray. The per-instance walk caches + # the *path list* ([('x',), ('extra',)]) on the instance, but the per-call args_hash still folds in + # ``id(getattr(state_full2, 'extra'))`` — so a swap should miss the spec-key cache and re-specialise. + state_full2.extra = qd.ndarray(qd.i32, shape=(N,)) + state_full2.extra.from_numpy(np.array([11, 13, 17, 19], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([110, 130, 170, 190], dtype=np.int32)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_with_cyclic_attr_graph(): + """A ``@qd.data_oriented`` class whose attribute graph contains a cycle (``parent.child.parent is parent``). + Walker must not re-enter the cycle.""" + N = 4 + + @qd.data_oriented + class Child: + def __init__(self): + self.parent = None + + @qd.data_oriented + class Parent: + def __init__(self, x): + self.x = x + self.child = Child() + self.child.parent = self # cycle + + x = qd.ndarray(qd.i32, shape=(N,)) + p = Parent(x=x) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 10 + + run(p) + np.testing.assert_array_equal(x.to_numpy(), np.arange(10, 10 + N)) + + +# --------------------------------------------------------------------------- +# Pruning-driven fastcache behaviour for @qd.data_oriented containers. +# +# These pin the three rules enforced by the args hasher (see fastcache.md "Pruning-driven argument hashing"): +# 1. The cache key may only include contributions from kernel-pruned paths. +# 2. Unrecognised types at kernel-read paths must not be silently dropped. +# 3. Fastcache works for @qd.data_oriented kernel args end-to-end. +# --------------------------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unused_opaque_member_does_not_affect_cache(tmp_path, monkeypatch): + """Rule 1: kernel-unused opaque members do not affect the fastcache key. + + Two ``State`` instances differ only in an opaque ``uuid`` member that the kernel never reads. Both must hit the + same compiled artifact on the second process — proof that the args hasher's pruning narrow walk skips the opaque + attribute (no qualname-fallback, no spurious miss).""" + import uuid + + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + self.uuid = uuid.uuid4() # opaque member, kernel does not read it + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + + # Second process: cold-start, must load from disk. If the uuid had leaked into the cache key, different uuid → + # different L2 key → no artifact would load. + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different uuid) should ALSO load from disk" + assert run._primal.src_ll_cache_observations.cache_loaded + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_opaque_member_fails_fastcache(tmp_path, capfd) -> None: + """Rule 2: when the kernel actually reads an unrecognised-type member, fastcache fails loudly with [UNKNOWN_TYPE] + + [INVALID_FUNC] — no silent drop, no qualname fallback. The kernel still runs via normal compilation.""" + from quadrants._test_tools import qd_init_same_arch + from quadrants.lang._fast_caching.args_hasher import reset_unknown_type_warn_state + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + reset_unknown_type_warn_state() + + class CustomConfig: + def __init__(self, scale: int) -> None: + self.scale = scale + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + x = qd.ndarray(qd.i32, shape=(4,)) + state = State(x=x, cfg=CustomConfig(scale=3)) + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + scale = s.cfg.scale # makes ``__qd_s__qd_cfg`` and ``__qd_s__qd_cfg__qd_scale`` live + for i in range(4): + s.x[i] = i * scale + + run(state) + _out, err = capfd.readouterr() + np.testing.assert_array_equal(x.to_numpy(), np.arange(4) * 3) + + obs = run._primal.src_ll_cache_observations + assert obs.cache_key_generated is False, "unrecognised type at kernel-read path must disable fastcache" + assert "[FASTCACHE][UNKNOWN_TYPE]" in err + assert CustomConfig.__name__ in err + assert "[FASTCACHE][INVALID_FUNC]" in err + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_primitive_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Rule 3 (data_oriented works) + pruning correctness: when the kernel reads a primitive member, its value is + baked into the kernel and must drive a distinct cache entry per value. Two State instances differing only in + ``n`` (read by the kernel) cold-compile separately and both load from disk on the second process.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, n): + self.x = x + self.n = n # primitive, baked into kernel via ``for i in range(s.n)`` + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(s.n): + s.x[i] = i + s.n + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different ``n`` → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both ``n`` values should load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unread_primitive_does_not_affect_cache(tmp_path, monkeypatch) -> None: + """Rule 1: kernel-unused primitive members do not affect the cache key. Mirror of the opaque case for + primitives. Two State instances differing only in ``unused_n`` must share the cache.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, unused_n): + self.x = x + self.unused_n = unused_n # kernel never reads this + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different unused_n) should ALSO load from disk" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_qd_func_chain_propagation_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``@qd.func`` calls (``record_after_call`` extension): when the kernel calls + ``f(self.dofs)`` and ``f`` reads ``s.x``, the kernel's pruning set must include ``__qd_self__qd_dofs__qd_x`` so + that changes to the inner ndarray's dtype invalidate the cache. Two States differing in ``dofs.x``'s dtype must + cold-compile separately.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Dofs: + def __init__(self, x): + self.x = x + + @qd.data_oriented + class State: + def __init__(self, dofs): + self.dofs = dofs + + @qd.func + def write_dofs(d: qd.template(), v: qd.i32): + d.x[0] = v + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_dofs(s.dofs, 7) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "differing dofs.x dtype → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both dtypes load distinct artifacts" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested data_oriented containers. + + Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the caller-side arg + flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — ``self.cfg`` → + ``__qd_self__qd_cfg``). When that happened, primitive members read inside the callee (``cfg.n`` → + ``__qd_cfg__qd_n`` in the callee's chain set) never made it into the kernel's pruning set, so the args-hasher + walked ``self.cfg`` as data_oriented and found no pruned children, yielding an identical hash for *any* value of + ``cfg.n``. Two configs that should produce different kernels (different ``range(s.cfg.n)`` trip counts baked into + codegen) would then share a fastcache entry — leading to stale-kernel hits and silent miscompiles (e.g. Genesis' + ``test_ndarray_no_compile`` was failing with iter-N kernels reused for iter-N+1 scenes that have a different + ``RigidSimStaticConfig.para_level`` baked into their ``qd.static`` branches). + + The fix in ``_pruning.py`` gates propagation on the *root Name* of the chain (``self``, not the flat result), so + both ``f(self)`` and ``f(self.cfg)`` propagate, while already-flattened dataclass refs + (``Name('__qd_state__qd_x')``) are still skipped.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Cfg: + def __init__(self, n): + self.n = n # primitive read by ``write_x`` — drives codegen via ``range(c.n)`` + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + @qd.func + def write_x(x: qd.template(), c: qd.template()): + for i in range(c.n): + x[i] = i + c.n + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_x(s.x, s.cfg) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different cfg.n → both cold-compile" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both cfg.n values load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py new file mode 100644 index 0000000000..b504aae31c --- /dev/null +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -0,0 +1,320 @@ +"""Tests for calling @qd.func that takes a typed-dataclass arg, from a @qd.kernel method of a @qd.data_oriented +class, passing ``self.dataclass_member`` as the arg. + +Genesis's @qd.func helpers declare typed-dataclass parameters (e.g. ``def func(links_state: LinksState, ...):``) and +are designed to be called from kernels that also take typed-dataclass kernel args (so the dataclass is flattened into +per-leaf kernel-locals on both sides of the call boundary). + +When migrating Genesis modules to @qd.data_oriented, we'd like to call the same @qd.func helpers from a data_oriented +kernel method, passing ``self.links_state`` as the arg. Today this fails at AST resolution: + + Missing argument '__qd_links_state__qd_cinr_inertial'. + Unexpected argument 'links_state'. + +These tests pin down the failure modes so we can fix them. +""" + +import dataclasses + +import numpy as np + +import quadrants as qd + +from tests import test_utils + +# ----- typed-dataclass kernel-arg baseline (works) ---------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_baseline_typed_dataclass_kernel_arg_calls_qd_func(): + """Baseline: typed-dataclass kernel arg + qd.func taking same dataclass type — works.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.kernel + def run(state: State): + for i in range(N): + write_x(state, i, i * 3) + + state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + run(state) + np.testing.assert_array_equal(state.x.to_numpy(), np.arange(N) * 3) + + +# ----- data_oriented self-method calling qd.func (the broken case) ----------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_dataclass_member(): + """data_oriented holds a dataclass; self-kernel calls a @qd.func taking that dataclass.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + + @qd.kernel + def run(self): + for i in range(N): + write_x(self.state, i, i * 5) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 5) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_keyword_dataclass_member(): + """Same as above but the qd.func arg is passed by keyword (Genesis pattern).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 7) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 7) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_stable_members_method_calls_qd_func_with_dataclass_member(): + """Same as above but with stable_members=True (the FPS-relevant case).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 11) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 11) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_two_dataclass_members(): + """Two dataclass members, qd.func takes both — Genesis-shaped scenario.""" + N = 4 + + @dataclasses.dataclass + class StateA: + a: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class StateB: + b: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_both(sa: StateA, sb: StateB, i: qd.i32, va: qd.i32, vb: qd.i32): + sa.a[i] = va + sb.b[i] = vb + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.sa = StateA(a=qd.ndarray(qd.i32, shape=(N,))) + self.sb = StateB(b=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_both(sa=self.sa, sb=self.sb, i=i, va=i * 2, vb=i * 13) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.sa.a.to_numpy(), np.arange(N) * 2) + np.testing.assert_array_equal(solver.sb.b.to_numpy(), np.arange(N) * 13) + + +# ----- nested dataclass -------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_member(): + """data_oriented holds an Outer{ Inner{ ndarray } } and passes ``self.outer`` to a @qd.func that expands the + nested dataclass into flat leaves on both sides.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(self.outer, i, i * 17) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 17) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_kwarg(): + """Same as above but the dataclass arg is passed by keyword.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(outer=self.outer, i=i, v=i * 19) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 19) + + +# ----- chained @qd.func calls (qd.func -> qd.func, dataclass threaded through) --- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_dataclass_member(): + """data_oriented kernel calls outer @qd.func, which in turn calls inner @qd.func, threading the same dataclass + arg through. Both qd.funcs have the typed-dataclass parameter; only the outermost call site (data_oriented method + body) uses self.X. The two inner call sites use the typed-arg path that already worked.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def inner_write(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.func + def outer_write(state: State, i: qd.i32, v: qd.i32): + inner_write(state, i, v) + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.state, i, i * 23) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 23) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_nested_dataclass_member(): + """Combination: nested dataclass passed through a chain of two @qd.func calls from a @qd.data_oriented + self-method via self.outer.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def inner_write(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.func + def outer_write(outer: Outer, i: qd.i32, v: qd.i32): + inner_write(outer, i, v) + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.outer, i, i * 29) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 29) diff --git a/tests/python/test_template_typing.py b/tests/python/test_template_typing.py index 69e9ee990b..11e9d5da72 100644 --- a/tests/python/test_template_typing.py +++ b/tests/python/test_template_typing.py @@ -57,24 +57,35 @@ class DataOrientedWithoutFloat: def __init__(self) -> None: self.an_int = 123 self.a_bool = True + self.scratch = qd.ndarray(qd.i32, shape=(1,)) @qd.data_oriented class DataOrientedWithFloat: def __init__(self) -> None: self.an_int = 123 self.a_float = 1.23 + self.scratch = qd.ndarray(qd.i32, shape=(1,)) + # Read the primitive members so the fastcache narrow walk includes them in the hash. Pre-pruning the args_hasher + # walked every member of every container arg blindly; with pruning the kernel must actually access ``a.a_float`` + # for the raise-on-templated-floats guard to fire (the value being baked-in only matters when the kernel reads + # it). @qd.kernel(fastcache=True) - def k1(a: qd.Template) -> None: ... + def k1f(a: qd.Template) -> None: + a.scratch[0] = qd.cast(a.a_float, qd.i32) + + @qd.kernel(fastcache=True) + def k1i(a: qd.Template) -> None: + a.scratch[0] = a.an_int my_do1 = DataOrientedWithoutFloat() - k1(my_do1) + k1i(my_do1) my_do2 = DataOrientedWithFloat() if raise_on_templated_floats: with pytest.raises(ValueError): - k1(my_do2) + k1f(my_do2) else: - k1(my_do2) + k1f(my_do2) @pytest.mark.parametrize("raise_on_templated_floats", [False, True])