From 00e2e003cc5d4be150e72d4d82fb6c3050130819 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 24 Oct 2025 15:51:27 -0500 Subject: [PATCH 01/32] add flop counting functions --- pytato/analysis/__init__.py | 360 +++++++++++++++++++++++++++++++++++- pytato/reductions.py | 36 ++++ pytato/scalar_expr.py | 136 ++++++++++++++ pytato/transform/calls.py | 94 +++++++++- pytato/utils.py | 21 ++- test/test_pytato.py | 281 ++++++++++++++++++++++++++++ 6 files changed, 924 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6de9e5de4..d4ca82aa8 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,16 +27,19 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import product from pytato.array import ( Array, + ArrayOrScalar, Concatenate, CSRMatmul, DictOfNamedArrays, @@ -50,13 +53,21 @@ Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.scalar_expr import ( + FlopCounter as ScalarFlopCounter, +) +from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, + ArrayOrNamesTc, CachedWalkMapper, CombineMapper, Mapper, VisitKeyT, + map_and_copy, ) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import is_materializable if TYPE_CHECKING: @@ -88,6 +99,13 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type + +.. autoclass:: UndefinedOpFlopCountError +.. autofunction:: get_default_op_name_to_num_flops +.. autofunction:: get_num_flops +.. autofunction:: get_materialized_node_flop_counts +.. autoclass:: UnmaterializedNodeFlopCounts +.. autofunction:: get_unmaterialized_node_flop_counts """ @@ -778,4 +796,344 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ flop counting + +def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and bool(expr.tags_of_type(ImplStored))) + + +def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and not bool(expr.tags_of_type(ImplStored))) + + +@dataclass +class UndefinedOpFlopCountError(ValueError): + op_name: str + + +class _PerEntryFlopCounter(CombineMapper[int, Never, []]): + def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() + self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( + op_name_to_num_flops) + self.node_to_nflops: dict[Array, int] = {} + + @override + def combine(self, *args: int) -> int: + return sum(args) + + @override + def rec(self, expr: ArrayOrNames) -> int: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: int + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(self_nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self_nflops + cast("int", Mapper.rec(self, expr)) + else: + result = 0 + if isinstance(expr, Array): + self.node_to_nflops[expr] = result + return self._cache_add(inputs, result) + + +class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the number of floating point operations of each materialized + expression in a DAG. + + .. note:: + + Flops from nodes inside function calls are accumulated onto the corresponding + call node. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None, + _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None + ) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} + self.call_to_nflops: dict[Call, ArrayOrScalar] = {} + self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ + _function_to_nflops if _function_to_nflops is not None else {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + op_name_to_num_flops=self.op_name_to_num_flops, + _visited_functions=self._visited_functions, + _function_to_nflops=self._function_to_nflops) + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for subexpr in expr.returns.values(): + # Assume that any calls that haven't been inlined have their functions' + # outputs materialized + assert not _is_unmaterialized(subexpr) + new_mapper(subexpr) + + self._function_to_nflops[expr] = ( + sum(new_mapper.materialized_node_to_nflops.values()) + + sum(new_mapper.call_to_nflops.values())) + + self.post_visit(expr) + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + self.rec_function_definition(expr.function) + for bnd in expr.bindings.values(): + # Assume that any calls that haven't been inlined have their inputs + # materialized + assert not _is_unmaterialized(bnd) + self.rec(bnd) + + self.call_to_nflops[expr] = self._function_to_nflops[expr.function] + + self.post_visit(expr) + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + self._per_entry_flop_counter(unmaterialized_expr) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + + +class _UnmaterializedSubexpressionUseCounter( + CombineMapper[dict[Array, int], Never, []]): + @override + def combine(self, *args: dict[Array, int]) -> dict[Array, int]: + result: dict[Array, int] = defaultdict(int) + for arg in args: + for ary, nuses in arg.items(): + result[ary] += nuses + return result + + @override + def rec(self, expr: ArrayOrNames) -> dict[Array, int]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: dict[Array, int] + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self.combine( + {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) + else: + result = {} + return self._cache_add(inputs, result) + + +@dataclass +class UnmaterializedNodeFlopCounts: + materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] + nflops_if_materialized: ArrayOrScalar + + +class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the accumulated number of floating point operations that each + unmaterialized expression contributes to materialized expressions in the DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.unmaterialized_node_to_flop_counts: \ + dict[Array, UnmaterializedNodeFlopCounts] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( + unmaterialized_expr) + del subexpr_to_nuses[unmaterialized_expr] + self._per_entry_flop_counter(unmaterialized_expr) + for subexpr, nuses in subexpr_to_nuses.items(): + per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = product(subexpr.shape) * per_entry_nflops + flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses * product(expr.shape) * per_entry_nflops) + + +# FIXME: Should this be added to normalize_outputs? +def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + # Make sure call bindings and function results are materialized + from pytato.transform.calls import normalize_calls + expr = normalize_calls(expr) + + # Make sure outputs are materialized + if isinstance(expr, DictOfNamedArrays): + output_to_materialized_output: dict[Array, Array] = { + ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + for ary in expr._data.values()} + + def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: + if not isinstance(ary, Array): + return ary + try: + return output_to_materialized_output[ary] + except KeyError: + return ary + + expr = map_and_copy(expr, replace_with_materialized) + + return expr + + +def get_default_op_name_to_num_flops() -> dict[str, int]: + """ + Returns a mapping from operator name to floating point operation count for + operators that are almost always a single flop. + """ + return { + "+": 1, + "*": 1, + "==": 1, + "!=": 1, + "<": 1, + ">": 1, + "<=": 1, + ">=": 1, + "min": 1, + "max": 1} + + +def get_num_flops( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> ArrayOrScalar: + """Count the total number of floating point operations in the DAG *expr*.""" + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return ( + sum(fc.materialized_node_to_nflops.values()) + + sum(fc.call_to_nflops.values())) + + +def get_materialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, ArrayOrScalar]: + """ + Returns a dictionary mapping materialized nodes in DAG *expr* to their floating + point operation count. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.materialized_node_to_nflops + + +def get_unmaterialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, UnmaterializedNodeFlopCounts]: + """ + Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a + :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count + information. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.unmaterialized_node_to_flop_counts + +# }}} + + # vim: fdm=marker diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..a35c1d98a 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,6 +34,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -80,10 +81,15 @@ class _NoValue: class ReductionOperation(ABC): """ + .. automethod:: scalar_op_name .. automethod:: neutral_element .. automethod:: __hash__ .. automethod:: __eq__ """ + @classmethod + @abstractmethod + def scalar_op_name(cls) -> str: + pass @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -110,16 +116,31 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "+" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 0 class ProductReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "*" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 1 class MaxReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "max" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("-inf")) @@ -130,6 +151,11 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "min" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("inf")) @@ -140,11 +166,21 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "or" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) class AnyReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "and" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index f99c2280c..c0a19b0bf 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -60,6 +60,7 @@ from loopy.symbolic import guarded_pwaff_from_expr from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( + Collector, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, P, @@ -73,6 +74,7 @@ ) from pymbolic.mapper.distributor import DistributeMapper as DistributeMapperBase from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase +from pymbolic.mapper.flop_counter import FlopCounterBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer @@ -244,6 +246,140 @@ def map_type_cast(self, inner_str = self.rec(expr.inner_expr, PREC_NONE, *args, **kwargs) return f"cast({expr.dtype}, {inner_str})" + +class InputGatherer(Collector[str, []]): + @override + def map_variable(self, expr: prim.Variable) -> set[str]: + return {expr.name} + + +class FlopCounter(FlopCounterBase): + def __init__( + self, + op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): + super().__init__() + self.op_name_to_num_flops: dict[str, ArithmeticExpression] + if op_name_to_num_flops: + self.op_name_to_num_flops = dict(op_name_to_num_flops) + else: + self.op_name_to_num_flops = {} + + def _get_op_nflops(self, name: str) -> ArithmeticExpression: + try: + return self.op_name_to_num_flops[name] + except KeyError: + from pymbolic import var + result = var("nflops")(var(name)) + self.op_name_to_num_flops[name] = result + return result + + @override + def map_call(self, expr: prim.Call) -> ArithmeticExpression: + assert isinstance(expr.function, prim.Variable) + return ( + self._get_op_nflops(expr.function.name) + + sum(self.rec(child) for child in expr.parameters)) + + @override + def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: + # Assume calculations inside subscripts are performed on non-floats + return 0 + + @override + def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + if expr.children: + return ( + self._get_op_nflops("+") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_product(self, expr: prim.Product) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("*") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_quotient(self, expr: prim.Quotient) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + return ( + self._get_op_nflops("/") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + return ( + self._get_op_nflops("//") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_power(self, expr: prim.Power) -> ArithmeticExpression: + if isinstance(expr.exponent, int): + if expr.exponent >= 0: + return ( + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("/") + + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("**") + + self.rec(expr.base) + + self.rec(expr.exponent)) + + @override + def map_comparison(self, expr: prim.Comparison) -> ArithmeticExpression: + return ( + self._get_op_nflops(expr.operator) + + self.rec(expr.left) + + self.rec(expr.right)) + + @override + def map_if(self, expr: prim.If) -> ArithmeticExpression: + return ( + self.rec(expr.condition) + + self.rec(expr.then) + + self.rec(expr.else_)) + + @override + def map_max(self, expr: prim.Max) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("max") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_min(self, expr: prim.Min) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("min") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_nan(self, expr: prim.NaN) -> ArithmeticExpression: + return 0 + + def map_reduce(self, expr: Reduce) -> ArithmeticExpression: + result = self.rec(expr.inner_expr) + nflops_op = self._get_op_nflops(expr.op.scalar_op_name()) + for lower_bd, upper_bd in expr.bounds.values(): + nops = upper_bd - lower_bd + result = result * nops + nflops_op * (nops-1) + + return result + # }}} diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index ffad1e5f7..e82c2b1e0 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,6 +1,7 @@ """ .. currentmodule:: pytato.transform.calls +.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -32,20 +33,26 @@ from typing import TYPE_CHECKING, cast -from typing_extensions import Self +from immutabledict import immutabledict +from typing_extensions import Never, Self, override from pytato.array import ( AbstractResultWithNamedArrays, Array, + DataWrapper, DictOfNamedArrays, Placeholder, + SizeParam, + make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import InlineCallTag +from pytato.tags import ImplStored, InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, + CombineMapper, CopyMapper, + Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -56,6 +63,89 @@ from collections.abc import Mapping +# {{{ normalizing + +class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never, []]): + """Mapper to collect bindings of calls on the current call stack.""" + @override + def combine(self, *args: frozenset[Array]) -> frozenset[Array]: + from functools import reduce + return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) + + @override + def map_call(self, expr: Call) -> frozenset[Array]: + return frozenset(expr.bindings.values()) + + +class _CallMaterializer(CopyMapper): + """Mapper to add materialization tags for call bindings and function results.""" + def __init__( + self, + local_call_bindings: frozenset[Array], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.local_call_bindings: frozenset[Array] = local_call_bindings + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + local_call_bindings = _LocalStackCallBindingCollector()( + make_dict_of_named_arrays(function.returns)) + return type(self)( + local_call_bindings, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + + def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: + if ( + isinstance(expr, Array) + and not isinstance(expr, + (DataWrapper, Placeholder, SizeParam, NamedCallResult))): + return expr.tagged(ImplStored()) + else: + return expr + + @override + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + new_mapper = self.clone_for_callee(expr) + new_returns: Mapping[str, Array] = immutabledict({ + name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) + for name, ret in expr.returns.items()}) + return expr.replace_if_different(returns=new_returns) + + @override + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = cast("ArrayOrNames", Mapper.rec(self, expr)) + if expr in self.local_call_bindings: + result = self._materialize_if_possible(result) + return self._cache_add(inputs, result) + + +def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + """ + Ensure that calls/functions are defined uniformly. + + Adds any missing materialization tags for call bindings and function results. + """ + local_call_bindings = _LocalStackCallBindingCollector()(expr) + return _CallMaterializer(local_call_bindings)(expr) + +# }}} + + # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..92c1a1cbb 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import CachedMapper +from pytato.transform import ArrayOrNames, CachedMapper if TYPE_CHECKING: @@ -69,6 +69,8 @@ from pytools.tag import Tag + from pytato.function import FunctionDefinition + __doc__ = """ Helper routines @@ -80,6 +82,7 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str +.. autofunction:: is_materializable References ^^^^^^^^^^ @@ -735,4 +738,20 @@ def get_einsum_specification(expr: Einsum) -> str: for i in range(expr.ndim)) return f"{','.join(input_specs)}->{output_spec}" + + +def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: + """ + Returns *True* if *expr* is an instance of an array type that can be materialized. + """ + from pytato.array import InputArgumentBase, NamedArray + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + return ( + isinstance(expr, Array) + and not isinstance(expr, ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, NamedArray, DistributedRecv, + DistributedSendRefHolder))) + + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 9d07d5ccd..8189aafd7 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -820,6 +820,287 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_scalar_flop_count(): + from pytato.scalar_expr import FlopCounter + fc = FlopCounter({ + "+": 1, + "*": 1, + "/": 4, + "//": 4, + "%": 4, + "**": 8, + "<": 1, + "min": 1, + "max": 1, + "f": 32}) + + import pymbolic.primitives as prim + from pymbolic import Variable + + x = Variable("x") + y = Variable("y") + + assert fc(Variable("f")(x)) == 32 + + assert fc(x[0]) == 0 + + assert fc(x + 2) == 1 + assert fc(2 + y) == 1 + assert fc(x + y) == 1 + + assert fc(prim.Sum((2, x, y))) == 2 + + assert fc(x - 2) == 1 + assert fc(2 - y) == 2 + assert fc(x - y) == 2 + + assert fc(x * 2) == 1 + assert fc(2 * y) == 1 + assert fc(x * y) == 1 + + assert fc(prim.Product((2, x, y))) == 2 + + assert fc(x.or_(y)) == 0 + assert fc(x.and_(y)) == 0 + + assert fc(x / 2) == 4 + assert fc(2 / y) == 4 + assert fc(x / y) == 4 + + assert fc(x // 2) == 4 + + assert fc(x % 2) == 0 + + assert fc(x ** 3) == 3 + assert fc(x ** 0.3) == 8 + + assert fc(x.lt(y)) == 1 + + assert fc(prim.If(x, x, y)) == 0 + + assert fc(prim.Min((2, x, y))) == 2 + assert fc(prim.Max((2, x, y))) == 2 + + from constantdict import constantdict + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + + +def test_flop_count(): + from pytato.analysis import ( + UndefinedOpFlopCountError, + get_default_op_name_to_num_flops, + get_num_flops, + ) + from pytato.tags import ImplStored + + # {{{ basic expression + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = u - v + + # expr[i, j] = 2*(x[i, j] + y[i, j]) + (-1)*3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*6 + + # }}} + + # {{{ expression with operators that don't have default flop counts + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + expr = pt.cmath.exp(x / y) + + with pytest.raises(UndefinedOpFlopCountError): + get_num_flops(expr) + + op_name_to_num_flops = get_default_op_name_to_num_flops() + op_name_to_num_flops.update({ + "/": 4, + "pytato.c99.exp": 8}) + + assert get_num_flops(expr, op_name_to_num_flops) == 40*12 + + # }}} + + # {{{ multiple expressions + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = pt.make_dict_of_named_arrays({"u": u, "v": v}) + + # expr["u"][i, j] = 2*(x[i, j] + y[i, j]) + # expr["v"][i, j] = 3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*4 + + # }}} + + # {{{ subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # expr[i, j] = 2*(x[2*i, j] + y[2*i, j]) + (-1)*3*(x[2*i, j] + y[2*i, j]) + assert get_num_flops(expr) == 20*6 + + # }}} + + # {{{ materialized array + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert get_num_flops(expr) == 40 + 40*4 + + # }}} + + # {{{ materialized array and subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[2*i, j] + (-1)*3*z[2*i, j] + assert get_num_flops(expr) == 40 + 20*4 + + # }}} + + # {{{ function call + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + def f(x, y): + z = x + y + return 2*z, 3*z + + u, v = pt.trace_call(f, x, y) + expr = u - v + + # u[i, j] = 2*(x[i, j] + y[i, j]) + # v[i, j] = 3*(x[i, j] + y[i, j]) + # expr[i, j] = u[i, j] + (-1)*v[i, j] + assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 + + # }}} + + # {{{ einsum + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->ijk", x, y) + + # expr[i, j, k] = x[i, j, k] * y[j, k] + assert get_num_flops(expr) == 24 + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->i", x, y) + + # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + + # }}} + + +def test_materialized_node_flop_counts(): + from pytato.analysis import get_materialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert len(materialized_node_to_flop_count) == 2 + assert z in materialized_node_to_flop_count + assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[z] == 40 + assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + + +def test_unmaterialized_node_flop_counts(): + from pytato.analysis import get_unmaterialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + # Make a reduction over a bunch of expressions that reference z + z = x + y + w = [i*z for i in range(1, 11)] + s = [w[0]] + for w_i in w[1:-1]: + s.append(s[-1] + w_i) + expr = s[-1] + w[-1] + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + materialized_expr = expr.tagged(ImplStored()) + + # Everything except expr stays unmaterialized + assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 + assert z in unmaterialized_node_to_flop_counts + assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) + assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) + flop_counts = unmaterialized_node_to_flop_counts[z] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*10 + assert flop_counts.nflops_if_materialized == 40 + for w_i in w: + flop_counts = unmaterialized_node_to_flop_counts[w_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2 + assert flop_counts.nflops_if_materialized == 40*2 + for i, s_i in enumerate(s): + flop_counts = unmaterialized_node_to_flop_counts[s_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2*(i+1) + 40*i + assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) From 200f021d557422d271cc7602d80b11054524c785 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 27 Oct 2025 16:17:59 -0500 Subject: [PATCH 02/32] Update baseline --- .basedpyright/baseline.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 49915bc9f..7c557968b 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1936,6 +1936,22 @@ "endColumn": 25, "lineCount": 1 } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 51, + "endColumn": 61, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 56, + "endColumn": 66, + "lineCount": 1 + } } ], "./pytato/array.py": [ @@ -7467,6 +7483,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 42, + "endColumn": 52, + "lineCount": 1 + } + }, { "code": "reportUnannotatedClassAttribute", "range": { From ec27f300cffa4b311425e72c7d2ef4f5c865227c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Nov 2025 16:49:36 -0600 Subject: [PATCH 03/32] add note to docs about assumptions when handling conditional expressions --- pytato/analysis/__init__.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d4ca82aa8..d461ca369 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1075,7 +1075,15 @@ def get_num_flops( expr: ArrayOrNames, op_name_to_num_flops: Mapping[str, int] | None = None, ) -> ArrayOrScalar: - """Count the total number of floating point operations in the DAG *expr*.""" + """ + Count the total number of floating point operations in the DAG *expr*. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) expr = _normalize_materialization(expr) @@ -1098,6 +1106,12 @@ def get_materialized_node_flop_counts( """ Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1120,6 +1134,12 @@ def get_unmaterialized_node_flop_counts( Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From cf76cfbc07964231f13b3910808450ce1064c14a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:46:02 -0600 Subject: [PATCH 04/32] change 'pass' -> '...' in ReductionOperation abstract methods --- pytato/reductions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index a35c1d98a..a229328c8 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -89,19 +89,19 @@ class ReductionOperation(ABC): @classmethod @abstractmethod def scalar_op_name(cls) -> str: - pass + ... @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: - pass + ... @abstractmethod def __hash__(self) -> int: - pass + ... @abstractmethod def __eq__(self, other: object) -> bool: - pass + ... class _StatelessReductionOperation(ReductionOperation): From 729e4910ebb3dfeb7e3e9fcc21b3a719c8901967 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:52:14 -0600 Subject: [PATCH 05/32] move op_name_to_num_flops type declaration to FlopCounter class body --- pytato/scalar_expr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index c0a19b0bf..524d4e7a7 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -254,11 +254,12 @@ def map_variable(self, expr: prim.Variable) -> set[str]: class FlopCounter(FlopCounterBase): + op_name_to_num_flops: dict[str, ArithmeticExpression] + def __init__( self, op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): super().__init__() - self.op_name_to_num_flops: dict[str, ArithmeticExpression] if op_name_to_num_flops: self.op_name_to_num_flops = dict(op_name_to_num_flops) else: From dde27460ccadeec00b5cbe27071e2954c6de428c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:31:55 -0600 Subject: [PATCH 06/32] add some details about how flop counts are computed --- pytato/analysis/__init__.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d461ca369..f34454741 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -973,6 +973,10 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: @dataclass class UnmaterializedNodeFlopCounts: + """ + Floating point operation counts for an unmaterialized node. See + :func:`get_unmaterialized_node_flop_counts` for details. + """ materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] nflops_if_materialized: ArrayOrScalar @@ -1078,6 +1082,11 @@ def get_num_flops( """ Count the total number of floating point operations in the DAG *expr*. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. The total flop count is the sum + over all materialized nodes. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1107,6 +1116,10 @@ def get_materialized_node_flop_counts( Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1135,6 +1148,15 @@ def get_unmaterialized_node_flop_counts( :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + The :class:`UnmaterializedNodeFlopCounts` instance for each unmaterialized node + (i.e., a node that can be tagged with :class:`pytato.tags.ImplStored` but isn't) + contains `materialized_successor_to_contrib_nflops` and `nflops_if_materialized` + attributes. The former is a mapping from each materialized successor of the + unmaterialized node to the number of flops the node contributes to evaluating + that successor (this includes flops from the predecessors of the unmaterialized + node). The latter is the number of flops that would be required to evaluate the + unmaterialized node if it was materialized instead. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, From 4e9e6ae32d795311c30ecba5fff047373d385092 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:33:20 -0600 Subject: [PATCH 07/32] clarify how flop counting functions behave w.r.t. DAG functions --- pytato/analysis/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f34454741..65596c2f4 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1092,6 +1092,10 @@ def get_num_flops( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1125,6 +1129,10 @@ def get_materialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1162,6 +1170,10 @@ def get_unmaterialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From e015656a6fbf2102436ef956dd61a3d102167d56 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 8 Dec 2025 16:57:43 -0600 Subject: [PATCH 08/32] don't try to count flops for function calls --- .basedpyright/baseline.json | 8 ---- pytato/analysis/__init__.py | 81 ++++++++++++-------------------- pytato/transform/calls.py | 94 +------------------------------------ test/test_pytato.py | 19 -------- 4 files changed, 32 insertions(+), 170 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 7c557968b..c8f199d78 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -7483,14 +7483,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 42, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 65596c2f4..1658167d2 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -869,15 +869,10 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): def __init__( self, op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None, - _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None ) -> None: - super().__init__(_visited_functions=_visited_functions) + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} - self.call_to_nflops: dict[Call, ArrayOrScalar] = {} - self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ - _function_to_nflops if _function_to_nflops is not None else {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( self.op_name_to_num_flops) @@ -885,50 +880,25 @@ def __init__( def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr - @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr - @override def clone_for_callee(self, function: FunctionDefinition) -> Self: - return type(self)( - op_name_to_num_flops=self.op_name_to_num_flops, - _visited_functions=self._visited_functions, - _function_to_nflops=self._function_to_nflops) + raise AssertionError("Control shouldn't reach this point.") @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return - new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - # Assume that any calls that haven't been inlined have their functions' - # outputs materialized - assert not _is_unmaterialized(subexpr) - new_mapper(subexpr) - - self._function_to_nflops[expr] = ( - sum(new_mapper.materialized_node_to_nflops.values()) - + sum(new_mapper.call_to_nflops.values())) - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def map_call(self, expr: Call) -> None: if not self.visit(expr): return - self.rec_function_definition(expr.function) - for bnd in expr.bindings.values(): - # Assume that any calls that haven't been inlined have their inputs - # materialized - assert not _is_unmaterialized(bnd) - self.rec(bnd) - - self.call_to_nflops[expr] = self._function_to_nflops[expr.function] - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -992,9 +962,8 @@ class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): """ def __init__( self, - op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None) -> None: - super().__init__(_visited_functions=_visited_functions) + op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.unmaterialized_node_to_flop_counts: \ dict[Array, UnmaterializedNodeFlopCounts] = {} @@ -1006,8 +975,24 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -1034,10 +1019,6 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: # FIXME: Should this be added to normalize_outputs? def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - # Make sure call bindings and function results are materialized - from pytato.transform.calls import normalize_calls - expr = normalize_calls(expr) - # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { @@ -1095,7 +1076,7 @@ def get_num_flops( .. note:: - This *does* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1107,9 +1088,7 @@ def get_num_flops( fc = MaterializedNodeFlopCounter(op_name_to_num_flops) fc(expr) - return ( - sum(fc.materialized_node_to_nflops.values()) - + sum(fc.call_to_nflops.values())) + return sum(fc.materialized_node_to_nflops.values()) def get_materialized_node_flop_counts( @@ -1132,7 +1111,7 @@ def get_materialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1173,7 +1152,7 @@ def get_unmaterialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index e82c2b1e0..ffad1e5f7 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,7 +1,6 @@ """ .. currentmodule:: pytato.transform.calls -.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -33,26 +32,20 @@ from typing import TYPE_CHECKING, cast -from immutabledict import immutabledict -from typing_extensions import Never, Self, override +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, Array, - DataWrapper, DictOfNamedArrays, Placeholder, - SizeParam, - make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import ImplStored, InlineCallTag +from pytato.tags import InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, - CombineMapper, CopyMapper, - Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -63,89 +56,6 @@ from collections.abc import Mapping -# {{{ normalizing - -class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never, []]): - """Mapper to collect bindings of calls on the current call stack.""" - @override - def combine(self, *args: frozenset[Array]) -> frozenset[Array]: - from functools import reduce - return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) - - @override - def map_call(self, expr: Call) -> frozenset[Array]: - return frozenset(expr.bindings.values()) - - -class _CallMaterializer(CopyMapper): - """Mapper to add materialization tags for call bindings and function results.""" - def __init__( - self, - local_call_bindings: frozenset[Array], - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None - ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) - self.local_call_bindings: frozenset[Array] = local_call_bindings - - @override - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - local_call_bindings = _LocalStackCallBindingCollector()( - make_dict_of_named_arrays(function.returns)) - return type(self)( - local_call_bindings, - _function_cache=cast( - "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) - - def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: - if ( - isinstance(expr, Array) - and not isinstance(expr, - (DataWrapper, Placeholder, SizeParam, NamedCallResult))): - return expr.tagged(ImplStored()) - else: - return expr - - @override - def map_function_definition(self, - expr: FunctionDefinition) -> FunctionDefinition: - new_mapper = self.clone_for_callee(expr) - new_returns: Mapping[str, Array] = immutabledict({ - name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) - for name, ret in expr.returns.items()}) - return expr.replace_if_different(returns=new_returns) - - @override - def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = cast("ArrayOrNames", Mapper.rec(self, expr)) - if expr in self.local_call_bindings: - result = self._materialize_if_possible(result) - return self._cache_add(inputs, result) - - -def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - """ - Ensure that calls/functions are defined uniformly. - - Adds any missing materialization tags for call bindings and function results. - """ - local_call_bindings = _LocalStackCallBindingCollector()(expr) - return _CallMaterializer(local_call_bindings)(expr) - -# }}} - - # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/test/test_pytato.py b/test/test_pytato.py index 8189aafd7..082a983c0 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -994,25 +994,6 @@ def test_flop_count(): # }}} - # {{{ function call - - x = pt.make_placeholder("x", (10, 4)) - y = pt.make_placeholder("y", (10, 4)) - - def f(x, y): - z = x + y - return 2*z, 3*z - - u, v = pt.trace_call(f, x, y) - expr = u - v - - # u[i, j] = 2*(x[i, j] + y[i, j]) - # v[i, j] = 3*(x[i, j] + y[i, j]) - # expr[i, j] = u[i, j] + (-1)*v[i, j] - assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 - - # }}} - # {{{ einsum x = pt.make_placeholder("x", (2, 3, 4)) From 3b0dea93cef8d921e9f843a48511d446690161a2 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:33:13 -0600 Subject: [PATCH 09/32] move own flop counting out of rec and into its own method in _PerEntryFlopCounter --- .basedpyright/baseline.json | 4 ++-- pytato/analysis/__init__.py | 31 ++++++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index c8f199d78..55befdcac 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1940,8 +1940,8 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 51, - "endColumn": 61, + "startColumn": 34, + "endColumn": 44, "lineCount": 1 } }, diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1658167d2..5755c59d7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -827,6 +827,18 @@ def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: def combine(self, *args: int) -> int: return sum(args) + def _get_own_flop_count(self, expr: Array) -> int: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + return nflops + @override def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) @@ -836,19 +848,12 @@ def rec(self, expr: ArrayOrNames) -> int: result: int if _is_unmaterialized(expr): assert isinstance(expr, Array) - self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - if not isinstance(self_nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None - else: - raise AssertionError from None - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = self_nflops + cast("int", Mapper.rec(self, expr)) + result = ( + self._get_own_flop_count(expr) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + + cast("int", Mapper.rec(self, expr))) else: result = 0 if isinstance(expr, Array): From 58ec2172de85c45ac051b3721a790d177793db36 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:27:46 -0600 Subject: [PATCH 10/32] change is_materializable -> is_materialized / has_taggable_materialization --- pytato/analysis/__init__.py | 47 +++++++++++++++++-------------------- pytato/utils.py | 37 +++++++++++++++++++++++------ test/test_pytato.py | 6 ++++- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5755c59d7..8d38360ec 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,6 +52,7 @@ ShapeType, Stack, ) +from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, @@ -67,7 +68,7 @@ map_and_copy, ) from pytato.transform.lower_to_index_lambda import to_index_lambda -from pytato.utils import is_materializable +from pytato.utils import has_taggable_materialization, is_materialized if TYPE_CHECKING: @@ -799,18 +800,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ flop counting -def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and bool(expr.tags_of_type(ImplStored))) - - -def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and not bool(expr.tags_of_type(ImplStored))) - - @dataclass class UndefinedOpFlopCountError(ValueError): op_name: str @@ -828,7 +817,10 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + try: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + except CannotBeLoweredToIndexLambda: + nflops = 0 if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) @@ -846,8 +838,7 @@ def rec(self, expr: ArrayOrNames) -> int: return self._cache_retrieve(inputs) except KeyError: result: int - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): result = ( self._get_own_flop_count(expr) # Intentionally going to Mapper instead of super() to avoid @@ -907,14 +898,16 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr): return assert isinstance(expr, Array) - unmaterialized_expr = expr.without_tags(ImplStored()) - self._per_entry_flop_counter(unmaterialized_expr) - self.materialized_node_to_nflops[expr] = ( - product(expr.shape) - * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + if has_taggable_materialization(expr): + unmaterialized_expr = expr.without_tags(ImplStored()) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter(unmaterialized_expr)) + else: + self.materialized_node_to_nflops[expr] = 0 class _UnmaterializedSubexpressionUseCounter( @@ -934,8 +927,7 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: return self._cache_retrieve(inputs) except KeyError: result: dict[Array, int] - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -1001,7 +993,7 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr) or not has_taggable_materialization(expr): return assert isinstance(expr, Array) unmaterialized_expr = expr.without_tags(ImplStored()) @@ -1027,7 +1019,10 @@ def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { - ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + ary: ( + ary.tagged(ImplStored()) + if has_taggable_materialization(ary) + else ary) for ary in expr._data.values()} def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: diff --git a/pytato/utils.py b/pytato/utils.py index 92c1a1cbb..d5d377966 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -82,7 +82,8 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str -.. autofunction:: is_materializable +.. autofunction:: is_materialized +.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -740,18 +741,40 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: +def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + """Returns *True* if *expr* is materialized.""" + from pytato.array import InputArgumentBase + from pytato.distributed.nodes import DistributedRecv + from pytato.tags import ImplStored + return ( + ( + isinstance(expr, Array) + and bool(expr.tags_of_type(ImplStored))) + or isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv))) + + +def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: """ - Returns *True* if *expr* is an instance of an array type that can be materialized. + Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to + determine whether or not it is materialized. """ from pytato.array import InputArgumentBase, NamedArray from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder return ( isinstance(expr, Array) - and not isinstance(expr, ( - # FIXME: Is there a nice way to generalize this? - InputArgumentBase, NamedArray, DistributedRecv, - DistributedSendRefHolder))) + and not isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv, + NamedArray, + DistributedSendRefHolder))) # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 082a983c0..1de91c6e6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1029,9 +1029,13 @@ def test_materialized_node_flop_counts(): # z[i, j] = x[i, j] + y[i, j] # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] - assert len(materialized_node_to_flop_count) == 2 + assert len(materialized_node_to_flop_count) == 4 + assert x in materialized_node_to_flop_count + assert y in materialized_node_to_flop_count assert z in materialized_node_to_flop_count assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[x] == 0 + assert materialized_node_to_flop_count[y] == 0 assert materialized_node_to_flop_count[z] == 40 assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 From 32853807de18cf1810b24d32f4f7b5610dd3efbb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 10:54:02 -0600 Subject: [PATCH 11/32] use explicit isinstance() check instead of try/except around to_index_lambda() --- pytato/analysis/__init__.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 8d38360ec..bb705f88a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -42,6 +42,7 @@ ArrayOrScalar, Concatenate, CSRMatmul, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -49,10 +50,11 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, ShapeType, Stack, ) -from pytato.diagnostic import CannotBeLoweredToIndexLambda +from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, @@ -76,7 +78,6 @@ import pytools.tag - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall __doc__ = """ @@ -817,10 +818,16 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - try: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - except CannotBeLoweredToIndexLambda: - nflops = 0 + if isinstance( + expr, + ( + DataWrapper, + Placeholder, + NamedArray, + DistributedRecv, + DistributedSendRefHolder)): + return 0 + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) From bd6f0393baa81281ef44ba6abbe7e731730a396d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 11:45:20 -0600 Subject: [PATCH 12/32] remove a couple of FIXMEs --- pytato/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index d5d377966..f2b50791c 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -753,7 +753,6 @@ def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: or isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv))) @@ -770,7 +769,6 @@ def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> boo and not isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv, NamedArray, From bb7288d693b8494b6e098c9ca95ce8359f6c7c94 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 26 Jan 2026 12:27:00 -0600 Subject: [PATCH 13/32] add placeholder class for operator flop counts that aren't specified --- pytato/analysis/__init__.py | 15 +++++++++------ pytato/scalar_expr.py | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index bb705f88a..5c1c8ad90 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -829,13 +829,16 @@ def _get_own_flop_count(self, expr: Array) -> int: return 0 nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None + # Restricting to numerical result here because the flop counters that use + # this mapper subsequently multiply the result by things that are + # potentially arrays (e.g., shape components), and arrays and scalar + # expressions are not interoperable + from pytato.scalar_expr import OpFlops, OpFlopsCollector + op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + if op_flops: + raise UndefinedOpFlopCountError(next(iter(op_flops)).op) else: - raise AssertionError from None + raise AssertionError return nflops @override diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 524d4e7a7..022517a44 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -45,6 +45,7 @@ import re from collections.abc import Iterable, Mapping, Set as AbstractSet +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -269,8 +270,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression: try: return self.op_name_to_num_flops[name] except KeyError: - from pymbolic import var - result = var("nflops")(var(name)) + result = OpFlops(name) self.op_name_to_num_flops[name] = result return result @@ -483,9 +483,42 @@ class TypeCast(ExpressionBase): dtype: np.dtype[Any] inner_expr: ScalarExpression + +@expr_dataclass() +class OpFlops(prim.AlgebraicLeaf): + """ + Placeholder flop count for an operator. + + .. autoattribute:: op + """ + op: str + # }}} +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[OpFlops]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[OpFlops]: + return frozenset() + + class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: from functools import reduce From c05cd93b350613f514ffebc86f63732f2a617bc2 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 13 Mar 2026 10:26:03 -0500 Subject: [PATCH 14/32] fix swapped and/or op names Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pytato/reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index a229328c8..1cb016def 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -169,7 +169,7 @@ class AllReductionOperation(_StatelessReductionOperation): @override @classmethod def scalar_op_name(cls): - return "or" + return "and" def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) @@ -179,7 +179,7 @@ class AnyReductionOperation(_StatelessReductionOperation): @override @classmethod def scalar_op_name(cls): - return "and" + return "or" def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) From e822226c18c2f764c43fe4bf861504241f5d5334 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 13 Mar 2026 10:30:21 -0500 Subject: [PATCH 15/32] fix flop counting for negative power Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pytato/scalar_expr.py | 2 +- test/test_pytato.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 022517a44..1a09fc5a0 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -328,7 +328,7 @@ def map_power(self, expr: prim.Power) -> ArithmeticExpression: else: return ( self._get_op_nflops("/") - + expr.exponent * self._get_op_nflops("*") + + (-expr.exponent) * self._get_op_nflops("*") + self.rec(expr.base)) else: return ( diff --git a/test/test_pytato.py b/test/test_pytato.py index 1de91c6e6..47d490663 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -872,6 +872,7 @@ def test_scalar_flop_count(): assert fc(x % 2) == 0 assert fc(x ** 3) == 3 + assert fc(x ** (-3)) == 7 assert fc(x ** 0.3) == 8 assert fc(x.lt(y)) == 1 From ede5e94175b5350f91fbe5489a387e92456cd790 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 13 Mar 2026 10:59:34 -0500 Subject: [PATCH 16/32] make UndefinedOpFlopCountError not a dataclass to avoid issues with __init__ --- pytato/analysis/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5c1c8ad90..a03ffd978 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -801,9 +801,8 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ flop counting -@dataclass class UndefinedOpFlopCountError(ValueError): - op_name: str + pass class _PerEntryFlopCounter(CombineMapper[int, Never, []]): @@ -836,7 +835,9 @@ def _get_own_flop_count(self, expr: Array) -> int: from pytato.scalar_expr import OpFlops, OpFlopsCollector op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) if op_flops: - raise UndefinedOpFlopCountError(next(iter(op_flops)).op) + op_name = next(iter(op_flops)).op + raise UndefinedOpFlopCountError( + f"Undefined flop count for operation '{op_name}'.") else: raise AssertionError return nflops From 45d23cc3ae6f8d2142b53b3f7eac1c13e0e7fc17 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 13 Mar 2026 14:21:33 -0500 Subject: [PATCH 17/32] update note about function traversal for MaterializedNodeFlopCounter --- pytato/analysis/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a03ffd978..aac132d16 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -870,8 +870,7 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): .. note:: - Flops from nodes inside function calls are accumulated onto the corresponding - call node. + This mapper does not descend into functions. """ def __init__( self, From f9956e4adcb7ccb3f0d4f637ca394f402aff8b96 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:21:39 -0500 Subject: [PATCH 18/32] add MaterializedNodeCollector/collect_materialized_nodes --- pytato/analysis/__init__.py | 82 +++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index aac132d16..a15f42b99 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -102,6 +102,9 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type +.. autoclass:: MaterializedNodeCollector +.. autofunction:: collect_materialized_nodes + .. autoclass:: UndefinedOpFlopCountError .. autofunction:: get_default_op_name_to_num_flops .. autofunction:: get_num_flops @@ -770,6 +773,85 @@ def get_num_tags_of_type( # }}} +# {{{ MaterializedNodeCollector + +class MaterializedNodeCollector(CachedWalkMapper[[]]): + """ + Return the nodes in a DAG that are materialized. + + See :func:`collect_materialized_nodes` for more details. + """ + def __init__( + self, + include_outputs: bool = True, + _visited_functions: set[Any] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + self.include_outputs: bool = include_outputs + self.materialized_nodes: set[Array] = set() + + @overload + def __call__(self, expr: ArrayOrNames) -> None: + ... + + @overload + def __call__(self, expr: FunctionDefinition) -> None: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + ) -> None: + super().__call__(expr) + + if self.include_outputs: + if isinstance(expr, DictOfNamedArrays): + self.materialized_nodes.update(expr._data.values()) + elif isinstance(expr, Array): + self.materialized_nodes.add(expr) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not isinstance(expr, Array): + return + + if ( + isinstance( + expr, ( + InputArgumentBase, + DistributedRecv, + CSRMatmul)) + or expr.tags_of_type(ImplStored)): + self.materialized_nodes.add(expr) + elif isinstance(expr, DistributedSendRefHolder): + self.materialized_nodes.add(expr.send.data) + + +def collect_materialized_nodes( + expr: ArrayOrNames | FunctionDefinition, + include_outputs: bool = True) -> frozenset[Array]: + """ + Return the nodes in DAG *expr* that are materialized. + + The result includes inputs, outputs (optionally), arrays tagged with + `pytato.tags.ImplStored`, as well as special arrays that are always materialized + (:class:`pytato.DistributedSend`'s data, for example). + """ + mac = MaterializedNodeCollector(include_outputs=include_outputs) + mac(expr) + return frozenset(mac.materialized_nodes) + +# }}} + + # {{{ PytatoKeyBuilder class PytatoKeyBuilder(LoopyKeyBuilder): From 8b60df2e07619f8011e7426b03eff62890559c09 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:24:41 -0500 Subject: [PATCH 19/32] add InputUseCounter in scalar_expr --- pytato/scalar_expr.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 1a09fc5a0..56b0056ca 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -44,6 +44,7 @@ import re +from collections import defaultdict from collections.abc import Iterable, Mapping, Set as AbstractSet from functools import reduce from typing import ( @@ -79,6 +80,7 @@ from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer +from pytools import product if TYPE_CHECKING: @@ -254,6 +256,36 @@ def map_variable(self, expr: prim.Variable) -> set[str]: return {expr.name} +class InputUseCounter(CombineMapper[dict[str, ArithmeticExpression], []]): + @override + def combine( + self, values: Iterable[dict[str, ArithmeticExpression]] + ) -> dict[str, ArithmeticExpression]: + result: dict[str, ArithmeticExpression] = defaultdict(int) + for val in values: + for name, nuses in val.items(): + result[name] += nuses + return result + + @override + def map_constant(self, expr: object) -> dict[str, ArithmeticExpression]: + return {} + + @override + def map_variable(self, expr: prim.Variable) -> dict[str, ArithmeticExpression]: + return {expr.name: 1} + + @override + def map_reduce(self, expr: Reduce) -> dict[str, ArithmeticExpression]: + inner_expr_result = self.rec(expr.inner_expr) + niters = product(( + upper_bd - lower_bd + for lower_bd, upper_bd in expr.bounds.values())) + return { + name: niters * inner_expr_nuses + for name, inner_expr_nuses in inner_expr_result.items()} + + class FlopCounter(FlopCounterBase): op_name_to_num_flops: dict[str, ArithmeticExpression] From 4b0978ceaa91258257e3a2e844aec9607b026830 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 25 Mar 2026 09:31:39 -0500 Subject: [PATCH 20/32] add MapAsIndexLambdaMixin --- pytato/transform/lower_to_index_lambda.py | 66 ++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 3537b0486..59d0f0026 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -2,6 +2,7 @@ .. currentmodule:: pytato.transform.lower_to_index_lambda .. autofunction:: to_index_lambda +.. autoclass:: MapAsIndexLambdaMixin """ from __future__ import annotations @@ -28,8 +29,9 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from constantdict import constantdict from typing_extensions import Never @@ -65,6 +67,8 @@ from pytato.tags import AssumeNonNegative from pytato.transform import ( Mapper, + P, + ResultT, _verify_is_array, ) from pytato.utils import normalized_slice_does_not_change_axis @@ -772,4 +776,64 @@ def to_index_lambda(expr: Array) -> IndexLambda: assert isinstance(res, IndexLambda) return res + +class MapAsIndexLambdaMixin(ABC, Generic[ResultT, P]): + """ + Mixin that, where possible, lowers arrays to :class:`~pytato.array.IndexLambda` + and calls :meth:`map_as_index_lambda` on them. + + .. automethod:: map_as_index_lambda + """ + @abstractmethod + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + """ + Map *expr* via its :class:`~pytato.array.IndexLambda` representation + *idx_lambda*. + """ + + def map_index_lambda( + self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, expr, *args, **kwargs) + + def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_axis_permutation( + self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_basic_index( + self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_concatenate( + self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + # vim:fdm=marker From 28c4c9b7427e752e87ac8583816faf7118026dfe Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:53:57 -0500 Subject: [PATCH 21/32] abandon is_materialized/has_taggable_materialization approach in favor of collecting materialized arrays ahead of time the former doesn't work because some materialized arrays can only be identified by context (e.g., the 'data' attribute of a DistributedSend). --- .basedpyright/baseline.json | 16 -- pytato/analysis/__init__.py | 379 +++++++++++++++++++++++++----------- pytato/utils.py | 40 +--- test/test_pytato.py | 23 +-- 4 files changed, 275 insertions(+), 183 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 55befdcac..49915bc9f 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1936,22 +1936,6 @@ "endColumn": 25, "lineCount": 1 } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 34, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 56, - "endColumn": 66, - "lineCount": 1 - } } ], "./pytato/array.py": [ diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a15f42b99..0e0fe0159 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,7 +28,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override @@ -62,15 +62,15 @@ from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, - ArrayOrNamesTc, CachedWalkMapper, + CacheKeyT, CombineMapper, Mapper, VisitKeyT, - map_and_copy, ) -from pytato.transform.lower_to_index_lambda import to_index_lambda -from pytato.utils import has_taggable_materialization, is_materialized +from pytato.transform.lower_to_index_lambda import ( + MapAsIndexLambdaMixin, +) if TYPE_CHECKING: @@ -78,7 +78,7 @@ import pytools.tag - from pytato.loopy import LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult __doc__ = """ .. currentmodule:: pytato.analysis @@ -887,62 +887,131 @@ class UndefinedOpFlopCountError(ValueError): pass -class _PerEntryFlopCounter(CombineMapper[int, Never, []]): - def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: +class _NonIntegralPerEntryFlopCountError(ValueError): + pass + + +class _PerEntryFlopCounter( + MapAsIndexLambdaMixin[int, [bool]], + CombineMapper[int, Never, [bool]]): + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: super().__init__() self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( op_name_to_num_flops) - self.node_to_nflops: dict[Array, int] = {} + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> int: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> int: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root @override def combine(self, *args: int) -> int: return sum(args) - def _get_own_flop_count(self, expr: Array) -> int: - if isinstance( - expr, - ( - DataWrapper, - Placeholder, - NamedArray, - DistributedRecv, - DistributedSendRefHolder)): + @override + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool) -> int: + if expr in self.materialized_nodes and not is_root: return 0 - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - if not isinstance(nflops, int): + + self_nflops = self.scalar_flop_counter(idx_lambda.expr) + + if not isinstance(self_nflops, int): # Restricting to numerical result here because the flop counters that use # this mapper subsequently multiply the result by things that are # potentially arrays (e.g., shape components), and arrays and scalar # expressions are not interoperable from pytato.scalar_expr import OpFlops, OpFlopsCollector - op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + op_flops: frozenset[OpFlops] = OpFlopsCollector()(self_nflops) if op_flops: op_name = next(iter(op_flops)).op raise UndefinedOpFlopCountError( f"Undefined flop count for operation '{op_name}'.") else: - raise AssertionError - return nflops + raise _NonIntegralPerEntryFlopCountError( + "Unable to compute an integer-valued per-entry flop count.") + + return self.combine( + self_nflops, + *( + self.rec(bnd, False) + for _, bnd in sorted(idx_lambda.bindings.items()))) @override - def rec(self, expr: ArrayOrNames) -> int: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - result: int - if isinstance(expr, Array) and not is_materialized(expr): - result = ( - self._get_own_flop_count(expr) - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - + cast("int", Mapper.rec(self, expr))) - else: - result = 0 - if isinstance(expr, Array): - self.node_to_nflops[expr] = result - return self._cache_add(inputs, result) + def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: + return 0 + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> int: + return 0 + + @override + def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> int: + raise NotImplementedError + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> int: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, is_root: bool) -> int: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> int: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result(self, expr: LoopyCallResult, is_root: bool) -> int: + return 0 + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> int: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv(self, expr: DistributedRecv, is_root: bool) -> int: + return 0 + + @override + def map_call(self, expr: Call, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result(self, expr: NamedCallResult, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): @@ -957,12 +1026,14 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): def __init__( self, op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array], ) -> None: super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( - self.op_name_to_num_flops) + self.op_name_to_num_flops, self.materialized_nodes) @override def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: @@ -990,20 +1061,59 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not is_materialized(expr): + if expr not in self.materialized_nodes: return assert isinstance(expr, Array) - if has_taggable_materialization(expr): - unmaterialized_expr = expr.without_tags(ImplStored()) - self.materialized_node_to_nflops[expr] = ( - product(expr.shape) - * self._per_entry_flop_counter(unmaterialized_expr)) + try: + nflops_per_entry = self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: - self.materialized_node_to_nflops[expr] = 0 + self.materialized_node_to_nflops[expr] = ( + nflops_per_entry + * product(expr.shape)) class _UnmaterializedSubexpressionUseCounter( - CombineMapper[dict[Array, int], Never, []]): + MapAsIndexLambdaMixin[dict[Array, int], [bool]], + CombineMapper[dict[Array, int], Never, [bool]]): + def __init__( + self, + materialized_nodes: frozenset[Array]) -> None: + super().__init__() + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> dict[Array, int]: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> dict[Array, int]: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root + @override def combine(self, *args: dict[Array, int]) -> dict[Array, int]: result: dict[Array, int] = defaultdict(int) @@ -1013,21 +1123,70 @@ def combine(self, *args: dict[Array, int]) -> dict[Array, int]: return result @override - def rec(self, expr: ArrayOrNames) -> dict[Array, int]: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - result: dict[Array, int] - if isinstance(expr, Array) and not is_materialized(expr): - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = self.combine( - {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) - else: - result = {} - return self._cache_add(inputs, result) + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool + ) -> dict[Array, int]: + if expr in self.materialized_nodes and not is_root: + return {} + + return self.combine( + {expr: 1} if not is_root else {}, + *( + self.rec(bnd, False) + for _, bnd in sorted(idx_lambda.bindings.items()))) + + @override + def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> dict[Array, int]: + raise NotImplementedError + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> dict[Array, int]: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays, is_root: bool) -> dict[Array, int]: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> dict[Array, int]: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result( + self, expr: LoopyCallResult, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> dict[Array, int]: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv( + self, expr: DistributedRecv, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_call(self, expr: Call, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result( + self, expr: NamedCallResult, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") @dataclass @@ -1051,13 +1210,18 @@ class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): """ def __init__( self, - op_name_to_num_flops: Mapping[str, int]) -> None: + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes self.unmaterialized_node_to_flop_counts: \ dict[Array, UnmaterializedNodeFlopCounts] = {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( - self.op_name_to_num_flops) + self.op_name_to_num_flops, self.materialized_nodes) + self._use_counter: _UnmaterializedSubexpressionUseCounter = \ + _UnmaterializedSubexpressionUseCounter( + self.materialized_nodes) @override def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: @@ -1085,49 +1249,33 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not is_materialized(expr) or not has_taggable_materialization(expr): + if expr not in self.materialized_nodes: return assert isinstance(expr, Array) - unmaterialized_expr = expr.without_tags(ImplStored()) - subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( - unmaterialized_expr) - del subexpr_to_nuses[unmaterialized_expr] - self._per_entry_flop_counter(unmaterialized_expr) - for subexpr, nuses in subexpr_to_nuses.items(): - per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] - if subexpr not in self.unmaterialized_node_to_flop_counts: - nflops_if_materialized = product(subexpr.shape) * per_entry_nflops - flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) - self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts - else: - flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] - assert expr not in flop_counts.materialized_successor_to_contrib_nflops - flop_counts.materialized_successor_to_contrib_nflops[expr] = ( - nuses * product(expr.shape) * per_entry_nflops) - - -# FIXME: Should this be added to normalize_outputs? -def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - # Make sure outputs are materialized - if isinstance(expr, DictOfNamedArrays): - output_to_materialized_output: dict[Array, Array] = { - ary: ( - ary.tagged(ImplStored()) - if has_taggable_materialization(ary) - else ary) - for ary in expr._data.values()} - - def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: - if not isinstance(ary, Array): - return ary - try: - return output_to_materialized_output[ary] - except KeyError: - return ary - - expr = map_and_copy(expr, replace_with_materialized) - - return expr + try: + self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e + else: + subexpr_to_nuses = self._use_counter(expr) + for subexpr, nuses in subexpr_to_nuses.items(): + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = nflops_per_entry * product(subexpr.shape) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses + * nflops_per_entry + * product(expr.shape)) def get_default_op_name_to_num_flops() -> dict[str, int]: @@ -1172,12 +1320,13 @@ def get_num_flops( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc = MaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) fc(expr) return sum(fc.materialized_node_to_nflops.values()) @@ -1207,12 +1356,13 @@ def get_materialized_node_flop_counts( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc = MaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) fc(expr) return fc.materialized_node_to_nflops @@ -1248,12 +1398,13 @@ def get_unmaterialized_node_flop_counts( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops, materialized_nodes) fc(expr) return fc.unmaterialized_node_to_flop_counts diff --git a/pytato/utils.py b/pytato/utils.py index f2b50791c..d7bc8eade 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import ArrayOrNames, CachedMapper +from pytato.transform import CachedMapper if TYPE_CHECKING: @@ -69,8 +69,6 @@ from pytools.tag import Tag - from pytato.function import FunctionDefinition - __doc__ = """ Helper routines @@ -82,8 +80,6 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str -.. autofunction:: is_materialized -.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -741,38 +737,4 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - """Returns *True* if *expr* is materialized.""" - from pytato.array import InputArgumentBase - from pytato.distributed.nodes import DistributedRecv - from pytato.tags import ImplStored - return ( - ( - isinstance(expr, Array) - and bool(expr.tags_of_type(ImplStored))) - or isinstance( - expr, - ( - InputArgumentBase, - DistributedRecv))) - - -def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: - """ - Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to - determine whether or not it is materialized. - """ - from pytato.array import InputArgumentBase, NamedArray - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - return ( - isinstance(expr, Array) - and not isinstance( - expr, - ( - InputArgumentBase, - DistributedRecv, - NamedArray, - DistributedSendRefHolder))) - - # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 47d490663..02c0aaacd 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1034,16 +1034,15 @@ def test_materialized_node_flop_counts(): assert x in materialized_node_to_flop_count assert y in materialized_node_to_flop_count assert z in materialized_node_to_flop_count - assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count assert materialized_node_to_flop_count[x] == 0 assert materialized_node_to_flop_count[y] == 0 assert materialized_node_to_flop_count[z] == 40 - assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + assert materialized_node_to_flop_count[expr] == 40*4 def test_unmaterialized_node_flop_counts(): from pytato.analysis import get_unmaterialized_node_flop_counts - from pytato.tags import ImplStored x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1058,8 +1057,6 @@ def test_unmaterialized_node_flop_counts(): unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) - materialized_expr = expr.tagged(ImplStored()) - # Everything except expr stays unmaterialized assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 assert z in unmaterialized_node_to_flop_counts @@ -1067,23 +1064,21 @@ def test_unmaterialized_node_flop_counts(): assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) flop_counts = unmaterialized_node_to_flop_counts[z] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*10 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*10 assert flop_counts.nflops_if_materialized == 40 for w_i in w: flop_counts = unmaterialized_node_to_flop_counts[w_i] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*2 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*2 assert flop_counts.nflops_if_materialized == 40*2 for i, s_i in enumerate(s): flop_counts = unmaterialized_node_to_flop_counts[s_i] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*2*(i+1) + 40*i + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == \ + 40*2*(i+1) + 40*i assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i From 74478165f33e444dc067d1ebafa239b3b0ac20de Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:57:02 -0500 Subject: [PATCH 22/32] fix flop counting for reductions --- pytato/analysis/__init__.py | 27 ++++++++++++++++++++++----- test/test_pytato.py | 12 ++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 0e0fe0159..630c46394 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,7 +28,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override @@ -58,6 +58,7 @@ from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, + InputUseCounter as ScalarInputUseCounter, ) from pytato.tags import ImplStored from pytato.transform import ( @@ -958,11 +959,17 @@ def map_as_index_lambda( raise _NonIntegralPerEntryFlopCountError( "Unable to compute an integer-valued per-entry flop count.") + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + # self_nflops check above should take care of non-constant use count case + assert all( + isinstance(nuses, int) + for nuses in binding_to_nuses.values()) + return self.combine( self_nflops, *( - self.rec(bnd, False) - for _, bnd in sorted(idx_lambda.bindings.items()))) + cast("int", binding_to_nuses[name]) * self.rec(bnd, False) + for name, bnd in sorted(idx_lambda.bindings.items()))) @override def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: @@ -1129,11 +1136,21 @@ def map_as_index_lambda( if expr in self.materialized_nodes and not is_root: return {} + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + if any( + not isinstance(nuses, int) + for nuses in binding_to_nuses.values()): + raise ValueError( + "Unable to compute integer-valued use counts for the predecessors " + f"of array of type '{type(expr).__name__}'.") + return self.combine( {expr: 1} if not is_root else {}, *( - self.rec(bnd, False) - for _, bnd in sorted(idx_lambda.bindings.items()))) + { + ary: cast("int", binding_to_nuses[name]) * nuses + for ary, nuses in self.rec(bnd, False).items()} + for name, bnd in sorted(idx_lambda.bindings.items()))) @override def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: diff --git a/test/test_pytato.py b/test/test_pytato.py index 02c0aaacd..da8f4288a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -999,17 +999,17 @@ def test_flop_count(): x = pt.make_placeholder("x", (2, 3, 4)) y = pt.make_placeholder("y", (3, 4)) - expr = pt.einsum("ijk,jk->ijk", x, y) + expr = pt.einsum("ijk,jk->ijk", 2*x, 3*y) - # expr[i, j, k] = x[i, j, k] * y[j, k] - assert get_num_flops(expr) == 24 + # expr[i, j, k] = 2*x[i, j, k] * 3*y[j, k] + assert get_num_flops(expr) == 72 x = pt.make_placeholder("x", (2, 3, 4)) y = pt.make_placeholder("y", (3, 4)) - expr = pt.einsum("ijk,jk->i", x, y) + expr = pt.einsum("ijk,jk->i", 2*x, 3*y) - # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) - assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + # expr[i] = sum(sum(2*x[i, j, k] * 3*y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*3 + 2) + 3) # }}} From 567d181aba6e3b2f4094e57f3986841aa13f3a38 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 24 Mar 2026 14:56:55 -0500 Subject: [PATCH 23/32] handle CSR matmul in flop counting --- pytato/analysis/__init__.py | 80 ++++++++++++++++++----- test/test_pytato.py | 126 +++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 17 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 630c46394..079e59e10 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -979,10 +979,6 @@ def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> int: return 0 - @override - def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> int: - raise NotImplementedError - @override def map_named_array(self, expr: NamedArray, is_root: bool) -> int: assert isinstance(expr._container, DictOfNamedArrays) @@ -1074,10 +1070,31 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: try: nflops_per_entry = self._per_entry_flop_counter(expr) except _NonIntegralPerEntryFlopCountError as e: - raise ValueError( - "Unable to compute a flop count for array of type " - f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " - "non-integer-valued.") from e + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + nflops_self = ( + ( + nelems # multiplies + + nelems - expr.shape[0]) # adds + * product(expr.shape[1:])) + nflops_children = ( + nelems + * ( + self._per_entry_flop_counter(expr.matrix.elem_values) + + self._per_entry_flop_counter(expr.array)) + * product(expr.shape[1:])) + self.materialized_node_to_nflops[expr] = ( + nflops_self + nflops_children) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: self.materialized_node_to_nflops[expr] = ( nflops_per_entry @@ -1160,10 +1177,6 @@ def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> dict[Array, int]: return {} - @override - def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> dict[Array, int]: - raise NotImplementedError - @override def map_named_array(self, expr: NamedArray, is_root: bool) -> dict[Array, int]: assert isinstance(expr._container, DictOfNamedArrays) @@ -1272,10 +1285,45 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: try: self._per_entry_flop_counter(expr) except _NonIntegralPerEntryFlopCountError as e: - raise ValueError( - "Unable to compute a flop count for array of type " - f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " - "non-integer-valued.") from e + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + subexpr_to_nuses_per_elem = self._use_counter.combine( + self._use_counter(expr.matrix.elem_values, is_root=False), + self._use_counter(expr.array, is_root=False)) + for subexpr, nuses_per_elem in subexpr_to_nuses_per_elem.items(): + try: + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + except _NonIntegralPerEntryFlopCountError: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(subexpr).__name__}'; per-entry flop count is " + "unexpectedly non-integer-valued.") from e + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = ( + nflops_per_entry * product(subexpr.shape)) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in \ + flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses_per_elem + * nelems + * nflops_per_entry + * product(expr.shape[1:])) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: subexpr_to_nuses = self._use_counter(expr) for subexpr, nuses in subexpr_to_nuses.items(): diff --git a/test/test_pytato.py b/test/test_pytato.py index da8f4288a..6f107c4af 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1013,11 +1013,60 @@ def test_flop_count(): # }}} + # {{{ CSR matmul (trivial predecessors) + + x = pt.make_csr_matrix( + shape=(8, 10), + elem_values=pt.make_placeholder("x_elem_values", (16,)), + elem_col_indices=pt.make_placeholder("x_elem_col_indices", (16,)), + row_starts=pt.make_placeholder("x_row_starts", (9,))) + y = pt.make_placeholder("y", (10, 5, 3)) + expr = x @ y + + assert get_num_flops(expr) == 5*3*( + 16 # multiplies + + 16 - 8 # adds + ) + + # }}} + + # {{{ CSR matmul (nontrivial predecessors) + + elem_values = pt.zeros(12) + 1 + elem_col_indices = pt.make_data_wrapper(np.array([ + 0, + 0, 1, + 0, 1, 2, + 1, 2, 3, + 2, 3, + 3])) + row_starts = pt.make_data_wrapper(np.array([0, 1, 3, 6, 9, 11, 12])) + x = pt.make_csr_matrix( + shape=(6, 4), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.zeros((4, 3, 2)) + 1 + expr = x @ y + + assert get_num_flops(expr) == 3*2*( + 3 # row 1 + + 3*2 + 1 # row 2 + + 3*3 + 2 # row 3 + + 3*3 + 2 # row 4 + + 3*2 + 1 # row 5 + + 3 # row 6 + ) + + # }}} + def test_materialized_node_flop_counts(): from pytato.analysis import get_materialized_node_flop_counts from pytato.tags import ImplStored + # {{{ basic DAG + x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1040,10 +1089,52 @@ def test_materialized_node_flop_counts(): assert materialized_node_to_flop_count[z] == 40 assert materialized_node_to_flop_count[expr] == 40*4 + # }}} + + # {{{ CSR matmul + + zeros_10 = pt.make_data_wrapper(np.zeros(10)) + zeros_20 = pt.make_data_wrapper(np.zeros(20)) + elem_values = zeros_20 + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = zeros_10 + 1 + z = x @ y + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + assert len(materialized_node_to_flop_count) == 6 + assert zeros_20 in materialized_node_to_flop_count + assert elem_col_indices in materialized_node_to_flop_count + assert row_starts in materialized_node_to_flop_count + assert zeros_10 in materialized_node_to_flop_count + assert z in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count + assert materialized_node_to_flop_count[zeros_20] == 0 + assert materialized_node_to_flop_count[elem_col_indices] == 0 + assert materialized_node_to_flop_count[row_starts] == 0 + assert materialized_node_to_flop_count[zeros_10] == 0 + # flops from elem_values/y/z (2 elems per row so y gets used twice) + assert materialized_node_to_flop_count[z] == 20 + 20 + 20 + 10 + # flops from u/v/expr (no flops from z because it's materialized) + assert materialized_node_to_flop_count[expr] == 40 + + # }}} + def test_unmaterialized_node_flop_counts(): from pytato.analysis import get_unmaterialized_node_flop_counts + # {{{ basic DAG + x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1057,7 +1148,7 @@ def test_unmaterialized_node_flop_counts(): unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) - # Everything except expr stays unmaterialized + # Everything except x/y/expr stays unmaterialized assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 assert z in unmaterialized_node_to_flop_counts assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) @@ -1081,6 +1172,39 @@ def test_unmaterialized_node_flop_counts(): 40*2*(i+1) + 40*i assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + # }}} + + # {{{ CSR matmul + + elem_values = pt.make_data_wrapper(np.zeros(20)) + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.make_data_wrapper(np.zeros(10)) + 1 + expr = x @ y + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + assert len(unmaterialized_node_to_flop_counts) == 2 + assert elem_values in unmaterialized_node_to_flop_counts + assert y in unmaterialized_node_to_flop_counts + flop_counts = unmaterialized_node_to_flop_counts[elem_values] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 20 + flop_counts = unmaterialized_node_to_flop_counts[y] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 10 + + # }}} + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) From dfb31f657ecc28acefda752e114ad392491b582f Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 27 Mar 2026 10:18:54 -0500 Subject: [PATCH 24/32] handle unused bindings Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pytato/analysis/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 079e59e10..c0e115404 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -968,7 +968,7 @@ def map_as_index_lambda( return self.combine( self_nflops, *( - cast("int", binding_to_nuses[name]) * self.rec(bnd, False) + cast("int", binding_to_nuses.get(name, 0)) * self.rec(bnd, False) for name, bnd in sorted(idx_lambda.bindings.items()))) @override @@ -1165,7 +1165,7 @@ def map_as_index_lambda( {expr: 1} if not is_root else {}, *( { - ary: cast("int", binding_to_nuses[name]) * nuses + ary: cast("int", binding_to_nuses.get(name, 0)) * nuses for ary, nuses in self.rec(bnd, False).items()} for name, bnd in sorted(idx_lambda.bindings.items()))) From 177e50a490bf168d69b3430d3d7937cdec35b6a4 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 27 Mar 2026 10:50:49 -0500 Subject: [PATCH 25/32] enumerate all undefined op flop counts in exception message Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pytato/analysis/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c0e115404..0a898db1c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -952,9 +952,10 @@ def map_as_index_lambda( from pytato.scalar_expr import OpFlops, OpFlopsCollector op_flops: frozenset[OpFlops] = OpFlopsCollector()(self_nflops) if op_flops: - op_name = next(iter(op_flops)).op + op_names = sorted({of.op for of in op_flops}) + formatted_ops = ", ".join(f"'{name}'" for name in op_names) raise UndefinedOpFlopCountError( - f"Undefined flop count for operation '{op_name}'.") + f"Undefined flop count for operation(s): {formatted_ops}.") else: raise _NonIntegralPerEntryFlopCountError( "Unable to compute an integer-valued per-entry flop count.") From 4198cf0dea1f4bd91cc255a4e37dcb67c7b5e07c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 10:36:11 -0500 Subject: [PATCH 26/32] treat LoopyCallResult as materialized --- pytato/analysis/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 0a898db1c..d903477ee 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -56,6 +56,7 @@ ) from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.loopy import LoopyCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, InputUseCounter as ScalarInputUseCounter, @@ -79,7 +80,7 @@ import pytools.tag - from pytato.loopy import LoopyCall, LoopyCallResult + from pytato.loopy import LoopyCall __doc__ = """ .. currentmodule:: pytato.analysis @@ -829,6 +830,7 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: expr, ( InputArgumentBase, DistributedRecv, + LoopyCallResult, CSRMatmul)) or expr.tags_of_type(ImplStored)): self.materialized_nodes.add(expr) From 2f5cd3af9b9ed7bbb5181a82fb0ecf916cb5b51e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 10:36:42 -0500 Subject: [PATCH 27/32] forbid loopy calls in flop counting --- pytato/analysis/__init__.py | 38 +++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d903477ee..2a38dc53c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -993,11 +993,13 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, is_root: bool) -> in @override def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> int: + # Shouldn't have loopy calls raise AssertionError("Control shouldn't reach here.") @override def map_loopy_call_result(self, expr: LoopyCallResult, is_root: bool) -> int: - return 0 + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") @override def map_distributed_send_ref_holder( @@ -1049,6 +1051,21 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: def clone_for_callee(self, function: FunctionDefinition) -> Self: raise AssertionError("Control shouldn't reach this point.") + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): @@ -1192,12 +1209,14 @@ def map_dict_of_named_arrays( @override def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> dict[Array, int]: + # Shouldn't have loopy calls raise AssertionError("Control shouldn't reach here.") @override def map_loopy_call_result( self, expr: LoopyCallResult, is_root: bool) -> dict[Array, int]: - return {} + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") @override def map_distributed_send_ref_holder( @@ -1264,6 +1283,21 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: def clone_for_callee(self, function: FunctionDefinition) -> Self: raise AssertionError("Control shouldn't reach this point.") + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): From f2244f2902d06077c8c4268c57d51e5642bc12cb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 10:55:32 -0500 Subject: [PATCH 28/32] add type annotation for scalar_op_name overrides --- pytato/reductions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index 1cb016def..ab8f265aa 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -118,7 +118,7 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "+" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -128,7 +128,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class ProductReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "*" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -138,7 +138,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MaxReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "max" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -153,7 +153,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "min" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -168,7 +168,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "and" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -178,7 +178,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AnyReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "or" def neutral_element(self, dtype: np.dtype[Any]) -> Any: From 564f09ea264f60d6b0068b5dbeacac52a9f51b99 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 16:48:01 -0500 Subject: [PATCH 29/32] make power flop count more precise --- pytato/scalar_expr.py | 40 +++++++++++++++++++++++++++++++++------- test/test_pytato.py | 11 +++++++++-- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 56b0056ca..d4533fa95 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -353,15 +353,41 @@ def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: @override def map_power(self, expr: prim.Power) -> ArithmeticExpression: if isinstance(expr.exponent, int): - if expr.exponent >= 0: + # The calculation below is based on the following code (which is an + # approximation of what is done in loopy) + # def pow(x, n): + # if n == 0: return 1 + # if n == 1: return x + # if n == 2: return x*x + # if n < 0: + # x = 1/x + # n = -n + # y = 1 + # while n > 1: + # if n % 2: + # y = x * y + # x = x * x + # n = n/2 + # return x*y + if expr.exponent == 0: + return 0 + elif expr.exponent > 0 and expr.exponent <= 2: return ( - expr.exponent * self._get_op_nflops("*") - + self.rec(expr.base)) - else: - return ( - self._get_op_nflops("/") - + (-expr.exponent) * self._get_op_nflops("*") + (expr.exponent - 1) * self._get_op_nflops("*") + self.rec(expr.base)) + nmults = 1 + remaining_exp = abs(expr.exponent) + while remaining_exp > 1: + if remaining_exp % 2: + nmults += 1 + nmults += 1 + remaining_exp //= 2 + nflops = ( + nmults * self._get_op_nflops("*") + + self.rec(expr.base)) + if expr.exponent < 0: + nflops += self._get_op_nflops("/") + return nflops else: return ( self._get_op_nflops("**") diff --git a/test/test_pytato.py b/test/test_pytato.py index 6f107c4af..67cc33880 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -871,9 +871,16 @@ def test_scalar_flop_count(): assert fc(x % 2) == 0 - assert fc(x ** 3) == 3 - assert fc(x ** (-3)) == 7 + assert fc(x ** 0) == 0 + assert fc(x ** 1) == 0 + # x * x + assert fc(x ** 2) == 1 + # compute x^2, x^4, x^8, x^16, x^32; multiply x^32 * x^16 * x^8 * x^4 * x * 1 + assert fc(x ** 61) == 5 + 5 + # divide; compute x^2, x^4, x^8, x^16; multiply x^16 * x^4 * x^2 * 1 + assert fc(x ** -22) == 4 + 4 + 3 assert fc(x ** 0.3) == 8 + assert fc(x ** y) == 8 assert fc(x.lt(y)) == 1 From 7ca7a50eef9a8425bb34438eb1b628ef09d89bd0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:52:19 -0500 Subject: [PATCH 30/32] fix remainder flop counting --- pytato/scalar_expr.py | 7 +++++++ test/test_pytato.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index d4533fa95..f48a81fbe 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -350,6 +350,13 @@ def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + self.rec(expr.numerator) + self.rec(expr.denominator)) + @override + def map_remainder(self, expr: prim.Remainder) -> ArithmeticExpression: + return ( + self._get_op_nflops("%") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + @override def map_power(self, expr: prim.Power) -> ArithmeticExpression: if isinstance(expr.exponent, int): diff --git a/test/test_pytato.py b/test/test_pytato.py index 67cc33880..4ac57a79c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -869,7 +869,7 @@ def test_scalar_flop_count(): assert fc(x // 2) == 4 - assert fc(x % 2) == 0 + assert fc(x % 2) == 4 assert fc(x ** 0) == 0 assert fc(x ** 1) == 0 From aae675c7ead4db8baedc665057ef60a59857747a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:57:00 -0500 Subject: [PATCH 31/32] test recursion in test_scalar_flop_count --- pytato/scalar_expr.py | 4 +-- test/test_pytato.py | 70 +++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index f48a81fbe..5f53d9e8e 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -315,8 +315,8 @@ def map_call(self, expr: prim.Call) -> ArithmeticExpression: @override def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: - # Assume calculations inside subscripts are performed on non-floats - return 0 + # Assume index calculations are performed on non-floats + return self.rec(expr.aggregate) @override def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] diff --git a/test/test_pytato.py b/test/test_pytato.py index 4ac57a79c..9932ec84e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -837,64 +837,68 @@ def test_scalar_flop_count(): import pymbolic.primitives as prim from pymbolic import Variable - x = Variable("x") - y = Variable("y") + x = 2*Variable("x") + y = 3 + Variable("y") - assert fc(Variable("f")(x)) == 32 + assert fc(x) == 1 + assert fc(y) == 1 - assert fc(x[0]) == 0 + assert fc(Variable("f")(x)) == 32 + 1 - assert fc(x + 2) == 1 - assert fc(2 + y) == 1 - assert fc(x + y) == 1 + assert fc(x[0]) == 0 + 1 - assert fc(prim.Sum((2, x, y))) == 2 + assert fc(x + 2) == 1 + 1 + assert fc(2 + y) == 1 + 1 + assert fc(x + y) == 1 + 2 - assert fc(x - 2) == 1 - assert fc(2 - y) == 2 - assert fc(x - y) == 2 + assert fc(prim.Sum((2, x, y))) == 2 + 2 - assert fc(x * 2) == 1 - assert fc(2 * y) == 1 - assert fc(x * y) == 1 + assert fc(x - 2) == 1 + 1 + assert fc(2 - y) == 2 + 1 + assert fc(x - y) == 2 + 2 - assert fc(prim.Product((2, x, y))) == 2 + assert fc(x * 2) == 1 + 1 + assert fc(2 * y) == 1 + 1 + assert fc(x * y) == 1 + 2 - assert fc(x.or_(y)) == 0 - assert fc(x.and_(y)) == 0 + assert fc(prim.Product((2, x, y))) == 2 + 2 - assert fc(x / 2) == 4 - assert fc(2 / y) == 4 - assert fc(x / y) == 4 + assert fc(x.or_(y)) == 0 + 2 + assert fc(x.and_(y)) == 0 + 2 - assert fc(x // 2) == 4 + assert fc(x / 2) == 4 + 1 + assert fc(2 / y) == 4 + 1 + assert fc(x / y) == 4 + 2 - assert fc(x % 2) == 4 + assert fc(x // 2) == 4 + 1 + + assert fc(x % 2) == 4 + 1 assert fc(x ** 0) == 0 - assert fc(x ** 1) == 0 + assert fc(x ** 1) == 0 + 1 # x * x - assert fc(x ** 2) == 1 + assert fc(x ** 2) == 1 + 1 # compute x^2, x^4, x^8, x^16, x^32; multiply x^32 * x^16 * x^8 * x^4 * x * 1 - assert fc(x ** 61) == 5 + 5 + assert fc(x ** 61) == 5 + 5 + 1 # divide; compute x^2, x^4, x^8, x^16; multiply x^16 * x^4 * x^2 * 1 - assert fc(x ** -22) == 4 + 4 + 3 - assert fc(x ** 0.3) == 8 - assert fc(x ** y) == 8 + assert fc(x ** -22) == 4 + 4 + 3 + 1 + assert fc(x ** 0.3) == 8 + 1 + assert fc(x ** y) == 8 + 2 - assert fc(x.lt(y)) == 1 + assert fc(x.lt(y)) == 1 + 2 - assert fc(prim.If(x, x, y)) == 0 + assert fc(prim.If(x, x, y)) == 0 + 3 - assert fc(prim.Min((2, x, y))) == 2 - assert fc(prim.Max((2, x, y))) == 2 + assert fc(prim.Min((2, x, y))) == 2 + 2 + assert fc(prim.Max((2, x, y))) == 2 + 2 from constantdict import constantdict from pytato.reductions import SumReductionOperation from pytato.scalar_expr import Reduce - assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) \ + == 9 + 10 def test_flop_count(): From f919d8568529f22e4574cd9e41b906cfc104b9d3 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:58:14 -0500 Subject: [PATCH 32/32] add note about flop counting for subscript indices --- pytato/analysis/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 2a38dc53c..22139a26e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1416,6 +1416,11 @@ def get_num_flops( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling. @@ -1452,6 +1457,11 @@ def get_materialized_node_flop_counts( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling. @@ -1494,6 +1504,11 @@ def get_unmaterialized_node_flop_counts( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling.