From 59bce14901ae5f27952552a1102426be35870cbd Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 13 Nov 2025 11:29:54 -0600 Subject: [PATCH 1/3] move materialization code to its own module --- .basedpyright/baseline.json | 144 ---------- pytato/transform/__init__.py | 357 +------------------------ pytato/transform/materialize.py | 455 ++++++++++++++++++++++++++++++++ test/test_codegen.py | 9 +- test/test_distributed.py | 3 +- test/test_jax.py | 2 +- 6 files changed, 476 insertions(+), 494 deletions(-) create mode 100644 pytato/transform/materialize.py diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index b50050823..9b4af4d24 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -8147,150 +8147,6 @@ "lineCount": 1 } }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 52, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 52, - "endColumn": 67, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 52, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 30, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 39, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 36, - "endColumn": 40, - "lineCount": 1 - } - }, { "code": "reportIncompatibleMethodOverride", "range": { diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index aa3bd2838..91cd72d1a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -26,7 +26,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import dataclasses import logging from collections.abc import Hashable, Mapping from typing import ( @@ -77,11 +76,10 @@ from pytato.equality import EqualityComparer from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult -from pytato.tags import ImplStored if TYPE_CHECKING: - from collections.abc import Callable, Iterable + from collections.abc import Callable from pytato.distributed.nodes import ( DistributedRecv, @@ -123,11 +121,11 @@ .. autofunction:: deduplicate .. autofunction:: get_dependencies .. autofunction:: map_and_copy -.. autofunction:: materialize_with_mpms .. autofunction:: deduplicate_data_wrappers .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum .. automodule:: pytato.transform.einsum_distributive_law +.. automodule:: pytato.transform.materialize .. automodule:: pytato.transform.metadata .. automodule:: pytato.transform.dead_code_elimination @@ -1815,276 +1813,6 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # }}} -# {{{ MPMS materializer - -@dataclasses.dataclass(frozen=True, eq=True) -class MPMSMaterializerAccumulator: - """This class serves as the return value of :class:`MPMSMaterializer`. It - contains the set of materialized predecessors and the rewritten expression - (i.e. the expression with tags for materialization applied). - """ - materialized_predecessors: frozenset[Array] - expr: Array - - -class MPMSMaterializerCache( - CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator, []]): - """ - Cache for :class:`MPMSMaterializer`. - - .. automethod:: __init__ - .. automethod:: add - """ - def __init__( - self, - err_on_collision: bool, - err_on_created_duplicate: bool) -> None: - """ - Initialize the cache. - - :arg err_on_collision: Raise an exception if two distinct input expression - instances have the same key. - :arg err_on_created_duplicate: Raise an exception if mapping produces a new - array instance that has the same key as the input array. - """ - super().__init__(err_on_collision=err_on_collision) - - self.err_on_created_duplicate = err_on_created_duplicate - - self._result_key_to_result: dict[ - ArrayOrNames, MPMSMaterializerAccumulator] = {} - - self._equality_comparer: EqualityComparer = EqualityComparer() - - def add( - self, - inputs: CacheInputsWithKey[ArrayOrNames, []], - result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: - """ - Cache a mapping result. - - Returns the cached result (which may not be identical to *result* if a - result was already cached with the same result key). - """ - key = inputs.key - - assert key not in self._input_key_to_result, \ - f"Cache entry is already present for key '{key}'." - - try: - # The first encountered instance of each distinct result (in terms of - # "==" of result.expr) gets cached, and subsequent mappings with results - # that are equal to prior cached results are replaced with the original - # instance - result = self._result_key_to_result[result.expr] - except KeyError: - if ( - self.err_on_created_duplicate - and _is_mapper_created_duplicate( - inputs.expr, result.expr, - equality_comparer=self._equality_comparer)): - raise MapperCreatedDuplicateError from None - - self._result_key_to_result[result.expr] = result - - self._input_key_to_result[key] = result - if self.err_on_collision: - self._input_key_to_expr[key] = inputs.expr - - return result - - -def _materialize_if_mpms(expr: Array, - nsuccessors: int, - predecessors: Iterable[MPMSMaterializerAccumulator] - ) -> MPMSMaterializerAccumulator: - """ - Returns an instance of :class:`MPMSMaterializerAccumulator`, that - materializes *expr* if it has more than 1 successor and more than 1 - materialized predecessor. - """ - from functools import reduce - - materialized_predecessors: frozenset[Array] = reduce( - frozenset.union, - (pred.materialized_predecessors - for pred in predecessors), - frozenset()) - if nsuccessors > 1 and len(materialized_predecessors) > 1: - new_expr = expr.tagged(ImplStored()) - return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) - else: - return MPMSMaterializerAccumulator(materialized_predecessors, expr) - - -class MPMSMaterializer( - CachedMapper[MPMSMaterializerAccumulator, Never, []]): - """ - See :func:`materialize_with_mpms` for an explanation. - - .. attribute:: nsuccessors - - A mapping from a node in the expression graph (i.e. an - :class:`~pytato.Array`) to its number of successors. - """ - def __init__( - self, - nsuccessors: Mapping[Array, int], - _cache: MPMSMaterializerCache | None = None): - err_on_collision = __debug__ - err_on_created_duplicate = __debug__ - - if _cache is None: - _cache = MPMSMaterializerCache( - err_on_collision=err_on_collision, - err_on_created_duplicate=err_on_created_duplicate) - - # Does not support functions, so function_cache is ignored - super().__init__(err_on_collision=err_on_collision, _cache=_cache) - - self.nsuccessors = nsuccessors - - def _cache_add( - self, - inputs: CacheInputsWithKey[ArrayOrNames, []], - result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: - try: - return self._cache.add(inputs, result) - except MapperCreatedDuplicateError as e: - raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " - f"{type(self)}.") from e - - def clone_for_callee( - self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - raise AssertionError("Control shouldn't reach this point.") - - def _map_input_base(self, expr: InputArgumentBase - ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) - - map_placeholder = _map_input_base - map_data_wrapper = _map_input_base - map_size_param = _map_input_base - - def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: - raise NotImplementedError("only LoopyCallResult named array" - " supported for now.") - - def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: - children_rec = {bnd_name: self.rec(bnd) - for bnd_name, bnd in sorted(expr.bindings.items())} - new_children: Mapping[str, Array] = immutabledict({ - bnd_name: bnd.expr - for bnd_name, bnd in children_rec.items()}) - return _materialize_if_mpms( - expr.replace_if_different(bindings=new_children), - self.nsuccessors[expr], - children_rec.values()) - - def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_arrays = tuple(ary.expr for ary in rec_arrays) - return _materialize_if_mpms( - expr.replace_if_different(arrays=new_arrays), - self.nsuccessors[expr], - rec_arrays) - - def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_arrays = tuple(ary.expr for ary in rec_arrays) - return _materialize_if_mpms( - expr.replace_if_different(arrays=new_arrays), - self.nsuccessors[expr], - rec_arrays) - - def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: - rec_array = self.rec(expr.array) - return _materialize_if_mpms( - expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], - (rec_array,)) - - def map_axis_permutation(self, expr: AxisPermutation - ) -> MPMSMaterializerAccumulator: - rec_array = self.rec(expr.array) - return _materialize_if_mpms( - expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], - (rec_array,)) - - def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: - rec_array = self.rec(expr.array) - rec_indices = {i: self.rec(idx) - for i, idx in enumerate(expr.indices) - if isinstance(idx, Array)} - new_indices = tuple(rec_indices[i].expr - if i in rec_indices - else expr.indices[i] - for i in range( - len(expr.indices))) - new_indices = ( - expr.indices - if _entries_are_identical(new_indices, expr.indices) - else new_indices) - return _materialize_if_mpms( - expr.replace_if_different(array=rec_array.expr, indices=new_indices), - self.nsuccessors[expr], - (rec_array, *tuple(rec_indices.values()))) - - map_basic_index = _map_index_base - map_contiguous_advanced_index = _map_index_base - map_non_contiguous_advanced_index = _map_index_base - - def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: - rec_array = self.rec(expr.array) - return _materialize_if_mpms( - expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], - (rec_array,)) - - def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: - rec_args = [self.rec(ary) for ary in expr.args] - new_args = tuple(ary.expr for ary in rec_args) - return _materialize_if_mpms( - expr.replace_if_different(args=new_args), - self.nsuccessors[expr], - rec_args) - - def map_dict_of_named_arrays(self, expr: DictOfNamedArrays - ) -> MPMSMaterializerAccumulator: - raise NotImplementedError - - def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator: - # loopy call result is always materialized - return MPMSMaterializerAccumulator(frozenset([expr]), expr) - - def map_distributed_send_ref_holder(self, - expr: DistributedSendRefHolder - ) -> MPMSMaterializerAccumulator: - rec_send_data = self.rec(expr.send.data) - rec_passthrough = self.rec(expr.passthrough_data) - return MPMSMaterializerAccumulator( - rec_passthrough.materialized_predecessors, - expr.replace_if_different( - send=expr.send.replace_if_different(data=rec_send_data.expr), - passthrough_data=rec_passthrough.expr)) - - def map_distributed_recv(self, expr: DistributedRecv - ) -> MPMSMaterializerAccumulator: - return MPMSMaterializerAccumulator(frozenset([expr]), expr) - - def map_named_call_result(self, expr: NamedCallResult - ) -> MPMSMaterializerAccumulator: - raise NotImplementedError("MPMSMaterializer does not support functions.") - -# }}} - - # {{{ mapper frontends def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, @@ -2130,77 +1858,16 @@ def map_and_copy(expr: ArrayOrNamesTc, def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - r""" - Materialize nodes in *expr* with MPMS materialization strategy. - MPMS stands for Multiple-Predecessors, Multiple-Successors. - - .. note:: - - - MPMS materialization strategy is a greedy materialization algorithm in - which any node with more than 1 materialized predecessor and more than - 1 successor is materialized. - - Materializing here corresponds to tagging a node with - :class:`~pytato.tags.ImplStored`. - - Does not attempt to materialize sub-expressions in - :attr:`pytato.Array.shape`. - - .. warning:: - - This is a greedy materialization algorithm and thereby this algorithm - might be too eager to materialize. Consider the graph below: - - :: - - I1 I2 - \ / - \ / - \ / - 🡦 🡧 - T - / \ - / \ - / \ - 🡧 🡦 - O1 O2 - - where, 'I1', 'I2' correspond to instances of - :class:`pytato.array.InputArgumentBase`, and, 'O1' and 'O2' are the outputs - required to be evaluated in the computation graph. MPMS materialization - algorithm will materialize the intermediate node 'T' as it has 2 - predecessors and 2 successors. However, the total number of memory - accesses after applying MPMS goes up as shown by the table below. - - ====== ======== ======= - .. Before After - ====== ======== ======= - Reads 4 4 - Writes 2 3 - Total 6 7 - ====== ======== ======= - - """ - from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers - materializer = MPMSMaterializer(get_nusers(expr)) - - if isinstance(expr, Array): - res = materializer(expr).expr - assert isinstance(res, Array) - elif isinstance(expr, DictOfNamedArrays): - res = expr.replace_if_different( - data={ - name: _verify_is_array(materializer(ary).expr) - for name, ary, in expr._data.items()}) - assert isinstance(res, DictOfNamedArrays) - else: - raise NotImplementedError("not implemented for {type(expr).__name__}.") - - from pytato import DEBUG_ENABLED - if DEBUG_ENABLED: - transform_logger.info("materialize_with_mpms: materialized " - f"{get_num_tags_of_type(res, ImplStored)} out of " - f"{get_num_nodes(res)} nodes") - - return res + from warnings import warn + warn( + "pytato.transform.materialize_with_mpms is deprecated and will be removed in " + "2025. Use pytato.transform.materialize.materialize_with_mpms instead.", + DeprecationWarning, stacklevel=2) + + from pytato.transform.materialize import ( + materialize_with_mpms as new_materialize_with_mpms, + ) + return new_materialize_with_mpms(expr) # }}} diff --git a/pytato/transform/materialize.py b/pytato/transform/materialize.py new file mode 100644 index 000000000..a3684b446 --- /dev/null +++ b/pytato/transform/materialize.py @@ -0,0 +1,455 @@ +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2020 Matt Wala +Copyright (C) 2020-21 Kaushik Kulkarni +Copyright (C) 2020-21 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +import dataclasses +import logging +from typing import ( + TYPE_CHECKING, + cast, +) + +from immutabledict import immutabledict +from typing_extensions import Never, Self, override + +from pytato.array import ( + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + Array, + AxisPermutation, + BasicIndex, + Concatenate, + DataWrapper, + DictOfNamedArrays, + Einsum, + IndexBase, + IndexLambda, + InputArgumentBase, + NamedArray, + Placeholder, + Reshape, + Roll, + SizeParam, + Stack, + _entries_are_identical, +) +from pytato.equality import EqualityComparer +from pytato.tags import ImplStored +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesTc, + CachedMapper, + CachedMapperCache, + CacheInputsWithKey, + MapperCreatedDuplicateError, + _is_mapper_created_duplicate, + _verify_is_array, +) + + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + + from pytato.distributed.nodes import ( + DistributedRecv, + DistributedSendRefHolder, + ) + from pytato.function import FunctionDefinition, NamedCallResult + + +__doc__ = """ +.. currentmodule:: pytato.transform.materialize + +.. autofunction:: materialize_with_mpms +""" + +# {{{ MPMS + + +@dataclasses.dataclass(frozen=True, eq=True) +class MPMSMaterializerAccumulator: + """This class serves as the return value of :class:`MPMSMaterializer`. It + contains the set of materialized predecessors and the rewritten expression + (i.e. the expression with tags for materialization applied). + """ + materialized_predecessors: frozenset[Array] + expr: Array + + +class MPMSMaterializerCache( + CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator, []]): + """ + Cache for :class:`MPMSMaterializer`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate: bool = err_on_created_duplicate + + self._result_key_to_result: dict[ + ArrayOrNames, MPMSMaterializerAccumulator] = {} + + self._equality_comparer: EqualityComparer = EqualityComparer() + + @override + def add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + try: + # The first encountered instance of each distinct result (in terms of + # "==" of result.expr) gets cached, and subsequent mappings with results + # that are equal to prior cached results are replaced with the original + # instance + result = self._result_key_to_result[result.expr] + except KeyError: + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate( + inputs.expr, result.expr, + equality_comparer=self._equality_comparer)): + raise MapperCreatedDuplicateError from None + + self._result_key_to_result[result.expr] = result + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result + + +def _materialize_if_mpms(expr: Array, + nsuccessors: int, + predecessors: Iterable[MPMSMaterializerAccumulator] + ) -> MPMSMaterializerAccumulator: + """ + Returns an instance of :class:`MPMSMaterializerAccumulator`, that + materializes *expr* if it has more than 1 successor and more than 1 + materialized predecessor. + """ + from functools import reduce + + materialized_predecessors: frozenset[Array] = reduce( + cast( + "Callable[[frozenset[Array], frozenset[Array]], frozenset[Array]]", + frozenset.union), + (pred.materialized_predecessors for pred in predecessors), + cast("frozenset[Array]", frozenset())) + + if nsuccessors > 1 and len(materialized_predecessors) > 1: + new_expr = expr.tagged(ImplStored()) + return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) + else: + return MPMSMaterializerAccumulator(materialized_predecessors, expr) + + +class MPMSMaterializer( + CachedMapper[MPMSMaterializerAccumulator, Never, []]): + """ + See :func:`materialize_with_mpms` for an explanation. + + .. attribute:: nsuccessors + + A mapping from a node in the expression graph (i.e. an + :class:`~pytato.Array`) to its number of successors. + """ + def __init__( + self, + nsuccessors: Mapping[Array, int], + _cache: MPMSMaterializerCache | None = None): + err_on_collision = __debug__ + err_on_created_duplicate = __debug__ + + if _cache is None: + _cache = MPMSMaterializerCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + # Does not support functions, so function_cache is ignored + super().__init__(err_on_collision=err_on_collision, _cache=_cache) + + self.nsuccessors: Mapping[Array, int] = nsuccessors + + @override + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + @override + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + raise AssertionError("Control shouldn't reach this point.") + + def _map_input_base(self, expr: InputArgumentBase + ) -> MPMSMaterializerAccumulator: + return MPMSMaterializerAccumulator(frozenset([expr]), expr) + + def map_placeholder(self, expr: Placeholder) -> MPMSMaterializerAccumulator: + return self._map_input_base(expr) + + def map_data_wrapper(self, expr: DataWrapper) -> MPMSMaterializerAccumulator: + return self._map_input_base(expr) + + def map_size_param(self, expr: SizeParam) -> MPMSMaterializerAccumulator: + return self._map_input_base(expr) + + def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: + raise NotImplementedError("only LoopyCallResult named array" + " supported for now.") + + def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: + children_rec = {bnd_name: self.rec(bnd) + for bnd_name, bnd in sorted(expr.bindings.items())} + new_children: Mapping[str, Array] = immutabledict({ + bnd_name: bnd.expr + for bnd_name, bnd in children_rec.items()}) + return _materialize_if_mpms( + expr.replace_if_different(bindings=new_children), + self.nsuccessors[expr], + children_rec.values()) + + def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: + rec_arrays = [self.rec(ary) for ary in expr.arrays] + new_arrays = tuple(ary.expr for ary in rec_arrays) + return _materialize_if_mpms( + expr.replace_if_different(arrays=new_arrays), + self.nsuccessors[expr], + rec_arrays) + + def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: + rec_arrays = [self.rec(ary) for ary in expr.arrays] + new_arrays = tuple(ary.expr for ary in rec_arrays) + return _materialize_if_mpms( + expr.replace_if_different(arrays=new_arrays), + self.nsuccessors[expr], + rec_arrays) + + def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: + rec_array = self.rec(expr.array) + return _materialize_if_mpms( + expr.replace_if_different(array=rec_array.expr), + self.nsuccessors[expr], + (rec_array,)) + + def map_axis_permutation(self, expr: AxisPermutation + ) -> MPMSMaterializerAccumulator: + rec_array = self.rec(expr.array) + return _materialize_if_mpms( + expr.replace_if_different(array=rec_array.expr), + self.nsuccessors[expr], + (rec_array,)) + + def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: + rec_array = self.rec(expr.array) + rec_indices = {i: self.rec(idx) + for i, idx in enumerate(expr.indices) + if isinstance(idx, Array)} + new_indices = tuple(rec_indices[i].expr + if i in rec_indices + else expr.indices[i] + for i in range( + len(expr.indices))) + new_indices = ( + expr.indices + if _entries_are_identical(new_indices, expr.indices) + else new_indices) + return _materialize_if_mpms( + expr.replace_if_different(array=rec_array.expr, indices=new_indices), + self.nsuccessors[expr], + (rec_array, *tuple(rec_indices.values()))) + + def map_basic_index(self, expr: BasicIndex) -> MPMSMaterializerAccumulator: + return self._map_index_base(expr) + + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes) -> MPMSMaterializerAccumulator: + return self._map_index_base(expr) + + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes + ) -> MPMSMaterializerAccumulator: + return self._map_index_base(expr) + + def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: + rec_array = self.rec(expr.array) + return _materialize_if_mpms( + expr.replace_if_different(array=rec_array.expr), + self.nsuccessors[expr], + (rec_array,)) + + def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: + rec_args = [self.rec(ary) for ary in expr.args] + new_args = tuple(ary.expr for ary in rec_args) + return _materialize_if_mpms( + expr.replace_if_different(args=new_args), + self.nsuccessors[expr], + rec_args) + + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays + ) -> MPMSMaterializerAccumulator: + raise NotImplementedError + + def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator: + # loopy call result is always materialized + return MPMSMaterializerAccumulator(frozenset([expr]), expr) + + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> MPMSMaterializerAccumulator: + rec_send_data = self.rec(expr.send.data) + rec_passthrough = self.rec(expr.passthrough_data) + return MPMSMaterializerAccumulator( + rec_passthrough.materialized_predecessors, + expr.replace_if_different( + send=expr.send.replace_if_different(data=rec_send_data.expr), + passthrough_data=rec_passthrough.expr)) + + def map_distributed_recv(self, expr: DistributedRecv + ) -> MPMSMaterializerAccumulator: + return MPMSMaterializerAccumulator(frozenset([expr]), expr) + + def map_named_call_result(self, expr: NamedCallResult + ) -> MPMSMaterializerAccumulator: + raise NotImplementedError("MPMSMaterializer does not support functions.") + + +def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + r""" + Materialize nodes in *expr* with MPMS materialization strategy. + MPMS stands for Multiple-Predecessors, Multiple-Successors. + + .. note:: + + - MPMS materialization strategy is a greedy materialization algorithm in + which any node with more than 1 materialized predecessor and more than + 1 successor is materialized. + - Materializing here corresponds to tagging a node with + :class:`~pytato.tags.ImplStored`. + - Does not attempt to materialize sub-expressions in + :attr:`pytato.Array.shape`. + + .. warning:: + + This is a greedy materialization algorithm and thereby this algorithm + might be too eager to materialize. Consider the graph below: + + :: + + I1 I2 + \ / + \ / + \ / + 🡦 🡧 + T + / \ + / \ + / \ + 🡧 🡦 + O1 O2 + + where, 'I1', 'I2' correspond to instances of + :class:`pytato.array.InputArgumentBase`, and, 'O1' and 'O2' are the outputs + required to be evaluated in the computation graph. MPMS materialization + algorithm will materialize the intermediate node 'T' as it has 2 + predecessors and 2 successors. However, the total number of memory + accesses after applying MPMS goes up as shown by the table below. + + ====== ======== ======= + .. Before After + ====== ======== ======= + Reads 4 4 + Writes 2 3 + Total 6 7 + ====== ======== ======= + + """ + from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers + materializer = MPMSMaterializer(get_nusers(expr)) + + if isinstance(expr, Array): + res = materializer(expr).expr + assert isinstance(res, Array) + elif isinstance(expr, DictOfNamedArrays): + res = expr.replace_if_different( + data={ + name: _verify_is_array(materializer(ary).expr) + for name, ary, in expr._data.items()}) + assert isinstance(res, DictOfNamedArrays) + else: + raise NotImplementedError("not implemented for {type(expr).__name__}.") + + from pytato import DEBUG_ENABLED + if DEBUG_ENABLED: + logger.info("materialize_with_mpms: materialized " + f"{get_num_tags_of_type(res, ImplStored)} out of " + f"{get_num_nodes(res)} nodes") + + return res + +# }}} + +# vim: foldmethod=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 23a5b0f6e..e24a1f211 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1383,7 +1383,8 @@ def test_materialize_reduces_flops(ctx_factory: cl.CtxFactory): y2 = cse / x5 bad_graph = pt.make_dict_of_named_arrays({"y1": y1, "y2": y2}) - good_graph = pt.transform.materialize_with_mpms(bad_graph) + from pytato.transform.materialize import materialize_with_mpms + good_graph = materialize_with_mpms(bad_graph) bad_t_unit = pt.generate_loopy(bad_graph) good_t_unit = pt.generate_loopy(good_graph) @@ -1407,7 +1408,9 @@ def test_named_temporaries(ctx_factory: cl.CtxFactory): dag = pt.make_dict_of_named_arrays({"out1": 10 * tmp1 + 11 * tmp2, "out2": 22 * tmp1 + 53 * tmp2 }) - dag = pt.transform.materialize_with_mpms(dag) + + from pytato.transform.materialize import materialize_with_mpms + dag = materialize_with_mpms(dag) def mark_materialized_nodes_as_cse(ary: pt.Array | pt.AbstractResultWithNamedArrays ) -> pt.Array: @@ -1461,7 +1464,7 @@ def test_random_dag_against_numpy(ctx_factory: cl.CtxFactory): ref_result = make_random_dag(rdagc_np) dag = make_random_dag(rdagc_pt) - from pytato.transform import materialize_with_mpms + from pytato.transform.materialize import materialize_with_mpms dict_named_arys = pt.make_dict_of_named_arrays({"result": dag}) dict_named_arys = materialize_with_mpms(dict_named_arys) if 0: diff --git a/test/test_distributed.py b/test/test_distributed.py index 7de1607ff..1492d5359 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -308,7 +308,8 @@ def gen_comm(rdagc): ]) pt_dag = pt.make_dict_of_named_arrays( {"result": make_random_dag(rdagc_comm)}) - x_comm = pt.transform.materialize_with_mpms(pt_dag) + from pytato.transform.materialize import materialize_with_mpms + x_comm = materialize_with_mpms(pt_dag) distributed_partition = pt.find_distributed_partition(comm, x_comm) pt.verify_distributed_partition(comm, distributed_partition) diff --git a/test/test_jax.py b/test/test_jax.py index d632c2673..69cf5bedd 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -111,7 +111,7 @@ def test_random_dag_against_numpy(jit): ref_result = make_random_dag(rdagc_np) dag = make_random_dag(rdagc_pt) - from pytato.transform import materialize_with_mpms + from pytato.transform.materialize import materialize_with_mpms dict_named_arys = pt.make_dict_of_named_arrays({"result": dag}) dict_named_arys = materialize_with_mpms(dict_named_arys) if 0: From 70eb56873b89359c54e64505b63b313f77014d4f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Nov 2025 10:01:26 -0600 Subject: [PATCH 2/3] add get_list_of_users function in analysis --- pytato/analysis/__init__.py | 61 ++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1213d49af..df2eca213 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,6 +26,7 @@ THE SOFTWARE. """ +from collections import defaultdict from typing import TYPE_CHECKING, Any, overload from orderedsets import FrozenOrderedSet @@ -63,6 +64,7 @@ .. currentmodule:: pytato.analysis .. autofunction:: get_nusers +.. autofunction:: get_list_of_users .. autofunction:: is_einsum_similar_to_subscript @@ -82,12 +84,12 @@ """ -# {{{ NUserCollector +# {{{ ListOfUsersCollector -class NUserCollector(Mapper[None, None, []]): +class ListOfUsersCollector(Mapper[None, None, []]): """ - A :class:`pytato.transform.CachedWalkMapper` that records the number of - times an array expression is a direct dependency of other nodes. + A :class:`pytato.transform.CachedWalkMapper` that records, for each array + expression, the nodes that directly depend on it. .. note:: @@ -97,10 +99,9 @@ class NUserCollector(Mapper[None, None, []]): send's data. """ def __init__(self) -> None: - from collections import defaultdict super().__init__() self._visited_ids: set[int] = set() - self.nusers: dict[Array, int] = defaultdict(lambda: 0) + self.array_to_users: dict[Array, list[ArrayOrNames]] = defaultdict(list) def rec(self, expr: ArrayOrNames) -> None: # See CachedWalkMapper.rec on why we chose id(x) as the cache key. @@ -113,38 +114,38 @@ def rec(self, expr: ArrayOrNames) -> None: def map_index_lambda(self, expr: IndexLambda) -> None: for ary in expr.bindings.values(): - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) for dim in expr.shape: if isinstance(dim, Array): - self.nusers[dim] += 1 + self.array_to_users[dim].append(expr) self.rec(dim) def map_stack(self, expr: Stack) -> None: for ary in expr.arrays: - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) def map_concatenate(self, expr: Concatenate) -> None: for ary in expr.arrays: - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) def map_loopy_call(self, expr: LoopyCall) -> None: for ary in expr.bindings.values(): if isinstance(ary, Array): - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) def map_einsum(self, expr: Einsum) -> None: for ary in expr.args: - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) for dim in expr.shape: if isinstance(dim, Array): - self.nusers[dim] += 1 + self.array_to_users[dim].append(expr) self.rec(dim) def map_named_array(self, expr: NamedArray) -> None: @@ -155,12 +156,12 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: self.rec(child) def _map_index_base(self, expr: IndexBase) -> None: - self.nusers[expr.array] += 1 + self.array_to_users[expr.array].append(expr) self.rec(expr.array) for idx in expr.indices: if isinstance(idx, Array): - self.nusers[idx] += 1 + self.array_to_users[idx].append(expr) self.rec(idx) map_basic_index = _map_index_base @@ -168,7 +169,7 @@ def _map_index_base(self, expr: IndexBase) -> None: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None: - self.nusers[expr.array] += 1 + self.array_to_users[expr.array].append(expr) self.rec(expr.array) map_roll = _map_index_remapping_base @@ -178,7 +179,7 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None: def _map_input_base(self, expr: InputArgumentBase) -> None: for dim in expr.shape: if isinstance(dim, Array): - self.nusers[dim] += 1 + self.array_to_users[dim].append(expr) self.rec(dim) map_placeholder = _map_input_base @@ -189,20 +190,20 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder ) -> None: # Note: We do not consider 'expr.send.data' as a predecessor of *expr*, # as there is no dataflow from *expr.send.data* to *expr* - self.nusers[expr.passthrough_data] += 1 + self.array_to_users[expr.passthrough_data].append(expr) self.rec(expr.passthrough_data) self.rec(expr.send.data) def map_distributed_recv(self, expr: DistributedRecv) -> None: for dim in expr.shape: if isinstance(dim, Array): - self.nusers[dim] += 1 + self.array_to_users[dim].append(expr) self.rec(dim) def map_call(self, expr: Call) -> None: for ary in expr.bindings.values(): if isinstance(ary, Array): - self.nusers[ary] += 1 + self.array_to_users[ary].append(expr) self.rec(ary) def map_named_call_result(self, expr: NamedCallResult) -> None: @@ -216,9 +217,21 @@ def get_nusers(outputs: ArrayOrNames) -> Mapping[Array, int]: For the DAG *outputs*, returns the mapping from each array node to the number of nodes using its value within the DAG given by *outputs*. """ - nuser_collector = NUserCollector() - nuser_collector(outputs) - return nuser_collector.nusers + list_of_users_collector = ListOfUsersCollector() + list_of_users_collector(outputs) + return defaultdict(int, { + ary: len(users) + for ary, users in list_of_users_collector.array_to_users.items()}) + + +def get_list_of_users(outputs: ArrayOrNames) -> Mapping[Array, list[ArrayOrNames]]: + """ + For the DAG *outputs*, returns the mapping from each array node to the list of + nodes using its value within the DAG given by *outputs*. + """ + list_of_users_collector = ListOfUsersCollector() + list_of_users_collector(outputs) + return list_of_users_collector.array_to_users # {{{ is_einsum_similar_to_subscript @@ -482,7 +495,6 @@ def __init__( ) -> None: super().__init__(_visited_functions=_visited_functions) - from collections import defaultdict self.expr_type_counts: dict[type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates @@ -562,7 +574,6 @@ class NodeMultiplicityMapper(CachedWalkMapper[[]]): def __init__(self, _visited_functions: set[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) - from collections import defaultdict self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: From b55b35bd0cfa5ee795ebe650c2bc71e4bbe74d82 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 14 Nov 2025 11:29:13 -0600 Subject: [PATCH 3/3] tweak the definition of 'multiple successors' in MPMS materializer to handle indexing with heavy reuse add more explanation for MPMS reuse tweak --- pytato/transform/materialize.py | 49 +++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/pytato/transform/materialize.py b/pytato/transform/materialize.py index a3684b446..b9158f4bf 100644 --- a/pytato/transform/materialize.py +++ b/pytato/transform/materialize.py @@ -172,7 +172,7 @@ def add( def _materialize_if_mpms(expr: Array, - nsuccessors: int, + successors: list[ArrayOrNames], predecessors: Iterable[MPMSMaterializerAccumulator] ) -> MPMSMaterializerAccumulator: """ @@ -189,6 +189,24 @@ def _materialize_if_mpms(expr: Array, (pred.materialized_predecessors for pred in predecessors), cast("frozenset[Array]", frozenset())) + nsuccessors = 0 + for successor in successors: + # Handle indexing with heavy reuse, if the sizes are known ahead of time. + # This can occur when the elements of a smaller array are used repeatedly to + # compute the elements of a larger array. (Example: In meshmode's direct + # connection code, this happens when injecting data from a smaller + # discretization into a larger one, such as BTAG_ALL -> FACE_RESTR_ALL.) + # + # In this case, we would like to bias towards materialization by + # making one successor seem like n of them, if it is n times bigger. + if ( + isinstance(successor, IndexBase) + and isinstance(successor.size, int) + and isinstance(expr.size, int)): + nsuccessors += (successor.size // expr.size) if expr.size else 0 + else: + nsuccessors += 1 + if nsuccessors > 1 and len(materialized_predecessors) > 1: new_expr = expr.tagged(ImplStored()) return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) @@ -201,14 +219,15 @@ class MPMSMaterializer( """ See :func:`materialize_with_mpms` for an explanation. - .. attribute:: nsuccessors + .. attribute:: successors A mapping from a node in the expression graph (i.e. an - :class:`~pytato.Array`) to its number of successors. + :class:`~pytato.Array`) to a list of its successors (possibly including + multiple references to the same successor if it uses the node multiple times). """ def __init__( self, - nsuccessors: Mapping[Array, int], + successors: Mapping[Array, list[ArrayOrNames]], _cache: MPMSMaterializerCache | None = None): err_on_collision = __debug__ err_on_created_duplicate = __debug__ @@ -221,7 +240,7 @@ def __init__( # Does not support functions, so function_cache is ignored super().__init__(err_on_collision=err_on_collision, _cache=_cache) - self.nsuccessors: Mapping[Array, int] = nsuccessors + self.successors: Mapping[Array, list[ArrayOrNames]] = successors @override def _cache_add( @@ -269,7 +288,7 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: for bnd_name, bnd in children_rec.items()}) return _materialize_if_mpms( expr.replace_if_different(bindings=new_children), - self.nsuccessors[expr], + self.successors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: @@ -277,7 +296,7 @@ def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: new_arrays = tuple(ary.expr for ary in rec_arrays) return _materialize_if_mpms( expr.replace_if_different(arrays=new_arrays), - self.nsuccessors[expr], + self.successors[expr], rec_arrays) def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: @@ -285,14 +304,14 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: new_arrays = tuple(ary.expr for ary in rec_arrays) return _materialize_if_mpms( expr.replace_if_different(arrays=new_arrays), - self.nsuccessors[expr], + self.successors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) return _materialize_if_mpms( expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], + self.successors[expr], (rec_array,)) def map_axis_permutation(self, expr: AxisPermutation @@ -300,7 +319,7 @@ def map_axis_permutation(self, expr: AxisPermutation rec_array = self.rec(expr.array) return _materialize_if_mpms( expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], + self.successors[expr], (rec_array,)) def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: @@ -319,7 +338,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: else new_indices) return _materialize_if_mpms( expr.replace_if_different(array=rec_array.expr, indices=new_indices), - self.nsuccessors[expr], + self.successors[expr], (rec_array, *tuple(rec_indices.values()))) def map_basic_index(self, expr: BasicIndex) -> MPMSMaterializerAccumulator: @@ -338,7 +357,7 @@ def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) return _materialize_if_mpms( expr.replace_if_different(array=rec_array.expr), - self.nsuccessors[expr], + self.successors[expr], (rec_array,)) def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: @@ -346,7 +365,7 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: new_args = tuple(ary.expr for ary in rec_args) return _materialize_if_mpms( expr.replace_if_different(args=new_args), - self.nsuccessors[expr], + self.successors[expr], rec_args) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays @@ -427,8 +446,8 @@ def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: ====== ======== ======= """ - from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers - materializer = MPMSMaterializer(get_nusers(expr)) + from pytato.analysis import get_list_of_users, get_num_nodes, get_num_tags_of_type + materializer = MPMSMaterializer(get_list_of_users(expr)) if isinstance(expr, Array): res = materializer(expr).expr