diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6de9e5de4..22139a26e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,18 +27,22 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import product from pytato.array import ( Array, + ArrayOrScalar, Concatenate, CSRMatmul, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -46,17 +50,29 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, ShapeType, Stack, ) +from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.loopy import LoopyCallResult +from pytato.scalar_expr import ( + FlopCounter as ScalarFlopCounter, + InputUseCounter as ScalarInputUseCounter, +) +from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, CachedWalkMapper, + CacheKeyT, CombineMapper, Mapper, VisitKeyT, ) +from pytato.transform.lower_to_index_lambda import ( + MapAsIndexLambdaMixin, +) if TYPE_CHECKING: @@ -64,7 +80,6 @@ import pytools.tag - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall __doc__ = """ @@ -88,6 +103,16 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type + +.. autoclass:: MaterializedNodeCollector +.. autofunction:: collect_materialized_nodes + +.. autoclass:: UndefinedOpFlopCountError +.. autofunction:: get_default_op_name_to_num_flops +.. autofunction:: get_num_flops +.. autofunction:: get_materialized_node_flop_counts +.. autoclass:: UnmaterializedNodeFlopCounts +.. autofunction:: get_unmaterialized_node_flop_counts """ @@ -750,6 +775,86 @@ def get_num_tags_of_type( # }}} +# {{{ MaterializedNodeCollector + +class MaterializedNodeCollector(CachedWalkMapper[[]]): + """ + Return the nodes in a DAG that are materialized. + + See :func:`collect_materialized_nodes` for more details. + """ + def __init__( + self, + include_outputs: bool = True, + _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + self.include_outputs: bool = include_outputs + self.materialized_nodes: set[Array] = set() + + @overload + def __call__(self, expr: ArrayOrNames) -> None: + ... + + @overload + def __call__(self, expr: FunctionDefinition) -> None: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + ) -> None: + super().__call__(expr) + + if self.include_outputs: + if isinstance(expr, DictOfNamedArrays): + self.materialized_nodes.update(expr._data.values()) + elif isinstance(expr, Array): + self.materialized_nodes.add(expr) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not isinstance(expr, Array): + return + + if ( + isinstance( + expr, ( + InputArgumentBase, + DistributedRecv, + LoopyCallResult, + CSRMatmul)) + or expr.tags_of_type(ImplStored)): + self.materialized_nodes.add(expr) + elif isinstance(expr, DistributedSendRefHolder): + self.materialized_nodes.add(expr.send.data) + + +def collect_materialized_nodes( + expr: ArrayOrNames | FunctionDefinition, + include_outputs: bool = True) -> frozenset[Array]: + """ + Return the nodes in DAG *expr* that are materialized. + + The result includes inputs, outputs (optionally), arrays tagged with + `pytato.tags.ImplStored`, as well as special arrays that are always materialized + (:class:`pytato.DistributedSend`'s data, for example). + """ + mac = MaterializedNodeCollector(include_outputs=include_outputs) + mac(expr) + return frozenset(mac.materialized_nodes) + +# }}} + + # {{{ PytatoKeyBuilder class PytatoKeyBuilder(LoopyKeyBuilder): @@ -778,4 +883,650 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ flop counting + +class UndefinedOpFlopCountError(ValueError): + pass + + +class _NonIntegralPerEntryFlopCountError(ValueError): + pass + + +class _PerEntryFlopCounter( + MapAsIndexLambdaMixin[int, [bool]], + CombineMapper[int, Never, [bool]]): + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: + super().__init__() + self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( + op_name_to_num_flops) + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> int: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> int: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root + + @override + def combine(self, *args: int) -> int: + return sum(args) + + @override + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool) -> int: + if expr in self.materialized_nodes and not is_root: + return 0 + + self_nflops = self.scalar_flop_counter(idx_lambda.expr) + + if not isinstance(self_nflops, int): + # Restricting to numerical result here because the flop counters that use + # this mapper subsequently multiply the result by things that are + # potentially arrays (e.g., shape components), and arrays and scalar + # expressions are not interoperable + from pytato.scalar_expr import OpFlops, OpFlopsCollector + op_flops: frozenset[OpFlops] = OpFlopsCollector()(self_nflops) + if op_flops: + op_names = sorted({of.op for of in op_flops}) + formatted_ops = ", ".join(f"'{name}'" for name in op_names) + raise UndefinedOpFlopCountError( + f"Undefined flop count for operation(s): {formatted_ops}.") + else: + raise _NonIntegralPerEntryFlopCountError( + "Unable to compute an integer-valued per-entry flop count.") + + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + # self_nflops check above should take care of non-constant use count case + assert all( + isinstance(nuses, int) + for nuses in binding_to_nuses.values()) + + return self.combine( + self_nflops, + *( + cast("int", binding_to_nuses.get(name, 0)) * self.rec(bnd, False) + for name, bnd in sorted(idx_lambda.bindings.items()))) + + @override + def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: + return 0 + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> int: + return 0 + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> int: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, is_root: bool) -> int: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> int: + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result(self, expr: LoopyCallResult, is_root: bool) -> int: + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> int: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv(self, expr: DistributedRecv, is_root: bool) -> int: + return 0 + + @override + def map_call(self, expr: Call, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result(self, expr: NamedCallResult, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + +class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the number of floating point operations of each materialized + expression in a DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array], + ) -> None: + super().__init__() + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes + self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops, self.materialized_nodes) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if expr not in self.materialized_nodes: + return + assert isinstance(expr, Array) + try: + nflops_per_entry = self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + nflops_self = ( + ( + nelems # multiplies + + nelems - expr.shape[0]) # adds + * product(expr.shape[1:])) + nflops_children = ( + nelems + * ( + self._per_entry_flop_counter(expr.matrix.elem_values) + + self._per_entry_flop_counter(expr.array)) + * product(expr.shape[1:])) + self.materialized_node_to_nflops[expr] = ( + nflops_self + nflops_children) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e + else: + self.materialized_node_to_nflops[expr] = ( + nflops_per_entry + * product(expr.shape)) + + +class _UnmaterializedSubexpressionUseCounter( + MapAsIndexLambdaMixin[dict[Array, int], [bool]], + CombineMapper[dict[Array, int], Never, [bool]]): + def __init__( + self, + materialized_nodes: frozenset[Array]) -> None: + super().__init__() + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> dict[Array, int]: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> dict[Array, int]: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root + + @override + def combine(self, *args: dict[Array, int]) -> dict[Array, int]: + result: dict[Array, int] = defaultdict(int) + for arg in args: + for ary, nuses in arg.items(): + result[ary] += nuses + return result + + @override + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool + ) -> dict[Array, int]: + if expr in self.materialized_nodes and not is_root: + return {} + + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + if any( + not isinstance(nuses, int) + for nuses in binding_to_nuses.values()): + raise ValueError( + "Unable to compute integer-valued use counts for the predecessors " + f"of array of type '{type(expr).__name__}'.") + + return self.combine( + {expr: 1} if not is_root else {}, + *( + { + ary: cast("int", binding_to_nuses.get(name, 0)) * nuses + for ary, nuses in self.rec(bnd, False).items()} + for name, bnd in sorted(idx_lambda.bindings.items()))) + + @override + def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> dict[Array, int]: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays, is_root: bool) -> dict[Array, int]: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> dict[Array, int]: + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result( + self, expr: LoopyCallResult, is_root: bool) -> dict[Array, int]: + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> dict[Array, int]: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv( + self, expr: DistributedRecv, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_call(self, expr: Call, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result( + self, expr: NamedCallResult, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + +@dataclass +class UnmaterializedNodeFlopCounts: + """ + Floating point operation counts for an unmaterialized node. See + :func:`get_unmaterialized_node_flop_counts` for details. + """ + materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] + nflops_if_materialized: ArrayOrScalar + + +class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the accumulated number of floating point operations that each + unmaterialized expression contributes to materialized expressions in the DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: + super().__init__() + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes + self.unmaterialized_node_to_flop_counts: \ + dict[Array, UnmaterializedNodeFlopCounts] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops, self.materialized_nodes) + self._use_counter: _UnmaterializedSubexpressionUseCounter = \ + _UnmaterializedSubexpressionUseCounter( + self.materialized_nodes) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if expr not in self.materialized_nodes: + return + assert isinstance(expr, Array) + try: + self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + subexpr_to_nuses_per_elem = self._use_counter.combine( + self._use_counter(expr.matrix.elem_values, is_root=False), + self._use_counter(expr.array, is_root=False)) + for subexpr, nuses_per_elem in subexpr_to_nuses_per_elem.items(): + try: + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + except _NonIntegralPerEntryFlopCountError: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(subexpr).__name__}'; per-entry flop count is " + "unexpectedly non-integer-valued.") from e + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = ( + nflops_per_entry * product(subexpr.shape)) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in \ + flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses_per_elem + * nelems + * nflops_per_entry + * product(expr.shape[1:])) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e + else: + subexpr_to_nuses = self._use_counter(expr) + for subexpr, nuses in subexpr_to_nuses.items(): + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = nflops_per_entry * product(subexpr.shape) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses + * nflops_per_entry + * product(expr.shape)) + + +def get_default_op_name_to_num_flops() -> dict[str, int]: + """ + Returns a mapping from operator name to floating point operation count for + operators that are almost always a single flop. + """ + return { + "+": 1, + "*": 1, + "==": 1, + "!=": 1, + "<": 1, + ">": 1, + "<=": 1, + ">=": 1, + "min": 1, + "max": 1} + + +def get_num_flops( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> ArrayOrScalar: + """ + Count the total number of floating point operations in the DAG *expr*. + + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. The total flop count is the sum + over all materialized nodes. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + + materialized_nodes = collect_materialized_nodes(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) + fc(expr) + + return sum(fc.materialized_node_to_nflops.values()) + + +def get_materialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, ArrayOrScalar]: + """ + Returns a dictionary mapping materialized nodes in DAG *expr* to their floating + point operation count. + + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + + materialized_nodes = collect_materialized_nodes(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) + fc(expr) + + return fc.materialized_node_to_nflops + + +def get_unmaterialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, UnmaterializedNodeFlopCounts]: + """ + Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a + :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count + information. + + The :class:`UnmaterializedNodeFlopCounts` instance for each unmaterialized node + (i.e., a node that can be tagged with :class:`pytato.tags.ImplStored` but isn't) + contains `materialized_successor_to_contrib_nflops` and `nflops_if_materialized` + attributes. The former is a mapping from each materialized successor of the + unmaterialized node to the number of flops the node contributes to evaluating + that successor (this includes flops from the predecessors of the unmaterialized + node). The latter is the number of flops that would be required to evaluate the + unmaterialized node if it was materialized instead. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + + .. note:: + + Does not support functions. Function calls must be inlined before calling. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + + materialized_nodes = collect_materialized_nodes(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) + fc(expr) + + return fc.unmaterialized_node_to_flop_counts + +# }}} + + # vim: fdm=marker diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..ab8f265aa 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,6 +34,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -80,22 +81,27 @@ class _NoValue: class ReductionOperation(ABC): """ + .. automethod:: scalar_op_name .. automethod:: neutral_element .. automethod:: __hash__ .. automethod:: __eq__ """ + @classmethod + @abstractmethod + def scalar_op_name(cls) -> str: + ... @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: - pass + ... @abstractmethod def __hash__(self) -> int: - pass + ... @abstractmethod def __eq__(self, other: object) -> bool: - pass + ... class _StatelessReductionOperation(ReductionOperation): @@ -110,16 +116,31 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "+" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 0 class ProductReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "*" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 1 class MaxReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "max" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("-inf")) @@ -130,6 +151,11 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "min" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("inf")) @@ -140,11 +166,21 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "and" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) class AnyReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls) -> str: + return "or" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index f99c2280c..5f53d9e8e 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -44,7 +44,9 @@ import re +from collections import defaultdict from collections.abc import Iterable, Mapping, Set as AbstractSet +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -60,6 +62,7 @@ from loopy.symbolic import guarded_pwaff_from_expr from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( + Collector, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, P, @@ -73,9 +76,11 @@ ) from pymbolic.mapper.distributor import DistributeMapper as DistributeMapperBase from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase +from pymbolic.mapper.flop_counter import FlopCounterBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer +from pytools import product if TYPE_CHECKING: @@ -244,6 +249,203 @@ def map_type_cast(self, inner_str = self.rec(expr.inner_expr, PREC_NONE, *args, **kwargs) return f"cast({expr.dtype}, {inner_str})" + +class InputGatherer(Collector[str, []]): + @override + def map_variable(self, expr: prim.Variable) -> set[str]: + return {expr.name} + + +class InputUseCounter(CombineMapper[dict[str, ArithmeticExpression], []]): + @override + def combine( + self, values: Iterable[dict[str, ArithmeticExpression]] + ) -> dict[str, ArithmeticExpression]: + result: dict[str, ArithmeticExpression] = defaultdict(int) + for val in values: + for name, nuses in val.items(): + result[name] += nuses + return result + + @override + def map_constant(self, expr: object) -> dict[str, ArithmeticExpression]: + return {} + + @override + def map_variable(self, expr: prim.Variable) -> dict[str, ArithmeticExpression]: + return {expr.name: 1} + + @override + def map_reduce(self, expr: Reduce) -> dict[str, ArithmeticExpression]: + inner_expr_result = self.rec(expr.inner_expr) + niters = product(( + upper_bd - lower_bd + for lower_bd, upper_bd in expr.bounds.values())) + return { + name: niters * inner_expr_nuses + for name, inner_expr_nuses in inner_expr_result.items()} + + +class FlopCounter(FlopCounterBase): + op_name_to_num_flops: dict[str, ArithmeticExpression] + + def __init__( + self, + op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): + super().__init__() + if op_name_to_num_flops: + self.op_name_to_num_flops = dict(op_name_to_num_flops) + else: + self.op_name_to_num_flops = {} + + def _get_op_nflops(self, name: str) -> ArithmeticExpression: + try: + return self.op_name_to_num_flops[name] + except KeyError: + result = OpFlops(name) + self.op_name_to_num_flops[name] = result + return result + + @override + def map_call(self, expr: prim.Call) -> ArithmeticExpression: + assert isinstance(expr.function, prim.Variable) + return ( + self._get_op_nflops(expr.function.name) + + sum(self.rec(child) for child in expr.parameters)) + + @override + def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: + # Assume index calculations are performed on non-floats + return self.rec(expr.aggregate) + + @override + def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + if expr.children: + return ( + self._get_op_nflops("+") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_product(self, expr: prim.Product) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("*") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_quotient(self, expr: prim.Quotient) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + return ( + self._get_op_nflops("/") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + return ( + self._get_op_nflops("//") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_remainder(self, expr: prim.Remainder) -> ArithmeticExpression: + return ( + self._get_op_nflops("%") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_power(self, expr: prim.Power) -> ArithmeticExpression: + if isinstance(expr.exponent, int): + # The calculation below is based on the following code (which is an + # approximation of what is done in loopy) + # def pow(x, n): + # if n == 0: return 1 + # if n == 1: return x + # if n == 2: return x*x + # if n < 0: + # x = 1/x + # n = -n + # y = 1 + # while n > 1: + # if n % 2: + # y = x * y + # x = x * x + # n = n/2 + # return x*y + if expr.exponent == 0: + return 0 + elif expr.exponent > 0 and expr.exponent <= 2: + return ( + (expr.exponent - 1) * self._get_op_nflops("*") + + self.rec(expr.base)) + nmults = 1 + remaining_exp = abs(expr.exponent) + while remaining_exp > 1: + if remaining_exp % 2: + nmults += 1 + nmults += 1 + remaining_exp //= 2 + nflops = ( + nmults * self._get_op_nflops("*") + + self.rec(expr.base)) + if expr.exponent < 0: + nflops += self._get_op_nflops("/") + return nflops + else: + return ( + self._get_op_nflops("**") + + self.rec(expr.base) + + self.rec(expr.exponent)) + + @override + def map_comparison(self, expr: prim.Comparison) -> ArithmeticExpression: + return ( + self._get_op_nflops(expr.operator) + + self.rec(expr.left) + + self.rec(expr.right)) + + @override + def map_if(self, expr: prim.If) -> ArithmeticExpression: + return ( + self.rec(expr.condition) + + self.rec(expr.then) + + self.rec(expr.else_)) + + @override + def map_max(self, expr: prim.Max) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("max") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_min(self, expr: prim.Min) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("min") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_nan(self, expr: prim.NaN) -> ArithmeticExpression: + return 0 + + def map_reduce(self, expr: Reduce) -> ArithmeticExpression: + result = self.rec(expr.inner_expr) + nflops_op = self._get_op_nflops(expr.op.scalar_op_name()) + for lower_bd, upper_bd in expr.bounds.values(): + nops = upper_bd - lower_bd + result = result * nops + nflops_op * (nops-1) + + return result + # }}} @@ -346,9 +548,42 @@ class TypeCast(ExpressionBase): dtype: np.dtype[Any] inner_expr: ScalarExpression + +@expr_dataclass() +class OpFlops(prim.AlgebraicLeaf): + """ + Placeholder flop count for an operator. + + .. autoattribute:: op + """ + op: str + # }}} +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[OpFlops]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[OpFlops]: + return frozenset() + + class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: from functools import reduce diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 3537b0486..59d0f0026 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -2,6 +2,7 @@ .. currentmodule:: pytato.transform.lower_to_index_lambda .. autofunction:: to_index_lambda +.. autoclass:: MapAsIndexLambdaMixin """ from __future__ import annotations @@ -28,8 +29,9 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from constantdict import constantdict from typing_extensions import Never @@ -65,6 +67,8 @@ from pytato.tags import AssumeNonNegative from pytato.transform import ( Mapper, + P, + ResultT, _verify_is_array, ) from pytato.utils import normalized_slice_does_not_change_axis @@ -772,4 +776,64 @@ def to_index_lambda(expr: Array) -> IndexLambda: assert isinstance(res, IndexLambda) return res + +class MapAsIndexLambdaMixin(ABC, Generic[ResultT, P]): + """ + Mixin that, where possible, lowers arrays to :class:`~pytato.array.IndexLambda` + and calls :meth:`map_as_index_lambda` on them. + + .. automethod:: map_as_index_lambda + """ + @abstractmethod + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + """ + Map *expr* via its :class:`~pytato.array.IndexLambda` representation + *idx_lambda*. + """ + + def map_index_lambda( + self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, expr, *args, **kwargs) + + def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_axis_permutation( + self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_basic_index( + self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_concatenate( + self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + # vim:fdm=marker diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..d7bc8eade 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -735,4 +735,6 @@ def get_einsum_specification(expr: Einsum) -> str: for i in range(expr.ndim)) return f"{','.join(input_specs)}->{output_spec}" + + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 9d07d5ccd..9932ec84e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -820,6 +820,403 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_scalar_flop_count(): + from pytato.scalar_expr import FlopCounter + fc = FlopCounter({ + "+": 1, + "*": 1, + "/": 4, + "//": 4, + "%": 4, + "**": 8, + "<": 1, + "min": 1, + "max": 1, + "f": 32}) + + import pymbolic.primitives as prim + from pymbolic import Variable + + x = 2*Variable("x") + y = 3 + Variable("y") + + assert fc(x) == 1 + assert fc(y) == 1 + + assert fc(Variable("f")(x)) == 32 + 1 + + assert fc(x[0]) == 0 + 1 + + assert fc(x + 2) == 1 + 1 + assert fc(2 + y) == 1 + 1 + assert fc(x + y) == 1 + 2 + + assert fc(prim.Sum((2, x, y))) == 2 + 2 + + assert fc(x - 2) == 1 + 1 + assert fc(2 - y) == 2 + 1 + assert fc(x - y) == 2 + 2 + + assert fc(x * 2) == 1 + 1 + assert fc(2 * y) == 1 + 1 + assert fc(x * y) == 1 + 2 + + assert fc(prim.Product((2, x, y))) == 2 + 2 + + assert fc(x.or_(y)) == 0 + 2 + assert fc(x.and_(y)) == 0 + 2 + + assert fc(x / 2) == 4 + 1 + assert fc(2 / y) == 4 + 1 + assert fc(x / y) == 4 + 2 + + assert fc(x // 2) == 4 + 1 + + assert fc(x % 2) == 4 + 1 + + assert fc(x ** 0) == 0 + assert fc(x ** 1) == 0 + 1 + # x * x + assert fc(x ** 2) == 1 + 1 + # compute x^2, x^4, x^8, x^16, x^32; multiply x^32 * x^16 * x^8 * x^4 * x * 1 + assert fc(x ** 61) == 5 + 5 + 1 + # divide; compute x^2, x^4, x^8, x^16; multiply x^16 * x^4 * x^2 * 1 + assert fc(x ** -22) == 4 + 4 + 3 + 1 + assert fc(x ** 0.3) == 8 + 1 + assert fc(x ** y) == 8 + 2 + + assert fc(x.lt(y)) == 1 + 2 + + assert fc(prim.If(x, x, y)) == 0 + 3 + + assert fc(prim.Min((2, x, y))) == 2 + 2 + assert fc(prim.Max((2, x, y))) == 2 + 2 + + from constantdict import constantdict + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) \ + == 9 + 10 + + +def test_flop_count(): + from pytato.analysis import ( + UndefinedOpFlopCountError, + get_default_op_name_to_num_flops, + get_num_flops, + ) + from pytato.tags import ImplStored + + # {{{ basic expression + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = u - v + + # expr[i, j] = 2*(x[i, j] + y[i, j]) + (-1)*3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*6 + + # }}} + + # {{{ expression with operators that don't have default flop counts + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + expr = pt.cmath.exp(x / y) + + with pytest.raises(UndefinedOpFlopCountError): + get_num_flops(expr) + + op_name_to_num_flops = get_default_op_name_to_num_flops() + op_name_to_num_flops.update({ + "/": 4, + "pytato.c99.exp": 8}) + + assert get_num_flops(expr, op_name_to_num_flops) == 40*12 + + # }}} + + # {{{ multiple expressions + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = pt.make_dict_of_named_arrays({"u": u, "v": v}) + + # expr["u"][i, j] = 2*(x[i, j] + y[i, j]) + # expr["v"][i, j] = 3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*4 + + # }}} + + # {{{ subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # expr[i, j] = 2*(x[2*i, j] + y[2*i, j]) + (-1)*3*(x[2*i, j] + y[2*i, j]) + assert get_num_flops(expr) == 20*6 + + # }}} + + # {{{ materialized array + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert get_num_flops(expr) == 40 + 40*4 + + # }}} + + # {{{ materialized array and subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[2*i, j] + (-1)*3*z[2*i, j] + assert get_num_flops(expr) == 40 + 20*4 + + # }}} + + # {{{ einsum + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->ijk", 2*x, 3*y) + + # expr[i, j, k] = 2*x[i, j, k] * 3*y[j, k] + assert get_num_flops(expr) == 72 + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->i", 2*x, 3*y) + + # expr[i] = sum(sum(2*x[i, j, k] * 3*y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*3 + 2) + 3) + + # }}} + + # {{{ CSR matmul (trivial predecessors) + + x = pt.make_csr_matrix( + shape=(8, 10), + elem_values=pt.make_placeholder("x_elem_values", (16,)), + elem_col_indices=pt.make_placeholder("x_elem_col_indices", (16,)), + row_starts=pt.make_placeholder("x_row_starts", (9,))) + y = pt.make_placeholder("y", (10, 5, 3)) + expr = x @ y + + assert get_num_flops(expr) == 5*3*( + 16 # multiplies + + 16 - 8 # adds + ) + + # }}} + + # {{{ CSR matmul (nontrivial predecessors) + + elem_values = pt.zeros(12) + 1 + elem_col_indices = pt.make_data_wrapper(np.array([ + 0, + 0, 1, + 0, 1, 2, + 1, 2, 3, + 2, 3, + 3])) + row_starts = pt.make_data_wrapper(np.array([0, 1, 3, 6, 9, 11, 12])) + x = pt.make_csr_matrix( + shape=(6, 4), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.zeros((4, 3, 2)) + 1 + expr = x @ y + + assert get_num_flops(expr) == 3*2*( + 3 # row 1 + + 3*2 + 1 # row 2 + + 3*3 + 2 # row 3 + + 3*3 + 2 # row 4 + + 3*2 + 1 # row 5 + + 3 # row 6 + ) + + # }}} + + +def test_materialized_node_flop_counts(): + from pytato.analysis import get_materialized_node_flop_counts + from pytato.tags import ImplStored + + # {{{ basic DAG + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert len(materialized_node_to_flop_count) == 4 + assert x in materialized_node_to_flop_count + assert y in materialized_node_to_flop_count + assert z in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count + assert materialized_node_to_flop_count[x] == 0 + assert materialized_node_to_flop_count[y] == 0 + assert materialized_node_to_flop_count[z] == 40 + assert materialized_node_to_flop_count[expr] == 40*4 + + # }}} + + # {{{ CSR matmul + + zeros_10 = pt.make_data_wrapper(np.zeros(10)) + zeros_20 = pt.make_data_wrapper(np.zeros(20)) + elem_values = zeros_20 + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = zeros_10 + 1 + z = x @ y + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + assert len(materialized_node_to_flop_count) == 6 + assert zeros_20 in materialized_node_to_flop_count + assert elem_col_indices in materialized_node_to_flop_count + assert row_starts in materialized_node_to_flop_count + assert zeros_10 in materialized_node_to_flop_count + assert z in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count + assert materialized_node_to_flop_count[zeros_20] == 0 + assert materialized_node_to_flop_count[elem_col_indices] == 0 + assert materialized_node_to_flop_count[row_starts] == 0 + assert materialized_node_to_flop_count[zeros_10] == 0 + # flops from elem_values/y/z (2 elems per row so y gets used twice) + assert materialized_node_to_flop_count[z] == 20 + 20 + 20 + 10 + # flops from u/v/expr (no flops from z because it's materialized) + assert materialized_node_to_flop_count[expr] == 40 + + # }}} + + +def test_unmaterialized_node_flop_counts(): + from pytato.analysis import get_unmaterialized_node_flop_counts + + # {{{ basic DAG + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + # Make a reduction over a bunch of expressions that reference z + z = x + y + w = [i*z for i in range(1, 11)] + s = [w[0]] + for w_i in w[1:-1]: + s.append(s[-1] + w_i) + expr = s[-1] + w[-1] + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + # Everything except x/y/expr stays unmaterialized + assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 + assert z in unmaterialized_node_to_flop_counts + assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) + assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) + flop_counts = unmaterialized_node_to_flop_counts[z] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*10 + assert flop_counts.nflops_if_materialized == 40 + for w_i in w: + flop_counts = unmaterialized_node_to_flop_counts[w_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*2 + assert flop_counts.nflops_if_materialized == 40*2 + for i, s_i in enumerate(s): + flop_counts = unmaterialized_node_to_flop_counts[s_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == \ + 40*2*(i+1) + 40*i + assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + + # }}} + + # {{{ CSR matmul + + elem_values = pt.make_data_wrapper(np.zeros(20)) + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.make_data_wrapper(np.zeros(10)) + 1 + expr = x @ y + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + assert len(unmaterialized_node_to_flop_counts) == 2 + assert elem_values in unmaterialized_node_to_flop_counts + assert y in unmaterialized_node_to_flop_counts + flop_counts = unmaterialized_node_to_flop_counts[elem_values] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 20 + flop_counts = unmaterialized_node_to_flop_counts[y] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 10 + + # }}} + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4))