diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index dd72c4312..548880722 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1793,94 +1793,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 4, - "endColumn": 7, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 11, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 11, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 4, - "endColumn": 7, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 15, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 15, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 15, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 4, - "endColumn": 8, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 11, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 11, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -7513,94 +7425,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnusedVariable", - "range": { - "startColumn": 20, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 15, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 22, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 29, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 35, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 48, - "endColumn": 59, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -7825,6 +7649,54 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 14, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 17, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 15, + "endColumn": 21, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 8, + "endColumn": 14, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 17, + "endColumn": 47, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 15, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportIncompatibleMethodOverride", "range": { diff --git a/pytato/__init__.py b/pytato/__init__.py index d99432942..e4e8b8857 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -149,7 +149,7 @@ def set_debug_enabled(flag: bool) -> None: ) from pytato.distributed.tags import number_distributed_tags from pytato.distributed.verify import verify_distributed_partition -from pytato.function import trace_call +from pytato.function import Call, FunctionDefinition, NamedCallResult, trace_call from pytato.loopy import LoopyCall from pytato.pad import pad from pytato.reductions import all, amax, amin, any, prod, sum @@ -157,6 +157,9 @@ def set_debug_enabled(flag: bool) -> None: from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.target.loopy.codegen import generate_loopy from pytato.target.python.jax import generate_jax + +# FIXME: Should some of the functions from pytato.transform be imported here? +# (deduplicate, map_and_copy, etc.) from pytato.transform.calls import inline_calls, tag_all_calls_to_be_inlined from pytato.transform.dead_code_elimination import eliminate_dead_code from pytato.transform.lower_to_index_lambda import to_index_lambda @@ -179,6 +182,7 @@ def set_debug_enabled(flag: bool) -> None: "Axis", "AxisPermutation", "BasicIndex", + "Call", "Concatenate", "DataWrapper", "DictOfNamedArrays", @@ -188,6 +192,7 @@ def set_debug_enabled(flag: bool) -> None: "DistributedSend", "DistributedSendRefHolder", "Einsum", + "FunctionDefinition", "IndexBase", "IndexLambda", "IndexRemappingBase", @@ -195,6 +200,7 @@ def set_debug_enabled(flag: bool) -> None: "LoopyCall", "LoopyPyOpenCLTarget", "NamedArray", + "NamedCallResult", "Placeholder", "ReductionDescriptor", "Reshape", diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 00e36ce32..940338182 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,13 +27,12 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder -from pymbolic.mapper.optimize import optimize_mapper from pytato.array import ( Array, @@ -51,15 +50,16 @@ from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.transform import ( ArrayOrNames, - CachedWalkMapper, - CombineMapper, + MapAndReduceMapper, Mapper, + NodeCollector, + NodeSet, VisitKeyT, ) if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping import pytools.tag @@ -74,20 +74,42 @@ .. autofunction:: is_einsum_similar_to_subscript -.. autofunction:: get_num_nodes +.. autoclass:: DirectPredecessorsGetter +.. autoclass:: ListOfDirectPredecessorsGetter .. autofunction:: get_node_type_counts - +.. autofunction:: get_num_nodes .. autofunction:: get_node_multiplicities +.. autofunction:: get_num_node_instances_of +.. autofunction:: collect_node_instances_of +.. autofunction:: get_num_tags_of_type +.. autofunction:: collect_materialized_nodes +""" -.. autofunction:: get_num_call_sites -.. autoclass:: DirectPredecessorsGetter -.. autoclass:: ListOfDirectPredecessorsGetter +# {{{ reduce_dicts -.. autoclass:: TagCountMapper -.. autofunction:: get_num_tags_of_type -""" +KeyT = TypeVar("KeyT") +ValueT = TypeVar("ValueT") + + +def reduce_dicts( + # FIXME: Is there a way to make argument type annotation more specific? + function: Callable[..., ValueT], + iterable: Iterable[dict[KeyT, ValueT]]) -> dict[KeyT, ValueT]: + """ + Apply *function* to the collection of values corresponding to each unique key in + *iterable*. + """ + key_to_list_of_values: dict[KeyT, list[ValueT]] = defaultdict(list) + for d in iterable: + for key, value in d.items(): + key_to_list_of_values[key].append(value) + return { + key: function(*list_of_values) + for key, list_of_values in key_to_list_of_values.items()} + +# }}} # {{{ ListOfUsersCollector @@ -482,251 +504,534 @@ def __call__( # }}} -# {{{ NodeCountMapper +# {{{ get_node_type_counts -@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class NodeCountMapper(CachedWalkMapper[[]]): - """ - Counts the number of nodes of a given type in a DAG. - - .. autoattribute:: expr_type_counts - .. autoattribute:: count_duplicates +NodeTypeCountDict: TypeAlias = dict[type[ArrayOrNames | FunctionDefinition], int] - Dictionary mapping node types to number of nodes of that type. - """ +# FIXME: I'm on the fence about whether these mapper classes should be kept around +# if they can be replaced with a call to map_and_reduce (which will be the case +# for most of these mappers once count_dict is removed). AFAIK the only real use +# case for using the mapper directly is if you want to compute a result for a +# collection of subexpressions by calling it multiple times while accumulating the set +# of visited nodes. But in most cases you could do that by putting them in a +# DictOfNamedArrays first and then call it once instead? *shrug* +# FIXME: optimize_mapper? +class NodeTypeCountMapper(MapAndReduceMapper[NodeTypeCountDict]): + """Count the number of nodes of each type in a DAG.""" def __init__( self, - count_duplicates: bool = False, - _visited_functions: set[VisitKeyT] | None = None, - ) -> None: - super().__init__(_visited_functions=_visited_functions) - - self.expr_type_counts: dict[type[Any], int] = defaultdict(int) - self.count_duplicates: bool = count_duplicates - - @override - def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: - # Returns unique nodes only if count_duplicates is False - return id(expr) if self.count_duplicates else expr - - @override - def get_function_definition_cache_key( - self, expr: FunctionDefinition) -> int | FunctionDefinition: - # Returns unique nodes only if count_duplicates is False - return id(expr) if self.count_duplicates else expr + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True, + map_dict: bool = False) -> None: + super().__init__( + map_fn=lambda expr: {type(expr): 1}, + reduce_fn=lambda *args: reduce_dicts( + lambda *values: sum(values, 0), args), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + + # FIXME: Remove this once count_dict argument has been eliminated from + # get_node_type_counts + self.map_dict: bool = map_dict @override def clone_for_callee(self, function: FunctionDefinition) -> Self: return type(self)( - count_duplicates=self.count_duplicates, - _visited_functions=self._visited_functions) + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions, + map_dict=self.map_dict) + # FIXME: Remove this once count_dict argument has been eliminated from + # get_node_type_counts @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not isinstance(expr, DictOfNamedArrays): - self.expr_type_counts[type(expr)] += 1 + def map_dict_of_named_arrays( + self, + expr: DictOfNamedArrays, + visited_node_keys: set[VisitKeyT] | None) -> NodeTypeCountDict: + if self.map_dict: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) + else: + return self.reduce_fn( + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) def get_node_type_counts( - outputs: ArrayOrNames, - count_duplicates: bool = False - ) -> dict[type[Any], int]: + outputs: ArrayOrNames | FunctionDefinition, *, + traverse_functions: bool = True, + count_duplicates: bool = False, + count_in_different_functions: bool | None = None, + count_dict: bool | None = None) -> NodeTypeCountDict: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. - - Instances of `DictOfNamedArrays` are excluded from counting. """ + if count_in_different_functions is None: + from warnings import warn + warn( + "The default value of 'count_in_different_functions' will change " + "from False to True in Q3 2026. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_in_different_functions = False - ncm = NodeCountMapper(count_duplicates) - ncm(outputs) + # FIXME: Deprecate/remove count_dict argument entirely after default value is + # changed + if count_dict is None: + from warnings import warn + warn( + "The default value of 'count_dict' will change " + "from False to True in Q3 2026. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_dict = False + + ntcm = NodeTypeCountMapper( + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions, + map_dict=count_dict) + return ntcm(outputs) + +# }}} + + +# {{{ get_num_nodes + +# FIXME: optimize_mapper? +class NodeCountMapper(MapAndReduceMapper[int]): + """Count the total number of nodes in a DAG.""" + def __init__( + self, + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True, + map_dict: bool = False) -> None: + super().__init__( + map_fn=lambda _: 1, + reduce_fn=lambda *args: sum(args, 0), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + + # FIXME: Remove this once count_dict argument has been eliminated from + # get_num_nodes + self.map_dict: bool = map_dict + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions, + map_dict=self.map_dict) - return ncm.expr_type_counts + # FIXME: Remove this once count_dict argument has been eliminated from + # get_num_nodes + @override + def map_dict_of_named_arrays( + self, + expr: DictOfNamedArrays, + visited_node_keys: set[VisitKeyT] | None) -> int: + if self.map_dict: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) + else: + return self.reduce_fn( + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) def get_num_nodes( - outputs: ArrayOrNames, - count_duplicates: bool | None = None - ) -> int: + outputs: ArrayOrNames | FunctionDefinition, *, + traverse_functions: bool = True, + count_duplicates: bool = False, + count_in_different_functions: bool | None = None, + count_dict: bool | None = None) -> int: """ Returns the number of nodes in DAG *outputs*. - Instances of `DictOfNamedArrays` are excluded from counting. """ - if count_duplicates is None: + if count_in_different_functions is None: from warnings import warn warn( - "The default value of 'count_duplicates' will change " - "from True to False in 2025. " + "The default value of 'count_in_different_functions' will change " + "from False to True in Q3 2026. " "For now, pass the desired value explicitly.", DeprecationWarning, stacklevel=2) - count_duplicates = True + count_in_different_functions = False - ncm = NodeCountMapper(count_duplicates) - ncm(outputs) + # FIXME: Deprecate/remove count_dict argument entirely after default value is + # changed + if count_dict is None: + from warnings import warn + warn( + "The default value of 'count_dict' will change " + "from False to True in Q3 2026. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_dict = False - return sum(ncm.expr_type_counts.values()) + ncm = NodeCountMapper( + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions, + map_dict=count_dict) + return ncm(outputs) # }}} -# {{{ NodeMultiplicityMapper +# {{{ get_node_multiplicities +NodeMultiplicityDict: TypeAlias = dict[ArrayOrNames | FunctionDefinition, int] -class NodeMultiplicityMapper(CachedWalkMapper[[]]): + +# FIXME: optimize_mapper? +class NodeMultiplicityMapper(MapAndReduceMapper[NodeMultiplicityDict]): """ Computes the multiplicity of each unique node in a DAG. - The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s - that equal `x`. - - .. autoattribute:: expr_multiplicity_counts + See :func:`get_node_multiplicities` for details. """ - def __init__(self, _visited_functions: set[Any] | None = None) -> None: - super().__init__(_visited_functions=_visited_functions) - - self.expr_multiplicity_counts: \ - dict[ArrayOrNames | FunctionDefinition, int] = defaultdict(int) - - @override - def get_cache_key(self, expr: ArrayOrNames) -> int: - # Returns each node, including nodes that are duplicates - return id(expr) + def __init__( + self, + traverse_functions: bool = True, + map_duplicates: bool = True, + map_in_different_functions: bool = True, + map_dict: bool = False) -> None: + super().__init__( + map_fn=lambda expr: {expr: 1}, + reduce_fn=lambda *args: reduce_dicts( + lambda *values: sum(values, 0), args), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + + # FIXME: Remove this once count_dict argument has been eliminated from + # get_node_multiplicities + self.map_dict: bool = map_dict @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: - # Returns each node, including nodes that are duplicates - return id(expr) + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions, + map_dict=self.map_dict) + # FIXME: Remove this once count_dict argument has been eliminated from + # get_node_multiplicities @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not isinstance(expr, DictOfNamedArrays): - self.expr_multiplicity_counts[expr] += 1 + def map_dict_of_named_arrays( + self, + expr: DictOfNamedArrays, + visited_node_keys: set[VisitKeyT] | None) -> NodeMultiplicityDict: + if self.map_dict: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) + else: + return self.reduce_fn( + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) def get_node_multiplicities( - outputs: ArrayOrNames) -> dict[ArrayOrNames | FunctionDefinition, int]: + outputs: ArrayOrNames | FunctionDefinition, + traverse_functions: bool = True, + count_duplicates: bool = True, + count_in_different_functions: bool = True, + count_dict: bool | None = None) -> NodeMultiplicityDict: """ - Returns the multiplicity per `expr`. + Computes the multiplicity of each unique node in a DAG. + + The multiplicity of a node `x` is the number of times an object equal to `x` will + be mapped during a cached DAG traversal. This varies depending on the combination + of options used. + + :param count_duplicates: If *True*, distinct node instances equal to `x` will be + counted. + :param count_in_different_functions: If *True*, instances equal to `x` in + different functions will be counted. """ - nmm = NodeMultiplicityMapper() - nmm(outputs) + # FIXME: Deprecate/remove count_dict argument entirely after default value is + # changed + if count_dict is None: + from warnings import warn + warn( + "The default value of 'count_dict' will change " + "from False to True in Q3 2026. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_dict = False - return nmm.expr_multiplicity_counts + nmm = NodeMultiplicityMapper( + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions, + map_dict=count_dict) + return nmm(outputs) # }}} -# {{{ CallSiteCountMapper - -@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class CallSiteCountMapper(CachedWalkMapper[[]]): - """ - Counts the number of :class:`~pytato.Call` nodes in a DAG. +# {{{ get_num_node_instances_of - .. attribute:: count +# FIXME: optimize_mapper? +class NodeInstanceCountMapper(MapAndReduceMapper[int]): + """Count the number of nodes in a DAG that are instances of *node_type*.""" + def __init__( + self, + node_type: + type[ArrayOrNames | FunctionDefinition] + | tuple[type[ArrayOrNames | FunctionDefinition], ...], + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True) -> None: + super().__init__( + map_fn=lambda expr: int(isinstance(expr, node_type)), + reduce_fn=lambda *args: sum(args, 0), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + + self.node_type: \ + type[ArrayOrNames | FunctionDefinition] \ + | tuple[type[ArrayOrNames | FunctionDefinition], ...] = node_type - The number of nodes. + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + node_type=self.node_type, + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions) + + +def get_num_node_instances_of( + outputs: ArrayOrNames | FunctionDefinition, + node_type: + type[ArrayOrNames | FunctionDefinition] + | tuple[type[ArrayOrNames | FunctionDefinition], ...], + traverse_functions: bool = True, + count_duplicates: bool = False, + count_in_different_functions: bool = True) -> int: """ + Returns the number of nodes in DAG *outputs* that are instances of *node_type*. + """ + nicm = NodeInstanceCountMapper( + node_type=node_type, + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions) + return nicm(outputs) - def __init__(self, _visited_functions: set[VisitKeyT] | None = None) -> None: - super().__init__(_visited_functions=_visited_functions) - self.count = 0 +# }}} - @override - def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) - @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: - return id(expr) +# {{{ collect_node_instances_of - @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if isinstance(expr, Call): - self.count += 1 +# FIXME: optimize_mapper? +class NodeInstanceCollector(NodeCollector): + """Return the nodes in a DAG that are instances of *node_type*.""" + def __init__( + self, + node_type: + type[ArrayOrNames | FunctionDefinition] + | tuple[type[ArrayOrNames | FunctionDefinition], ...], + traverse_functions: bool = True) -> None: + super().__init__( + collect_fn=lambda expr: isinstance(expr, node_type), + traverse_functions=traverse_functions) + + self.node_type: \ + type[ArrayOrNames | FunctionDefinition] \ + | tuple[type[ArrayOrNames | FunctionDefinition], ...] = node_type @override - def map_function_definition(self, expr: FunctionDefinition) -> None: - if not self.visit(expr): - return + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + node_type=self.node_type, + traverse_functions=self.traverse_functions) + + +def collect_node_instances_of( + outputs: ArrayOrNames | FunctionDefinition, + node_type: + type[ArrayOrNames | FunctionDefinition] + | tuple[type[ArrayOrNames | FunctionDefinition], ...], + traverse_functions: bool = True) -> NodeSet: + """Return the nodes in DAG *outputs* that are instances of *node_type*.""" + nic = NodeInstanceCollector( + node_type=node_type, + traverse_functions=traverse_functions) + return nic(outputs) + +# }}} - new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - new_mapper(subexpr) - self.count += new_mapper.count - self.post_visit(expr) +# {{{ get_num_call_sites +# FIXME: optimize_mapper? +class CallSiteCountMapper(MapAndReduceMapper[int]): + """Count the number of :class:`~pytato.Call` nodes in a DAG.""" + def __init__( + self, + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True) -> None: + from warnings import warn + warn( + "CallSiteCountMapper is deprecated and will be removed in Q3 2026. " + "Use NodeInstanceCountMapper instead.", + DeprecationWarning, stacklevel=2) -def get_num_call_sites(outputs: ArrayOrNames) -> int: - """Returns the number of nodes in DAG *outputs*.""" - cscm = CallSiteCountMapper() - cscm(outputs) + super().__init__( + map_fn=lambda expr: int(isinstance(expr, Call)), + reduce_fn=lambda *args: sum(args, 0), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) - return cscm.count + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions) + + +def get_num_call_sites( + outputs: ArrayOrNames | FunctionDefinition, + traverse_functions: bool = True, + count_duplicates: bool = True, + count_in_different_functions: bool = True) -> int: + """Returns the number of :class:`pytato.function.Call` nodes in DAG *outputs*.""" + from warnings import warn + warn( + "get_num_call_sites is deprecated and will be removed in Q3 2026. " + "Use get_num_node_instances_of instead.", + DeprecationWarning, stacklevel=2) + + nicm = NodeInstanceCountMapper( + node_type=Call, + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions) + return nicm(outputs) # }}} -# {{{ TagCountMapper +# {{{ get_num_tags_of_type -class TagCountMapper(CombineMapper[int, Never]): +# FIXME: optimize_mapper? +class TagCountMapper(MapAndReduceMapper[int]): """ - Returns the number of nodes in a DAG that are tagged with all the tag types in + Count the number of nodes in a DAG that are tagged with all the tag types in *tag_types*. """ - def __init__( self, tag_types: type[pytools.tag.Tag] - | Iterable[type[pytools.tag.Tag]]) -> None: - super().__init__() + | Iterable[type[pytools.tag.Tag]], + traverse_functions: bool = False, + map_duplicates: bool = False, + map_in_different_functions: bool = True) -> None: + super().__init__( + map_fn=lambda expr: int( + isinstance(expr, Array) + and ( + self.tag_types + <= frozenset(type(tag) for tag in expr.tags))), + reduce_fn=lambda *args: sum(args, 0), + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + if isinstance(tag_types, type): tag_types = frozenset((tag_types,)) elif not isinstance(tag_types, frozenset): tag_types = frozenset(tag_types) - self._tag_types = tag_types - - def combine(self, *args: int) -> int: - return sum(args) - - def rec(self, expr: ArrayOrNames) -> int: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - s = Mapper.rec(self, expr) - if ( - isinstance(expr, Array) - and ( - self._tag_types - <= frozenset(type(tag) for tag in expr.tags))): - result = 1 + s - else: - result = 0 + s + self.tag_types: frozenset[type[pytools.tag.Tag]] = tag_types - self._cache_add(inputs, 0) - return result + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + tag_types=self.tag_types, + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions) def get_num_tags_of_type( - outputs: ArrayOrNames, - tag_types: type[pytools.tag.Tag] | Iterable[type[pytools.tag.Tag]]) -> int: + outputs: ArrayOrNames | FunctionDefinition, + tag_types: type[pytools.tag.Tag] | Iterable[type[pytools.tag.Tag]], + traverse_functions: bool | None = None, + count_duplicates: bool = False, + count_in_different_functions: bool = True) -> int: """Returns the number of nodes in DAG *outputs* that are tagged with all the tag types in *tag_types*.""" + if traverse_functions is None: + from warnings import warn + warn( + "The default value of 'traverse_functions' will change " + "from False to True in Q3 2026. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + traverse_functions = False - tcm = TagCountMapper(tag_types) - + tcm = TagCountMapper( + tag_types=tag_types, + traverse_functions=traverse_functions, + map_duplicates=count_duplicates, + map_in_different_functions=count_in_different_functions) return tcm(outputs) # }}} +# {{{ collect_materialized_nodes + +# FIXME: optimize_mapper? +class MaterializedNodeCollector(NodeCollector): + """Return the nodes in a DAG that are materialized.""" + def __init__( + self, + traverse_functions: bool = True) -> None: + def collect_fn(expr: ArrayOrNames | FunctionDefinition) -> bool: + # FIXME: This isn't right; need is_materialized() function from + # https://github.com/inducer/pytato/pull/623 + from pytato.tags import ImplStored + return bool(expr.tags_of_type(ImplStored)) + + super().__init__( + collect_fn=collect_fn, + traverse_functions=traverse_functions) + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + traverse_functions=self.traverse_functions) + + +def collect_materialized_nodes( + outputs: ArrayOrNames | FunctionDefinition, + node_type: type[ArrayOrNames | FunctionDefinition], + traverse_functions: bool = True) -> NodeSet: + """Return the nodes in DAG *outputs* that are materialized.""" + mnc = MaterializedNodeCollector( + traverse_functions=traverse_functions) + return mnc(outputs) + +# }}} + + # {{{ PytatoKeyBuilder class PytatoKeyBuilder(LoopyKeyBuilder): diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0ed13a63f..f5fc0b858 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -118,10 +118,15 @@ .. autoclass:: CachedWalkMapper .. autoclass:: TopoSortMapper .. autoclass:: CachedMapAndCopyMapper +.. autoclass:: MapOnceMapper +.. autoclass:: MapAndReduceMapper +.. autoclass:: NodeCollector .. autofunction:: copy_dict_of_named_arrays .. autofunction:: deduplicate .. autofunction:: get_dependencies .. autofunction:: map_and_copy +.. autofunction:: map_and_reduce +.. autofunction:: collect_nodes .. autofunction:: deduplicate_data_wrappers .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum @@ -1303,6 +1308,9 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper +# FIXME: Might be able to replace this with NodeCollector. Would need to be slightly +# careful, as DependencyMapper excludes DictOfNamedArrays and NodeCollector does not +# (unless specified via collect_fn). class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of @@ -1381,6 +1389,9 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: # {{{ SubsetDependencyMapper +# FIXME: Might be able to implement this as a NodeCollector. Would need to be slightly +# careful, as DependencyMapper excludes DictOfNamedArrays and NodeCollector does not +# (unless specified via collect_fn). class SubsetDependencyMapper(DependencyMapper): """ Mapper to combine the dependencies of an expression that are a subset of @@ -1889,6 +1900,448 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # }}} +# {{{ MapOnceMapper + +# FIXME: optimize_mapper? +class MapOnceMapper(Mapper[ResultT, FunctionResultT, [set[VisitKeyT] | None]]): + """ + Mapper that maps each node to a result only once per call (excepting duplicates + and different function contexts, depending on configuration). Each time the + mapper visits a node after the first visit, the value *no_result_value* or + *no_function_result_value* is returned instead. This is intended to be used as a + base class for mappers that must avoid double counting of results. + + .. attribute:: no_result_value + + The value to return for nodes that have already been visited. + + .. attribute:: no_function_result_value + + The value to return for function nodes that have already been visited. + + .. attribute:: map_duplicates + + Controls whether duplicate nodes (equal nodes with different `id()`) are to + be mapped separately or treated as the same. + + .. attribute:: map_in_different_functions + + Controls whether nodes in different function call contexts are to be mapped + separately or treated as the same. + """ + def __init__( + self, + no_result_value: ResultT, + no_function_result_value: FunctionResultT, + map_duplicates: bool, + map_in_different_functions: bool, + ) -> None: + super().__init__() + + self.no_result_value: ResultT = no_result_value + self.no_function_result_value: FunctionResultT = no_function_result_value + + self.map_duplicates: bool = map_duplicates + self.map_in_different_functions: bool = map_in_different_functions + + def get_visit_key(self, expr: ArrayOrNames) -> VisitKeyT: + return id(expr) if self.map_duplicates else expr + + def get_function_definition_visit_key(self, expr: FunctionDefinition) -> VisitKeyT: + return id(expr) if self.map_duplicates else expr + + @override + def rec( + self, + expr: ArrayOrNames, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + assert visited_node_keys is not None + + key = self.get_visit_key(expr) + + if key in visited_node_keys: + return self.no_result_value + + # Intentionally going to Mapper instead of super() to avoid + # double visiting when subclasses of MapOnceMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = Mapper.rec(self, expr, visited_node_keys) + + visited_node_keys.add(key) + return result + + @override + def rec_function_definition( + self, + expr: FunctionDefinition, + visited_node_keys: set[VisitKeyT] | None) -> FunctionResultT: + assert visited_node_keys is not None + + key = self.get_function_definition_visit_key(expr) + + if key in visited_node_keys: + return self.no_function_result_value + + # Intentionally going to Mapper instead of super() to avoid + # double visiting when subclasses of MapOnceMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = Mapper.rec_function_definition(self, expr, visited_node_keys) + + visited_node_keys.add(key) + return result + + def rec_idx_or_size_tuple( + self, + situp: tuple[IndexOrShapeExpr, ...], + visited_node_keys: set[VisitKeyT] | None) -> tuple[ResultT, ...]: + return tuple( + self.rec(s, visited_node_keys) + for s in situp if isinstance(s, Array)) + + def get_visited_node_keys_for_callee( + self, visited_node_keys: set[VisitKeyT] | None) -> set[VisitKeyT]: + assert visited_node_keys is not None + return ( + visited_node_keys + if not self.map_in_different_functions + else set()) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + no_result_value=self.no_result_value, + no_function_result_value=self.no_function_result_value, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions) + + @overload + def __call__( + self, + expr: ArrayOrNames, + visited_node_keys: set[VisitKeyT] | None = None, + ) -> ResultT: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + visited_node_keys: set[VisitKeyT] | None = None, + ) -> FunctionResultT: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + visited_node_keys: set[VisitKeyT] | None = None, + ) -> ResultT | FunctionResultT: + if visited_node_keys is None: + visited_node_keys = set() + + return super().__call__(expr, visited_node_keys) + +# }}} + + +# {{{ MapAndReduceMapper + +# FIXME: basedpyright doesn't like this +# @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class MapAndReduceMapper(MapOnceMapper[ResultT, ResultT]): + """ + Mapper that, for each node, calls *map_fn* on that node and then calls + *reduce_fn* on the output together with the already-reduced results from its + predecessors. + + .. note:: + + Each node is mapped/reduced only the first time it is visited (excepting + duplicates and different function contexts, depending on configuration). + On subsequent visits, the returned result is `reduce_fn()`. + + .. attribute:: traverse_functions + + Controls whether or not the mapper should descend into function definitions. + + .. attribute:: map_duplicates + + Controls whether duplicate nodes (equal nodes with different `id()`) are to + be mapped separately or treated as the same. + + .. attribute:: map_in_different_functions + + Controls whether nodes in different function call contexts are to be mapped + separately or treated as the same. + """ + def __init__( + self, + map_fn: Callable[[ArrayOrNames | FunctionDefinition], ResultT], + # FIXME: Is there a way to make argument type annotation more specific? + reduce_fn: Callable[..., ResultT], + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True, + ) -> None: + super().__init__( + no_result_value=reduce_fn(), + no_function_result_value=reduce_fn(), + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + + self.map_fn: Callable[[ArrayOrNames | FunctionDefinition], ResultT] = map_fn + self.reduce_fn: Callable[..., ResultT] = reduce_fn + + self.traverse_functions: bool = traverse_functions + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + map_fn=self.map_fn, + reduce_fn=self.reduce_fn, + traverse_functions=self.traverse_functions, + map_duplicates=self.map_duplicates, + map_in_different_functions=self.map_in_different_functions) + + def map_index_lambda( + self, + expr: IndexLambda, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *self.rec_idx_or_size_tuple(expr.shape, visited_node_keys), + *( + self.rec(subexpr, visited_node_keys) + for subexpr in expr.bindings.values())) + + def map_placeholder( + self, + expr: Placeholder, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *self.rec_idx_or_size_tuple(expr.shape, visited_node_keys)) + + def map_stack( + self, + expr: Stack, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(arr, visited_node_keys) for arr in expr.arrays)) + + def map_concatenate( + self, + expr: Concatenate, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(arr, visited_node_keys) for arr in expr.arrays)) + + def map_roll( + self, + expr: Roll, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr.array, visited_node_keys)) + + def map_axis_permutation( + self, + expr: AxisPermutation, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr.array, visited_node_keys)) + + def _map_index_base( + self, + expr: IndexBase, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr.array, visited_node_keys), + *self.rec_idx_or_size_tuple(expr.indices, visited_node_keys)) + + def map_basic_index( + self, + expr: BasicIndex, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self._map_index_base(expr, visited_node_keys) + + def map_contiguous_advanced_index( + self, + expr: AdvancedIndexInContiguousAxes, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self._map_index_base(expr, visited_node_keys) + + def map_non_contiguous_advanced_index( + self, + expr: AdvancedIndexInNoncontiguousAxes, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self._map_index_base(expr, visited_node_keys) + + def map_data_wrapper( + self, + expr: DataWrapper, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *self.rec_idx_or_size_tuple(expr.shape, visited_node_keys)) + + def map_size_param( + self, + expr: SizeParam, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn(self.map_fn(expr)) + + def map_einsum( + self, + expr: Einsum, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(arg, visited_node_keys) for arg in expr.args)) + + def map_named_array( + self, + expr: NamedArray, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr._container, visited_node_keys)) + + def map_dict_of_named_arrays( + self, + expr: DictOfNamedArrays, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *(self.rec(val.expr, visited_node_keys) for val in expr.values())) + + def map_loopy_call( + self, + expr: LoopyCall, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *( + self.rec(subexpr, visited_node_keys) + for subexpr in expr.bindings.values() + if isinstance(subexpr, Array))) + + def map_loopy_call_result( + self, + expr: LoopyCallResult, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr._container, visited_node_keys)) + + def map_reshape( + self, + expr: Reshape, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr.array, visited_node_keys), + *self.rec_idx_or_size_tuple(expr.newshape, visited_node_keys)) + + def map_distributed_send_ref_holder( + self, + expr: DistributedSendRefHolder, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr.send.data, visited_node_keys), + self.rec(expr.passthrough_data, visited_node_keys)) + + def map_distributed_recv( + self, + expr: DistributedRecv, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + *self.rec_idx_or_size_tuple(expr.shape, visited_node_keys)) + + def map_function_definition( + self, + expr: FunctionDefinition, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + if not self.traverse_functions: + return self.no_function_result_value + + new_mapper = self.clone_for_callee(expr) + new_visited_node_keys = self.get_visited_node_keys_for_callee( + visited_node_keys) + return self.reduce_fn( + self.map_fn(expr), + *( + new_mapper(ret, new_visited_node_keys) + for ret in expr.returns.values())) + + def map_call( + self, + expr: Call, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec_function_definition(expr.function, visited_node_keys), + *(self.rec(bnd, visited_node_keys) for bnd in expr.bindings.values())) + + def map_named_call_result( + self, + expr: NamedCallResult, + visited_node_keys: set[VisitKeyT] | None) -> ResultT: + return self.reduce_fn( + self.map_fn(expr), + self.rec(expr._container, visited_node_keys)) + +# }}} + + +# {{{ NodeCollector + +NodeSet: TypeAlias = frozenset[ArrayOrNames | FunctionDefinition] + + +# FIXME: optimize_mapper? +class NodeCollector(MapAndReduceMapper[NodeSet]): + """ + Return the set of nodes in a DAG that match a condition specified via + *collect_fn*. + + .. attribute:: traverse_functions + + Controls whether or not the mapper should descend into function definitions. + """ + def __init__( + self, + collect_fn: Callable[[ArrayOrNames | FunctionDefinition], bool], + traverse_functions: bool = True) -> None: + from functools import reduce + super().__init__( + map_fn=lambda expr: ( + frozenset([expr]) if collect_fn(expr) else frozenset()), + reduce_fn=lambda *args: reduce( + cast("Callable[[NodeSet, NodeSet], NodeSet]", frozenset.union), + args, + cast("NodeSet", frozenset())), + traverse_functions=traverse_functions, + map_duplicates=False, + map_in_different_functions=False) + + self.collect_fn: Callable[[ArrayOrNames | FunctionDefinition], bool] = \ + collect_fn + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + collect_fn=self.collect_fn, + traverse_functions=self.traverse_functions) + +# }}} + + # {{{ mapper frontends def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, @@ -1933,6 +2386,46 @@ def map_and_copy(expr: ArrayOrNamesTc, return CachedMapAndCopyMapper(map_fn)(expr) +def map_and_reduce( + expr: ArrayOrNames | FunctionDefinition, + map_fn: Callable[[ArrayOrNames | FunctionDefinition], ResultT], + # FIXME: Is there a way to make argument type annotation more specific? + reduce_fn: Callable[..., ResultT], + traverse_functions: bool = True, + map_duplicates: bool = False, + map_in_different_functions: bool = True) -> ResultT: + """ + For each node, call *map_fn* on that node and then call *reduce_fn* on the output + together with the already-reduced results from its predecessors. + + See :class:`MapAndReduceMapper` for more details. + """ + mrm = MapAndReduceMapper( + map_fn=map_fn, + reduce_fn=reduce_fn, + traverse_functions=traverse_functions, + map_duplicates=map_duplicates, + map_in_different_functions=map_in_different_functions) + return mrm(expr) + + +def collect_nodes( + expr: ArrayOrNames | FunctionDefinition, + collect_fn: Callable[[ArrayOrNames | FunctionDefinition], bool], + traverse_functions: bool = True + ) -> frozenset[ArrayOrNames | FunctionDefinition]: + """ + Return the set of nodes in a DAG that match a condition specified via + *collect_fn*. + + See :class:`NodeCollector` for more details. + """ + nc = NodeCollector( + collect_fn=collect_fn, + traverse_functions=traverse_functions) + return nc(expr) + + def materialize_with_mpms(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: from warnings import warn warn( diff --git a/test/test_distributed.py b/test/test_distributed.py index 06733c927..31b515da2 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -668,6 +668,7 @@ def _gather_random_dist_partitions(ctx_factory: cl.CtxFactory): dag = get_random_pt_dag_with_send_recv_nodes( seed, rank=comm.rank, size=comm.size, convert_dws_to_placeholders=True) + dag = pt.make_dict_of_named_arrays({"result": dag}) my_partition = pt.find_distributed_partition(comm, dag) diff --git a/test/test_pytato.py b/test/test_pytato.py index 8746810ae..b22482c76 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -34,6 +34,7 @@ import numpy as np import pytest from testlib import RandomDAGContext, make_random_dag +from typing_extensions import override from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests, @@ -245,7 +246,7 @@ def test_accessing_dict_of_named_arrays_validation(): x = pt.make_placeholder(name="x", shape=10) y1y2 = pt.make_dict_of_named_arrays({"y1": 2*x, "y2": 3*x}) - assert isinstance(y1y2["y1"], pt.array.NamedArray) + assert isinstance(y1y2["y1"], pt.NamedArray) assert y1y2["y1"].shape == (2*x).shape assert y1y2["y1"].dtype == (2*x).dtype @@ -603,33 +604,20 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) # noqa: PLR0124 -def test_empty_dag_count(): - from pytato.analysis import get_node_type_counts, get_num_nodes - - empty_dag = pt.make_dict_of_named_arrays({}) - - # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag, count_duplicates=False) == 0 - - counts = get_node_type_counts(empty_dag) - assert len(counts) == 0 - - def test_single_node_dag_count(): from pytato.analysis import get_node_type_counts, get_num_nodes data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays( - {"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_data_wrapper(data) # Get counts per node type - node_counts = get_node_type_counts(single_node_dag) + node_counts = get_node_type_counts(single_node_dag, count_dict=True) # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} # Get total number of nodes - total_nodes = get_num_nodes(single_node_dag, count_duplicates=False) + total_nodes = get_num_nodes(single_node_dag, count_dict=True) assert total_nodes == 1 @@ -637,18 +625,19 @@ def test_single_node_dag_count(): def test_small_dag_count(): from pytato.analysis import get_node_type_counts, get_num_nodes - # Make a DAG using two nodes and one operation + # Make a DAG using two placeholders, one operation, and a dict of named arrays a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) - b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + b = pt.make_placeholder(name="b", shape=(2, 2), dtype=np.float64) + dag = pt.make_dict_of_named_arrays({"res0": a + 1, "res1": b}) - # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag, count_duplicates=False) == 2 + # Verify that get_num_nodes returns 4 for a DAG with four nodes + assert get_num_nodes(dag, count_dict=True) == 4 - counts = get_node_type_counts(dag) - assert len(counts) == 2 - assert counts[pt.array.Placeholder] == 1 # "a" - assert counts[pt.array.IndexLambda] == 1 # single operation + counts = get_node_type_counts(dag, count_dict=True) + assert len(counts) == 3 + assert counts[pt.Placeholder] == 2 # "a" / "b" + assert counts[pt.IndexLambda] == 1 # single operation + assert counts[pt.DictOfNamedArrays] == 1 # dict def test_large_dag_count(): @@ -660,15 +649,59 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag, count_duplicates=False) == iterations + 1 + assert get_num_nodes(dag, count_dict=True) == iterations + 1 - counts = get_node_type_counts(dag) + counts = get_node_type_counts(dag, count_dict=True) assert len(counts) >= 1 - assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert counts[pt.Placeholder] == 1 + assert counts[pt.IndexLambda] == 100 # 100 operations assert sum(counts.values()) == iterations + 1 +def test_dag_with_function_count(): + from pytato.analysis import get_node_type_counts, get_num_nodes + + def f(x): + return x, 2*x + + x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) + + call = pt.trace_call(f, x) + + dag = pt.transform.deduplicate( + pt.make_dict_of_named_arrays({ + "result0": call[0], + "result1": call[1]})) + + # placeholder x 2 (one for input x, one for f parameter x), 2*x, function, call, + # named call result for call[0], named call result for call[1], dict of named + # arrays + + assert get_num_nodes(dag, count_dict=True) == 8 + + counts = get_node_type_counts(dag, count_dict=True) + assert len(counts) == 6 + assert counts[pt.Placeholder] == 2 + assert counts[pt.IndexLambda] == 1 + assert counts[pt.FunctionDefinition] == 1 + assert counts[pt.Call] == 1 + assert counts[pt.NamedCallResult] == 2 + assert counts[pt.DictOfNamedArrays] == 1 + + +class SelfCachingCachedWalkMapper(pt.transform.CachedWalkMapper[[]]): + @override + def get_cache_key(self, expr: pt.transform.ArrayOrNames) -> pt.transform.VisitKeyT: + return expr + + +def get_num_nodes_alt(expr: pt.transform.ArrayOrNames) -> int: + mapper = SelfCachingCachedWalkMapper() + mapper(expr) + + return len(mapper._visited_arrays_or_names) + + def test_random_dag_count(): from testlib import get_random_pt_dag @@ -676,8 +709,7 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - assert get_num_nodes(dag, count_duplicates=False) == len( - pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag, count_dict=True) == get_num_nodes_alt(dag) def test_random_dag_with_comm_count(): @@ -690,8 +722,7 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes( seed=i, rank=rank, size=size) - assert get_num_nodes(dag, count_duplicates=False) == len( - pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag, count_dict=True) == get_num_nodes_alt(dag) def test_small_dag_with_duplicates_count(): @@ -706,31 +737,30 @@ def test_small_dag_with_duplicates_count(): dag = make_small_dag_with_duplicates() # Get the number of expressions, including duplicates - node_count = get_num_nodes(dag, count_duplicates=True) + node_count = get_num_nodes(dag, count_duplicates=True, count_dict=True) expected_node_count = 4 assert node_count == expected_node_count # Get the number of occurrences of each unique expression - node_multiplicity = get_node_multiplicities(dag) + node_multiplicity = get_node_multiplicities( + dag, count_duplicates=True, count_dict=True) assert any(count > 1 for count in node_multiplicity.values()) # Get difference in duplicates num_duplicates = sum(count - 1 for count in node_multiplicity.values()) - counts = get_node_type_counts(dag, count_duplicates=True) + counts = get_node_type_counts(dag, count_duplicates=True, count_dict=True) expected_counts = { - pt.array.Placeholder: 1, - pt.array.IndexLambda: 3 + pt.Placeholder: 1, + pt.IndexLambda: 3 } for node_type, expected_count in expected_counts.items(): assert counts[node_type] == expected_count # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len( - pt.transform.DependencyMapper(err_on_collision=False)(dag)) - assert node_count - num_duplicates == get_num_nodes( - dag, count_duplicates=False) + assert node_count - num_duplicates == get_num_nodes_alt(dag) + assert node_count - num_duplicates == get_num_nodes(dag, count_dict=True) def test_large_dag_with_duplicates_count(): @@ -746,10 +776,10 @@ def test_large_dag_with_duplicates_count(): dag = make_large_dag_with_duplicates(iterations, seed=42) # Get the number of expressions, including duplicates - node_count = get_num_nodes(dag, count_duplicates=True) + node_count = get_num_nodes(dag, count_duplicates=True, count_dict=True) # Get the number of occurrences of each unique expression - node_multiplicity = get_node_multiplicities(dag) + node_multiplicity = get_node_multiplicities(dag, count_dict=True) assert any(count > 1 for count in node_multiplicity.values()) expected_node_count = sum(count for count in node_multiplicity.values()) @@ -758,16 +788,215 @@ def test_large_dag_with_duplicates_count(): # Get difference in duplicates num_duplicates = sum(count - 1 for count in node_multiplicity.values()) - counts = get_node_type_counts(dag, count_duplicates=True) + counts = get_node_type_counts(dag, count_duplicates=True, count_dict=True) - assert counts[pt.array.Placeholder] == 1 + assert counts[pt.Placeholder] == 1 assert sum(counts.values()) == expected_node_count # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len( - pt.transform.DependencyMapper(err_on_collision=False)(dag)) - assert node_count - num_duplicates == get_num_nodes( - dag, count_duplicates=False) + assert node_count - num_duplicates == get_num_nodes_alt(dag) + assert node_count - num_duplicates == get_num_nodes(dag, count_dict=True) + + +def test_dag_with_duplicates_and_function_count(): + from pytato.analysis import ( + get_node_multiplicities, + get_node_type_counts, + get_num_nodes, + ) + + zeros = pt.zeros(shape=(2, 2), dtype=np.float64) + + def f(): + return 2*zeros, 2*zeros + + call = pt.trace_call(f) + + dag = pt.make_dict_of_named_arrays({ + "result0": call[0], + "result1": call[1], + "result2": 2*zeros}) + + # {{{ counting with duplicates + + # zeros x 2 (once inside f, once outside), 2*zeros x 3 (twice inside f, once + # outside), function, call, named call result for call[0], named call result for + # call[1], dict of named arrays + + assert get_num_nodes( + dag, count_duplicates=True, count_in_different_functions=True, + count_dict=True) == 10 + + counts = get_node_type_counts( + dag, count_duplicates=True, count_in_different_functions=True, + count_dict=True) + assert len(counts) == 5 + assert counts[pt.IndexLambda] == 5 + assert counts[pt.FunctionDefinition] == 1 + assert counts[pt.Call] == 1 + assert counts[pt.NamedCallResult] == 2 + assert counts[pt.DictOfNamedArrays] == 1 + + multiplicities = get_node_multiplicities( + dag, traverse_functions=True, count_in_different_functions=True, + count_dict=True) + assert len(multiplicities) == 7 + assert sum(count for count in multiplicities.values()) == 10 + assert multiplicities[zeros] == 2 + assert multiplicities[2*zeros] == 3 + + multiplicities = get_node_multiplicities( + dag, traverse_functions=False, count_in_different_functions=True, + count_dict=True) + assert len(multiplicities) == 6 + assert sum(count for count in multiplicities.values()) == 6 + assert multiplicities[zeros] == 1 + assert multiplicities[2*zeros] == 1 + + # }}} + + # {{{ counting without duplicates + + # zeros x 2 (once inside f, once outside), 2*zeros x 2 (once inside f, once + # outside), function, call, named call result for call[0], named call result for + # call[1], dict of named arrays + + assert get_num_nodes( + dag, count_in_different_functions=True, count_dict=True) == 9 + + counts = get_node_type_counts( + dag, count_in_different_functions=True, count_dict=True) + assert len(counts) == 5 + assert counts[pt.IndexLambda] == 4 + assert counts[pt.FunctionDefinition] == 1 + assert counts[pt.Call] == 1 + assert counts[pt.NamedCallResult] == 2 + assert counts[pt.DictOfNamedArrays] == 1 + + multiplicities = get_node_multiplicities( + dag, traverse_functions=True, count_duplicates=False, + count_in_different_functions=True, count_dict=True) + assert len(multiplicities) == 7 + assert sum(count for count in multiplicities.values()) == 9 + assert multiplicities[zeros] == 2 + assert multiplicities[2*zeros] == 2 + + multiplicities = get_node_multiplicities( + dag, traverse_functions=False, count_duplicates=False, + count_in_different_functions=True, count_dict=True) + assert len(multiplicities) == 6 + assert sum(count for count in multiplicities.values()) == 6 + assert multiplicities[zeros] == 1 + assert multiplicities[2*zeros] == 1 + + # }}} + + +def test_get_num_node_instances_of(): + from pytato.analysis import get_num_node_instances_of + + indices = pt.make_data_wrapper(np.array([1, 0], dtype=np.int32)) + + x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) + y = x[:, indices] + z = y[0] + + def f(a): + return a[:, indices] + + u = pt.trace_call(f, x) + v = u[0] + + dag = z + v + + # indices (once outside f, once inside), x, y, z, a, a[:, indices], f, call(f), + # u, v, z + v + assert get_num_node_instances_of( + dag, + ( + pt.Array, + pt.AbstractResultWithNamedArrays, + pt.FunctionDefinition)) == 12 + + # Same as above, but with indices only counted once + assert get_num_node_instances_of( + dag, + ( + pt.Array, + pt.AbstractResultWithNamedArrays, + pt.FunctionDefinition), + count_in_different_functions=False) == 11 + + # Everything but f and call(f) + assert get_num_node_instances_of(dag, pt.Array) == 10 + + # y, z, a[:, indices], v + assert get_num_node_instances_of(dag, pt.IndexBase) == 4 + + # z, v + assert get_num_node_instances_of(dag, pt.BasicIndex) == 2 + + # y, a[:, indices] + assert get_num_node_instances_of(dag, pt.AdvancedIndexInContiguousAxes) == 2 + + +def test_collect_node_instances_of(): + from pytato.analysis import collect_node_instances_of + + indices = pt.make_data_wrapper(np.array([1, 0], dtype=np.int32)) + + x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) + y = x[:, indices] + z = y[0] + + # Save a and b so we can reference them below + f_nodes = [] + + def f(a): + b = a[:, indices] + assert not f_nodes + f_nodes.extend([a, b]) + return b + + u = pt.trace_call(f, x) + v = u[0] + + dag = z + v + + a = f_nodes[0] + b = f_nodes[1] + + assert ( + collect_node_instances_of( + dag, + ( + pt.Array, + pt.AbstractResultWithNamedArrays, + pt.FunctionDefinition)) + == frozenset([ + indices, x, y, z, a, b, u._container.function, u._container, u, v, dag])) + + assert ( + collect_node_instances_of(dag, pt.Array) + == frozenset([indices, x, y, z, a, b, u, v, dag])) + + assert ( + collect_node_instances_of(dag, pt.IndexBase) + == frozenset([y, z, b, v])) + + assert ( + collect_node_instances_of(dag, pt.BasicIndex) + == frozenset([z, v])) + + assert ( + collect_node_instances_of(dag, pt.AdvancedIndexInContiguousAxes) + == frozenset([y, b])) + + +def test_collect_materialized_nodes(): + # FIXME: Add tests after fixing collect_materialized_nodes to use + # is_materialized() function from https://github.com/inducer/pytato/pull/623 + pytest.fail("Not implemented yet.") def test_rec_get_user_nodes(): @@ -1252,12 +1481,11 @@ class ExistentTag(Tag): rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False) - out = make_random_dag(rdagc_pt).tagged(ExistentTag()) - - dag = pt.transform.deduplicate(pt.make_dict_of_named_arrays({"out": out})) + dag = make_random_dag(rdagc_pt).tagged(ExistentTag()) - # get_num_nodes() returns an extra DictOfNamedArrays node - assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) + assert ( + get_num_tags_of_type(dag, frozenset()) + == get_num_nodes(dag, count_dict=True)) assert get_num_tags_of_type(dag, NonExistentTag) == 0 assert get_num_tags_of_type(dag, frozenset((ExistentTag,))) == 1 @@ -1267,7 +1495,9 @@ class ExistentTag(Tag): a = pt.make_data_wrapper(np.arange(27)) dag = a+a+a+a+a+a+a+a - assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) + assert ( + get_num_tags_of_type(dag, frozenset()) + == get_num_nodes(dag, count_dict=True)) def test_expand_dims_input_validate(): diff --git a/test/testlib.py b/test/testlib.py index 0b587f400..c58d62735 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -10,7 +10,7 @@ from pytools.tag import Tag import pytato as pt -from pytato.transform import Mapper +from pytato.transform import ArrayOrNames, Mapper if TYPE_CHECKING: @@ -299,7 +299,7 @@ def get_random_pt_dag(seed: int, axis_len: int = 4, convert_dws_to_placeholders: bool = False, allow_duplicate_nodes: bool = False - ) -> pt.DictOfNamedArrays: + ) -> ArrayOrNames: if additional_generators is None: additional_generators = [] @@ -309,14 +309,13 @@ def get_random_pt_dag(seed: int, axis_len=axis_len, use_numpy=False, allow_duplicate_nodes=allow_duplicate_nodes, additional_generators=additional_generators) - dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)}) + dag = make_random_dag(rdagc_comm) if convert_dws_to_placeholders: from pytools import UniqueNameGenerator vng = UniqueNameGenerator() - def make_dws_placeholder(expr: pt.transform.ArrayOrNames - ) -> pt.transform.ArrayOrNames: + def make_dws_placeholder(expr: ArrayOrNames) -> ArrayOrNames: if isinstance(expr, pt.DataWrapper): return pt.make_placeholder(vng("_pt_ph"), expr.shape, expr.dtype) @@ -336,7 +335,7 @@ def get_random_pt_dag_with_send_recv_nodes( comm_fake_probability: int = 500, axis_len: int = 4, convert_dws_to_placeholders: bool = False - ) -> pt.DictOfNamedArrays: + ) -> ArrayOrNames: comm_tag = 17 def gen_comm(rdagc: RandomDAGContext) -> pt.Array: @@ -356,7 +355,7 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: additional_generators=[(comm_fake_probability, gen_comm)]) -def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: +def make_large_dag(iterations: int, seed: int = 0) -> ArrayOrNames: """ Builds a DAG with emphasis on number of operations. """ @@ -376,23 +375,20 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: current = operation(current, value) # DAG should have `iterations` number of operations - return pt.make_dict_of_named_arrays({"result": current}) + return current -def make_small_dag_with_duplicates() -> pt.DictOfNamedArrays: +def make_small_dag_with_duplicates() -> ArrayOrNames: x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) expr1 = 2 * x expr2 = 2 * x - y = expr1 + expr2 - # Has duplicates of the 2*x operation - return pt.make_dict_of_named_arrays({"result": y}) + return expr1 + expr2 -def make_large_dag_with_duplicates(iterations: int, - seed: int = 0) -> pt.DictOfNamedArrays: +def make_large_dag_with_duplicates(iterations: int, seed: int = 0) -> ArrayOrNames: random.seed(seed) rng = np.random.default_rng(seed) @@ -419,8 +415,7 @@ def make_large_dag_with_duplicates(iterations: int, all_exprs = [current, *duplicates] combined_expr = pt.stack(all_exprs, axis=0) - result = pt.sum(combined_expr, axis=0) - return pt.make_dict_of_named_arrays({"result": result}) + return pt.sum(combined_expr, axis=0) # }}}