diff --git a/doc/array.rst b/doc/array.rst index a31fb2e53..56772affa 100644 --- a/doc/array.rst +++ b/doc/array.rst @@ -5,12 +5,17 @@ Array Expressions .. automodule:: pytato.array +Functions in Pytato IR +---------------------- + +.. automodule:: pytato.function Raising :class:`~pytato.array.IndexLambda` nodes ------------------------------------------------ .. automodule:: pytato.raising + Calling :mod:`loopy` kernels in an array expression --------------------------------------------------- diff --git a/pytato/__init__.py b/pytato/__init__.py index afb08653c..6536a9fa3 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -91,8 +91,10 @@ def set_debug_enabled(flag: bool) -> None: from pytato.visualization import (get_dot_graph, show_dot_graph, get_ascii_graph, show_ascii_graph, get_dot_graph_from_partition) +from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls import pytato.analysis as analysis import pytato.tags as tags +import pytato.function as function import pytato.transform as transform from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv, DistributedRecv, DistributedSend, @@ -108,6 +110,8 @@ def set_debug_enabled(flag: bool) -> None: from pytato.transform.remove_broadcasts_einsum import ( rewrite_einsums_with_no_broadcasts) from pytato.transform.metadata import unify_axes_tags +from pytato.function import trace_call +from pytato.transform.calls import concatenate_calls from pytato.partition import generate_code_for_partition @@ -154,11 +158,17 @@ def set_debug_enabled(flag: bool) -> None: "broadcast_to", "pad", + "trace_call", + + "concatenate_calls", + "make_distributed_recv", "make_distributed_send", "DistributedRecv", "DistributedSend", "staple_distributed_send", "DistributedSendRefHolder", "DistributedGraphPart", "DistributedGraphPartition", + "tag_all_calls_to_be_inlined", "inline_calls", + "find_distributed_partition", "number_distributed_tags", @@ -174,6 +184,6 @@ def set_debug_enabled(flag: bool) -> None: "unify_axes_tags", # sub-modules - "analysis", "tags", "transform", + "analysis", "tags", "transform", "function", ) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a3c5e99ce..31f6c788e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -31,9 +31,11 @@ DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, ShapeType) +from pytato.function import FunctionDefinition, Call from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -47,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_num_call_sites + .. autoclass:: DirectPredecessorsGetter """ @@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: return ncm.count # }}} + + +# {{{ CallSiteCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class CallSiteCountMapper(CachedWalkMapper): + """ + Counts the number of nodes in a DAG. + + .. attribute:: count + + The number of nodes. + """ + + def __init__(self) -> None: + super().__init__() + self.count = 0 + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + @memoize_method + def map_function_definition(self, /, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.count += new_mapper.count + + self.post_visit(expr, *args, **kwargs) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, Call): + self.count += 1 + + +def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Returns the number of nodes in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + cscm = CallSiteCountMapper() + cscm(outputs) + + return cscm.count + +# }}} diff --git a/pytato/codegen.py b/pytato/codegen.py index 63067fbd9..0bc85d649 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -23,7 +23,7 @@ """ import dataclasses -from typing import Union, Dict, Tuple, List, Any +from typing import Union, Dict, Tuple, List, Any, Optional from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder, DataInterface, SizeParam, InputArgumentBase, @@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target) -> None: + def __init__(self, target: Target, + kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None + ) -> None: super().__init__() self.bound_arguments: Dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target - self.kernels_seen: Dict[str, lp.LoopKernel] = {} + self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {} + + def clone_for_callee(self) -> CodeGenPreprocessor: + return CodeGenPreprocessor(self.target, self.kernels_seen) def map_size_param(self, expr: SizeParam) -> Array: name = expr.name @@ -262,6 +267,7 @@ class PreprocessResult: def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: """Preprocess a computation for code generation.""" from pytato.transform import copy_dict_of_named_arrays + from pytato.transform.calls import inline_calls check_validity_of_outputs(outputs) @@ -289,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: # }}} - mapper = CodeGenPreprocessor(target) + new_outputs = inline_calls(outputs) + assert isinstance(new_outputs, DictOfNamedArrays) - new_outputs = copy_dict_of_named_arrays(outputs, mapper) + mapper = CodeGenPreprocessor(target) + new_outputs = copy_dict_of_named_arrays(new_outputs, mapper) return PreprocessResult(outputs=new_outputs, - compute_order=tuple(output_order), - bound_arguments=mapper.bound_arguments) + compute_order=tuple(output_order), + bound_arguments=mapper.bound_arguments) # vim: fdm=marker diff --git a/pytato/equality.py b/pytato/equality.py index 5eb1814c6..42c2978cd 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -31,6 +31,8 @@ IndexBase, IndexLambda, NamedArray, Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) +from pytato.function import Call, NamedCallResult, FunctionDefinition +from pytools import memoize_method if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -273,6 +275,31 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.tags == expr2.tags ) + @memoize_method + def map_function_definition(self, expr1: FunctionDefinition, expr2: Any + ) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.parameters == expr2.parameters + and (set(expr1.returns.keys()) == set(expr2.returns.keys())) + and all(self.rec(expr1.returns[k], expr2.returns[k]) + for k in expr1.returns) + and expr1.tags == expr2.tags + ) + + def map_call(self, expr1: Call, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.map_function_definition(expr1.function, expr2.function) + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + for name, bnd in expr1.bindings.items()) + ) + + def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self.rec(expr1._container, expr2._container)) + # }}} # vim: fdm=marker diff --git a/pytato/function.py b/pytato/function.py new file mode 100644 index 000000000..9f3733b03 --- /dev/null +++ b/pytato/function.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +__doc__ = """ +.. currentmodule:: pytato + +.. autofunction:: trace_call + +.. currentmodule:: pytato.function + +.. autoclass:: Call +.. autoclass:: NamedCallResult +.. autoclass:: FunctionDefinition +.. autoclass:: ReturnType + +.. class:: ReturnT + + A type variable corresponding to the return type of the function + :func:`pytato.trace_call`. +""" + +__copyright__ = """ +Copyright (C) 2022 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import attrs +import re +import enum + +from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional, + Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping) +from immutables import Map +from functools import cached_property +from pytato.array import (Array, AbstractResultWithNamedArrays, AxesT, + Placeholder, NamedArray, ShapeType, _dtype_any, + InputArgumentBase) +from pytools import memoize_method +from pytools.tag import Tag, Taggable + +ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array]) + + +# {{{ Call/NamedCallResult + + +@enum.unique +class ReturnType(enum.Enum): + """ + Records the function body's return type in :class:`FunctionDefinition`. + """ + ARRAY = 0 + DICT_OF_ARRAYS = 1 + TUPLE_OF_ARRAYS = 2 + + +@attrs.define(frozen=True, repr=False, eq=False) +class FunctionDefinition(Taggable): + r""" + A function definition that represents its outputs as instances of + :class:`~pytato.Array` with the inputs being + :class:`~pytato.array.Placeholder`\ s. The outputs of the function + can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an + instance of ``Dict[str, Array]``. + + .. attribute:: parameters + + Names of the input :class:`~pytato.array.Placeholder`\ s to the + function node. This is a superset of the names of + :class:`~pytato.array.Placeholder` instances encountered in + :attr:`returns`. Unused parameters are allowed. + + .. attribute:: return_type + + An instance of :class:`ReturnType`. + + .. attribute:: returns + + The outputs of the function call which are array expressions that + depend on the *parameters*. The keys of the mapping depend on + :attr:`return_type` as: + + - If the function returns a single :class:`pytato.Array`, then + *returns* contains a single array expression with ``"_"`` as the + key. + - If the function returns a :class:`tuple` of + :class:`pytato.Array`\ s, then *returns* contains entries with + the key ``"_N"`` mapping the ``N``-th entry of the result-tuple. + - If the function returns a :class:`dict` mapping identifiers to + :class:`pytato.Array`\ s, then *returns* uses the same mapping. + + .. automethod:: get_placeholder + + .. note:: + + A :class:`FunctionDefinition` comes with its own namespace based on + :attr:`parameters`. A :class:`~pytato.transform.Mapper`-implementer + must ensure **not** to reuse the cached result between the caller's + expressions and a function definition's expressions to avoid unsound + cache hits that could lead to incorrect mappings. + + """ + parameters: FrozenSet[str] + return_type: ReturnType + returns: Map[str, Array] + tags: FrozenSet[Tag] = attrs.field(kw_only=True) + + def __hash__(self) -> int: + return hash((self.tags, self.parameters, self.returns)) + + @cached_property + def _placeholders(self) -> Mapping[str, Placeholder]: + from pytato.transform import InputGatherer + from functools import reduce + + mapper = InputGatherer() + + all_input_args: FrozenSet[InputArgumentBase] = reduce( + frozenset.union, + (mapper(ary) for ary in self.returns.values()), + frozenset() + ) + + return Map({input_arg.name: input_arg + for input_arg in all_input_args + if isinstance(input_arg, Placeholder)}) + + def get_placeholder(self, name: str) -> Placeholder: + """ + Returns the instance of :class:`pytato.array.Placeholder` corresponding + to the parameter *name* in function body. + """ + return self._placeholders[name] + + def __call__(self, **kwargs: Array + ) -> Union[Array, + Tuple[Array, ...], + Dict[str, Array]]: + from pytato.array import _get_default_axes, _get_default_tags + from pytato.utils import are_shapes_equal + + # {{{ sanity checks + + if self.parameters > frozenset(kwargs): + missing_params = self.parameters - frozenset(kwargs) + raise ValueError(f"Missing arguments: '{missing_params}'.") + + if frozenset(kwargs) > self.parameters: + extra_params = self.parameters - frozenset(kwargs) + raise ValueError(f"Unexpected arguments: '{extra_params}'.") + + for argname, expected_arg in self._placeholders.items(): + if expected_arg.dtype != kwargs[argname].dtype: + raise ValueError(f"Argument '{argname}' expected to " + f" be of type '{expected_arg.dtype}', got" + f" '{kwargs[argname].dtype}'.") + if not are_shapes_equal(expected_arg.shape, kwargs[argname].shape): + raise ValueError(f"Argument '{argname}' expected to " + f" have shape '{expected_arg.shape}', got" + f" '{kwargs[argname].shape}'.") + + # }}} + + call_site = Call( + self, + bindings=Map(kwargs), + result_tags=Map({name: _get_default_tags() + for name in self.returns}), + result_axes=Map({name: _get_default_axes(ret.ndim) + for name, ret in self.returns.items()}), + tags=_get_default_tags() + ) + + if self.return_type == ReturnType.ARRAY: + return call_site["_"] + elif self.return_type == ReturnType.TUPLE_OF_ARRAYS: + return tuple(call_site[f"_{iarg}"] + for iarg in range(len(self.returns))) + elif self.return_type == ReturnType.DICT_OF_ARRAYS: + return {kw: call_site[kw] for kw in self.returns} + else: + raise NotImplementedError(self.return_type) + + +class NamedCallResult(NamedArray): + """ + One of the arrays that are returned from a call to :class:`FunctionDefinition`. + + .. attribute:: traced_call + + The function invocation that led to *self*. + + .. attribute:: name + + The name by which the returned array is referred to in + :attr:`FunctionDefinition.returns`. + """ + _mapper_method: ClassVar[str] = "map_named_call_result" + + def __init__(self, + traced_call: "Call", + name: str) -> None: + super().__init__(traced_call, name, + axes=traced_call.function.returns[name].axes, + tags=traced_call.function.returns[name].tags) + + def with_tagged_axis(self, iaxis: int, + tags: Union[Sequence[Tag], Tag]) -> Array: + raise ValueError("Tagging a NamedCallResult's axis is illegal, use" + " Call.with_tagged_axis instead") + + def tagged(self, + tags: Union[Iterable[Tag], Tag, None]) -> NamedCallResult: + raise ValueError("Tagging a NamedCallResult is illegal, use" + " Call.tagged instead") + + def without_tags(self, + tags: Union[Iterable[Tag], Tag, None], + verify_existence: bool = True, + ) -> NamedCallResult: + raise ValueError("Untagging a NamedCallResult is illegal, use" + " Call.without_tags instead") + + @property + def shape(self) -> ShapeType: + assert isinstance(self._container, Call) + return self._container.function.returns[self.name].shape + + @property + def dtype(self) -> _dtype_any: + assert isinstance(self._container, Call) + return self._container.function.returns[self.name].dtype + + +@attrs.define(frozen=True, eq=False, repr=False) +class Call(AbstractResultWithNamedArrays): + """ + Records an invocation to a :class:`FunctionDefinition`. + + .. attribute:: function + + The instance of :class:`FunctionDefinition` being called by this call-site. + + .. attribute:: bindings + + A mapping from the placeholder names of :class:`FunctionDefinition` to + their corresponding parameters in the invocation to :attr:`function`. + + .. attribute:: result_tags + + A mapping from the results of the call invocation to their + corresponding tags. + + .. attribute:: result_axes + + A mapping from the results of the call invocation to their corresponding + :class`~pytato.array.AxesT` for recording the metadata for each of their + dimensions. + + .. automethod:: with_result_axis_tagged + .. automethod:: with_result_tagged + + .. note:: + + To distinguish the metadata of the results in the caller expression to + the metadata in the callee expression we avoid using the same metadata + between the two. This is the motivation for introducing the attributes + ``result_tags`` and ``result_axes``. + """ + function: FunctionDefinition + bindings: Map[str, Array] + result_tags: Map[str, FrozenSet[Tag]] = attrs.field(kw_only=True) + result_axes: Map[str, AxesT] = attrs.field(kw_only=True) + + _mapper_method: ClassVar[str] = "map_call" + + copy = attrs.evolve + + def __post_init__(self) -> None: + # check that the invocation parameters and the function definition + # parameters agree with each other. + assert frozenset(self.bindings) == self.function.parameters + assert set(self.result_tags.keys()) == set(self.function.returns.keys()) + assert set(self.result_axes.keys()) == set(self.function.returns.keys()) + super().__post_init__() + + def __contains__(self, name: object) -> bool: + return name in self.function.returns + + def __iter__(self) -> Iterator[str]: + return iter(self.function.returns) + + def __getitem__(self, name: str) -> NamedCallResult: + return NamedCallResult(self, name) + + def __len__(self) -> int: + return len(self.function.returns) + + def with_result_axis_tagged(self, name: str, + iaxis: int, + tags: Union[Sequence[Tag], Tag]) -> Call: + """ + Returns a copy of *self* with the result corresponding to *name*\'s + *iaxis*-th axis tagged with *tags*. Also, see + :meth:`pytato.Array.with_tagged_axis`. + """ + from attrs import evolve as replace + + if not (0 <= iaxis < self[name].ndim): + raise ValueError(f"Got iaxis={iaxis} for an array of dimension" + " {self[name].dim}") + + new_axes = tuple(axis.tagged(tags) if i == iaxis + else axis + for i, axis in enumerate(self.result_axes[name])) + return replace(self, + result_axes=self.result_axes.set(name, new_axes), + ) + + def with_result_tagged(self, name: str, + tags: Union[Sequence[Tag], Tag]) -> Call: + """ + Returns a copy of *self* with the result corresponding to *name*\'s + tagged with *tags*. Also, see :meth:`pytato.Array.tagged`. + """ + from attrs import evolve as replace + from pytools.tag import check_tag_uniqueness, normalize_tags + new_tags = check_tag_uniqueness(normalize_tags(tags) + | self.result_tags[name]) + + return replace(self, result_tags=self.result_tags.set(name, new_tags)) + + @memoize_method + def __hash__(self) -> int: + return hash((self.function, self.bindings, self.tags, self.result_tags, + self.result_axes)) + + +# }}} + + +# {{{ user-facing routines + +class _Guess: + pass + + +RE_ARGNAME = re.compile(r"^_pt_(\d+)$") + + +def trace_call(f: Callable[..., ReturnT], + *args: Array, + identifier: Optional[Hashable] = _Guess, + **kwargs: Array) -> ReturnT: + """ + Returns the expressions returned after calling *f* with the arguments + *args* and keyword arguments *kwargs*. The subexpressions in the returned + expressions are outlined as a :class:`~pytato.function.FunctionDefinition`. + + :arg identifier: A hashable object that acts as + :attr:`pytato.tags.FunctionIdentifier.identifier` for the + :class:`~pytato.tags.FunctionIdentifier` tagged to the outlined + :class:`~pytato.function.FunctionDefinition`. If ``None`` the function + definition is not tagged with a + :class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the + function identifier is guessed from ``f.__name__``. + """ + from pytato.tags import FunctionIdentifier + from pytato.array import _get_default_tags + + if identifier is _Guess: + # partials might not have a __name__ attribute + identifier = getattr(f, "__name__", None) + + for kw in kwargs: + if RE_ARGNAME.match(kw): + # avoid collision between argument names + raise ValueError(f"Kw argument named '{kw}' not allowed.") + + # Get placeholders from the ``args``, ``kwargs``. + pl_args = tuple(Placeholder(f"in__pt_{iarg}", arg.shape, arg.dtype, + axes=arg.axes, tags=arg.tags) + for iarg, arg in enumerate(args)) + pl_kwargs = {kw: Placeholder(f"in_{kw}", arg.shape, + arg.dtype, axes=arg.axes, tags=arg.tags) + for kw, arg in kwargs.items()} + + # Pass the placeholders + output = f(*pl_args, **pl_kwargs) + + if isinstance(output, Array): + returns = {"_": output} + return_type = ReturnType.ARRAY + elif isinstance(output, tuple): + assert all(isinstance(el, Array) for el in output) + returns = {f"_{iout}": out for iout, out in enumerate(output)} + return_type = ReturnType.TUPLE_OF_ARRAYS + elif isinstance(output, dict): + assert all(isinstance(el, Array) for el in output.values()) + returns = output + return_type = ReturnType.DICT_OF_ARRAYS + else: + raise ValueError("The function being traced must return one of" + f"pytato.Array, tuple, dict. Got {type(output)}.") + + # construct the function + function = FunctionDefinition( + frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), + return_type, + Map(returns), + tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) + if identifier + else frozenset()) + ) + + # type-ignore-reason: return type is dependent on dynamic state i.e. + # ret_type and hence mypy is unhappy + return function( # type: ignore[return-value] + **{pl.name: arg for pl, arg in zip(pl_args, args)}, + **{pl_kwargs[kw].name: arg for kw, arg in kwargs.items()} + ) + +# }}} + +# vim:foldmethod=marker diff --git a/pytato/tags.py b/pytato/tags.py index 8d080f66e..cdfea12af 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -10,9 +10,12 @@ .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative .. autoclass:: ExpandedDimsReshape +.. autoclass:: FunctionIdentifier +.. autoclass:: CallImplementationTag +.. autoclass:: InlineCallTag """ -from typing import Tuple +from typing import Tuple, Hashable from pytools.tag import Tag, UniqueTag from dataclasses import dataclass @@ -125,3 +128,33 @@ class ExpandedDimsReshape(UniqueTag): frozenset({ExpandedDimsReshape(new_dims=(0, 2, 4))}) """ new_dims: Tuple[int, ...] + + +@dataclass(frozen=True) +class FunctionIdentifier(UniqueTag): + """ + A tag that can be attached to a :class:`~pytato.function.FunctionDefinition` + node to to describe the function's identifier. One can use this to refer + all instances of :class:`~pytato.function.FunctionDefinition`, for ex. in + transformations like :func:`~pytato.transform.calls.concatenate_calls`. + + .. attribute:: identifier + """ + identifier: Hashable + + +@dataclass(frozen=True) +class CallImplementationTag(UniqueTag): + """ + A tag that can be attached to a :class:`~pytato.function.Call` node to + direct a :class:`~pytato.target.Target` how the call site should be + lowered. + """ + + +@dataclass(frozen=True) +class InlineCallTag(CallImplementationTag): + r""" + A :class:`CallImplementationTag` that directs the + :class:`pytato.target.Target` to inline the call-site. + """ diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index ad7b73491..7d8760ed6 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -46,6 +46,7 @@ from pytato.transform import Mapper from pytato.scalar_expr import ScalarExpression, INT_CLASSES from pytato.codegen import preprocess, normalize_outputs, SymbolicIndex +from pytato.function import Call, NamedCallResult from pytato.loopy import LoopyCall from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed, ImplementationStrategy) @@ -575,6 +576,19 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef": state.update_kernel(kernel) + def map_named_call_result(self, expr: NamedCallResult, + state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + + def map_call(self, expr: Call, state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + # }}} @@ -972,36 +986,30 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], .. note:: - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: - - - Inames that index over an :class:`~pytato.array.Array`'s axis in the - allocation instruction are tagged with the corresponding - :class:`~pytato.array.Axis`'s tags. The caller may choose to not - propagate axis tags of type *axis_tag_t_to_not_propagate*. - - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* - would be copied over to the tags of the corresponding - :class:`loopy.ArrayArg`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. - - Arrays tagged with :class:`pytato.tags.ImplStored` would have their - tags copied over to the tags of corresponding - :class:`loopy.TemporaryVariable`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. + - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: + + - Inames that index over an :class:`~pytato.array.Array`'s axis in the + allocation instruction are tagged with the corresponding + :class:`~pytato.array.Axis`'s tags. The caller may choose to not + propagate axis tags of type *axis_tag_t_to_not_propagate*. + - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* + would be copied over to the tags of the corresponding + :class:`loopy.ArrayArg`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + - Arrays tagged with :class:`pytato.tags.ImplStored` would have their + tags copied over to the tags of corresponding + :class:`loopy.TemporaryVariable`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + + .. warning:: + + Currently only :class:`~pytato.function.Call` nodes that are tagged with + :class:`pytato.tags.InlineCallTag` can be lowered to :mod:`loopy` IR. """ result_is_dict = isinstance(result, (dict, DictOfNamedArrays)) orig_outputs: DictOfNamedArrays = normalize_outputs(result) - # optimization: remove any ImplStored tags on outputs to avoid redundant - # store-load operations (see https://github.com/inducer/pytato/issues/415) - orig_outputs = DictOfNamedArrays( - {name: (output.without_tags(ImplStored(), - verify_existence=False) - if not isinstance(output, - InputArgumentBase) - else output) - for name, output in orig_outputs._data.items()}, - tags=orig_outputs.tags) - del result if cl_device is not None: @@ -1017,6 +1025,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], preproc_result = preprocess(orig_outputs, target) outputs = preproc_result.outputs + # optimization: remove any ImplStored tags on outputs to avoid redundant + # store-load operations (see https://github.com/inducer/pytato/issues/415) + # (This must be done after all the calls have been inlined) + outputs = DictOfNamedArrays( + {name: (output.without_tags(ImplStored(), + verify_existence=False) + if not isinstance(output, + InputArgumentBase) + else output) + for name, output in outputs._data.items()}, + tags=outputs.tags) + compute_order = preproc_result.compute_order if options is None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index cd5bb1680..def11668a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pytools import memoize_method + __copyright__ = """ Copyright (C) 2020 Matt Wala Copyright (C) 2020-21 Kaushik Kulkarni @@ -29,6 +31,7 @@ import logging import numpy as np from abc import abstractmethod +from immutables import Map from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, Hashable) @@ -44,10 +47,12 @@ from pytato.distributed.nodes import ( DistributedSendRefHolder, DistributedRecv, DistributedSend) from pytato.loopy import LoopyCall, LoopyCallResult +from pytato.function import Call, NamedCallResult, FunctionDefinition from dataclasses import dataclass from pytato.tags import ImplStored from pymbolic.mapper.optimize import optimize_mapper + ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) @@ -57,6 +62,7 @@ CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = FrozenSet[Array] +_SelfMapper = TypeVar("_SelfMapper", bound="Mapper") __doc__ = """ .. currentmodule:: pytato.transform @@ -91,6 +97,14 @@ .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes + +Transforming call sites +----------------------- + +.. automodule:: pytato.transform.calls + +.. currentmodule:: pytato.transform + Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -101,6 +115,11 @@ .. class:: CombineT A type variable representing the type of a :class:`CombineMapper`. + +.. class:: _SelfMapper + + A type variable used to represent the type of a mapper in + :meth:`CopyMapper.clone_for_callee`. """ transform_logger = logging.getLogger(__file__) @@ -214,6 +233,8 @@ class CopyMapper(CachedMapper[ArrayOrNames]): The typical use of this mapper is to override individual ``map_`` methods in subclasses to permit term rewriting on an expression graph. + .. automethod:: clone_for_callee + .. note:: This does not copy the data of a :class:`pytato.array.DataWrapper`. @@ -230,6 +251,13 @@ def __call__(self, # type: ignore[override] expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)() + def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars @@ -373,6 +401,32 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + @memoize_method + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + # spawn a new to avoid unsound cache hits, since the namespace of the + # function's body is different from that of the caller. + new_mapper = self.clone_for_callee() + new_returns = {name: new_mapper(ret) + for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags + ) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function), + Map({name: self.rec(bnd) + for name, bnd in expr.bindings.items()}), + result_axes=expr.result_axes, + result_tags=expr.result_tags, + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + return NamedCallResult(self.rec(expr._container), expr.name) + class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): """ @@ -576,6 +630,26 @@ def map_distributed_recv(self, expr: DistributedRecv, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> FunctionDefinition: + raise NotImplementedError("Function definitions are purposefully left" + " unimplemented as the default arguments to a new" + " DAG traversal are tricky to guess.") + + def map_call(self, expr: Call, + *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function, *args, **kwargs), + Map({name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()}), + result_axes=expr.result_axes, + result_tags=expr.result_tags, + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> Array: + return NamedCallResult(self.rec(expr._container, *args, **kwargs), expr.name) + # }}} @@ -686,6 +760,18 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_function_definition(self, expr: FunctionDefinition) -> CombineT: + raise NotImplementedError("Combining results from a callee expression" + " is context-dependent. Derived classes" + " must override map_function_definition.") + + def map_call(self, expr: Call) -> CombineT: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> CombineT: + return self.rec(expr._container) + # }}} @@ -753,6 +839,18 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + def map_function_definition(self, expr: FunctionDefinition) -> R: + # do not include arrays from the function's body as it would involve + # putting arrays from different namespaces into the same collection. + return frozenset() + + def map_call(self, expr: Call) -> R: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> R: + return self.rec(expr._container) + # }}} @@ -797,6 +895,23 @@ def map_data_wrapper(self, expr: DataWrapper) -> FrozenSet[InputArgumentBase]: def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[InputArgumentBase]: + # get rid of placeholders local to the function. + new_mapper = InputGatherer() + all_callee_inputs = new_mapper.combine(*[new_mapper(ret) + for ret in expr.returns.values()]) + result: Set[InputArgumentBase] = set() + for inp in all_callee_inputs: + if isinstance(inp, Placeholder): + if inp.name not in expr.parameters: + result.add(inp) + else: + result.add(inp) + + return frozenset(result) + # }}} @@ -815,6 +930,12 @@ def combine(self, *args: FrozenSet[SizeParam] def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[SizeParam]: + return self.combine(*[self.rec(ret) + for ret in expr.returns.values()]) + # }}} @@ -831,6 +952,9 @@ class WalkMapper(Mapper): .. automethod:: post_visit """ + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + return type(self)() + def visit(self, expr: Any, *args: Any, **kwargs: Any) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. @@ -983,6 +1107,36 @@ def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> None: self.post_visit(expr, *args, **kwargs) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + self.map_function_definition(expr.function) + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.post_visit(expr) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr._container, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + # }}} @@ -1020,6 +1174,11 @@ class TopoSortMapper(CachedWalkMapper): """A mapper that creates a list of nodes in topological order. :members: topological_order + + .. note:: + + Does not consider the nodes inside a + :class:`~pytato.function.FunctionDefinition`. """ def __init__(self) -> None: @@ -1034,6 +1193,12 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] def post_visit(self, expr: Any) -> None: # type: ignore[override] self.topological_order.append(expr) + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def map_function_definition(self, # type: ignore[override] + expr: FunctionDefinition) -> None: + # do nothing as it includes arrays from a different namespace. + return + # }}} @@ -1049,6 +1214,11 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: super().__init__() self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ and does not have map_fn + return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] + # type-ignore-reason:incompatible with Mapper.rec() def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] if expr in self._cache: @@ -1103,7 +1273,14 @@ def _materialize_if_mpms(expr: Array, class MPMSMaterializer(Mapper): - """See :func:`materialize_with_mpms` for an explanation.""" + """ + See :func:`materialize_with_mpms` for an explanation. + + .. attribute:: nsuccessors + + A mapping from a node in the expression graph (i.e. an + :class:`~pytato.Array`) to its number of successors. + """ def __init__(self, nsuccessors: Mapping[Array, int]): super().__init__() self.nsuccessors = nsuccessors @@ -1253,6 +1430,75 @@ def map_distributed_recv(self, expr: DistributedRecv ) -> MPMSMaterializerAccumulator: return MPMSMaterializerAccumulator(frozenset([expr]), expr) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FunctionDefinition: + # spawn a new traversal here. + from pytato.analysis import get_nusers + + returns_dict_of_named_arys = DictOfNamedArrays(expr.returns) + func_nsuccessors = get_nusers(returns_dict_of_named_arys) + new_mapper = MPMSMaterializer(func_nsuccessors) + new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags + ) + + @memoize_method + def map_call(self, expr: Call) -> Call: + from functools import reduce + from typing import cast + new_function = self.map_function_definition(expr.function) + rec_bindings = {name: self.rec(bnd) + for name, bnd in expr.bindings.items()} + + materialized_predecessors = reduce(frozenset.union, + (rec_bnd.materialized_predecessors + for rec_bnd in rec_bindings.values()), + cast(FrozenSet[Array], frozenset())) + new_expr = Call(new_function, + Map({name: rec_bnd.expr + for name, rec_bnd in rec_bindings.items()}), + result_tags=expr.result_tags, + result_axes=expr.result_axes, + tags=expr.tags, + ) + + for name, named_result in expr.items(): + nsuccs = self.nsuccessors[named_result] + if nsuccs > 1 and len(materialized_predecessors) > 1: + new_expr = new_expr.with_result_tagged(name, ImplStored()) + else: + pass + + return new_expr + + def map_named_call_result(self, expr: NamedCallResult + ) -> MPMSMaterializerAccumulator: + assert isinstance(expr._container, Call) + new_call = self.map_call(expr._container) + new_result = new_call[expr.name] + + assert isinstance(new_result, NamedCallResult) + assert isinstance(new_result._container, Call) + + # do not use _materialize_if_mpms as tagging a NamedArray is illegal. + if new_result.tags_of_type(ImplStored): + return MPMSMaterializerAccumulator(frozenset([new_result]), + new_result) + else: + from functools import reduce + materialized_predecessors: FrozenSet[Array] = ( + reduce(frozenset.union, + (self.rec(bnd).materialized_predecessors + for bnd in new_result._container.bindings.values()), + frozenset()) + ) + return MPMSMaterializerAccumulator(materialized_predecessors, + new_result) + # }}} @@ -1488,9 +1734,26 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) + def map_function_definition(self, expr: FunctionDefinition, *args: Any + ) -> None: + raise AssertionError("Control shouldn't reach at this point." + " Instantiate another UsersCollector to" + " traverse the callee function.") + + def map_call(self, expr: Call, *args: Any) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + def map_named_call(self, expr: NamedCallResult, *args: Any) -> None: + assert isinstance(expr._container, Call) + for bnd in expr._container.bindings.values(): + self.node_to_users.setdefault(bnd, set()).add(expr) + + self.rec(expr._container) + def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, - Set[ArrayOrNames]]: + Set[ArrayOrNames]]: """ Returns a mapping from node in *expr* to its direct users. """ @@ -1740,6 +2003,10 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ # }}} + def map_named_call_result(self, expr: NamedCallResult, *args: Any) -> Any: + # TODO + raise NotImplementedError() + # }}} diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py new file mode 100644 index 000000000..865174b5d --- /dev/null +++ b/pytato/transform/calls.py @@ -0,0 +1,1540 @@ +from __future__ import annotations + +__doc__ = """ +.. currentmodule:: pytato.transform.calls + +.. autofunction:: inline_calls +.. autofunction:: concatenate_calls +.. autofunction:: tag_all_calls_to_be_inlined + +.. autoclass:: CallSiteLocation +""" +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import itertools +import attrs +import pymbolic.primitives as prim +import pytato.scalar_expr as scalar_expr + +from functools import partialmethod +from immutables import Map +from typing import (Tuple, FrozenSet, Collection, Mapping, Any, List, Dict, + TYPE_CHECKING, Sequence, Callable, Set, Generator) +from pytools import memoize_method, memoize_on_first_arg +from pytato.transform import (ArrayOrNames, CopyMapper, CombineMapper, Mapper, + CachedMapper, _SelfMapper, + CachedWalkMapper) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.array import (AbstractResultWithNamedArrays, Array, + DictOfNamedArrays, IndexBase, Placeholder, + SizeParam, InputArgumentBase, concatenate, + IndexLambda, Roll, Stack, Concatenate, + Einsum, AxisPermutation, + Reshape, BasicIndex, DataWrapper, ShapeComponent, + ShapeType) + +from pytato.function import Call, NamedCallResult, FunctionDefinition +from pytato.tags import InlineCallTag +from pytato.utils import are_shape_components_equal +import logging +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from pytato.loopy import LoopyCallResult + +ArrayOnStackT = Tuple[Tuple[Call, ...], Array] + + +# {{{ inlining + +class PlaceholderSubstitutor(CopyMapper): + """ + .. attribute:: substitutions + + A mapping from the placeholder name to the array that it is to be + substituted with. + """ + def __init__(self, substitutions: Map[str, Array]) -> None: + super().__init__() + self.substitutions = substitutions + + def map_placeholder(self, expr: Placeholder) -> Array: + return self.substitutions[expr.name] + + +class Inliner(CopyMapper): + """ + Primary mapper for :func:`inline_calls`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + # inline call-sites within the callee. + new_expr = super().map_call(expr) + assert isinstance(new_expr, Call) + + if expr.tags_of_type(InlineCallTag): + substitutor = PlaceholderSubstitutor(expr.bindings) + + return DictOfNamedArrays( + {name: substitutor(ret) + for name, ret in new_expr.function.returns.items()}, + tags=expr.tags + ) + else: + return new_expr + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + new_call = self.rec(expr._container) + assert isinstance(new_call, AbstractResultWithNamedArrays) + return new_call[expr.name] + + +class InlineMarker(CopyMapper): + """ + Primary mapper for :func:`tag_all_calls_to_be_inlined`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return super().map_call(expr).tagged(InlineCallTag()) + + +def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with call-sites tagged with + :class:`pytato.tags.InlineCallTag` inlined into the expression graph. + """ + inliner = Inliner() + # type-ignore-reason: CopyMapper.__call__ does not specify types + return inliner(expr) # type: ignore[no-any-return] + + +def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with all reachable instances of + :class:`pytato.function.Call` nodes tagged with + :class:`pytato.tags.InlineCallTag`. + + .. note:: + + This routine does NOT inline calls, to inline the calls + use :func:`tag_all_calls_to_be_inlined` on this routine's + output. + """ + # type-ignore-reason: CopyMapper.__call__ does not specify types + return InlineMarker()(expr) # type: ignore[no-any-return] + +# }}} + + +# {{{ Concatenatability + +@attrs.define(frozen=True) +class Concatenatability: + """ + Describes how a particular array expression can be concatenated. + """ + + +@attrs.define(frozen=True) +class ConcatableAlongAxis(Concatenatability): + """ + Used to describe an array expression that is concatenatable along *axis*. + """ + axis: int + + +@attrs.define(frozen=True) +class ConcatableIfConstant(Concatenatability): + """ + Used to describe an array expression in a function body that can be + concatenated only if the expression is the same across call-sites. + """ + +# }}} + + +# {{{ concatenate_calls + +@attrs.define(frozen=True) +class CallSiteLocation: + r""" + Records a call-site's location in a :mod:`pytato` expression. + + .. attribute:: call + + The instance of :class:`~pytato.function.Call` being called at this + location. + + .. attribute:: stack + + The call sites within which this particular call is called. + For eg. if ``stack = (c1, c2)``, then :attr:`call` is called within + ``c2``\ 's function body which itself is called from ``c1``\ 's + function body. + """ + call: Call + stack: Tuple[Call, ...] + + +class CallsiteCollector(CombineMapper[FrozenSet[CallSiteLocation]]): + r""" + Collects all the call sites in a :mod:`pytato` expression. + + .. attribute:: stack + + The stack of calls at which the calls are being collected. This + attribute is used to specify :attr:`CallSiteLocation.stack` in the + :class:`CallSiteLocation`\ s being built. Must be altered (by creating + a new instance of the mapper) before entering the function body of a + new :class:`~pytato.function.Call`. + """ + def __init__(self, stack: Tuple[Call, ...]) -> None: + self.stack = stack + super().__init__() + + def combine(self, *args: FrozenSet[CallSiteLocation] + ) -> FrozenSet[CallSiteLocation]: + from functools import reduce + return reduce(lambda a, b: a | b, args, frozenset()) + + def map_size_param(self, expr: SizeParam) -> FrozenSet[CallSiteLocation]: + return frozenset() + + def map_call(self, expr: Call) -> FrozenSet[CallSiteLocation]: + new_mapper_for_fn = CallsiteCollector(stack=self.stack + (expr,)) + return self.combine(frozenset([CallSiteLocation(expr, self.stack)]), + *[new_mapper_for_fn(ret) + for ret in expr.function.returns.values()]) + + +class _NamedCallResultReplacerPostConcatenate(CopyMapper): + """ + Mapper to replace instances of :class:`~pytato.function.NamedCallResult` as + per :attr:`replacement_map`. + + .. attribute:: stack_to_replace_on + + The stack onto which the replacement must be performed. + + .. attribute:: current_stack + + Records the stack to track which function body the mapper is + traversing. Must be altered (by creating a new instance) before + entering the function body of a new :class:`~pytato.function.Call`. + """ + def __init__(self, + replacement_map: Mapping[NamedCallResult, Array], + current_stack: Tuple[Call, ...], + stack_to_replace_on: Tuple[Call, ...]) -> None: + self.replacement_map = replacement_map + self.current_stack = current_stack + self.stack_to_replace_on = stack_to_replace_on + super().__init__() + + @memoize_method + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + @memoize_method + def clone_with_new_call_on_stack(self: _SelfMapper, expr: Call) -> _SelfMapper: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.replacement_map, # type: ignore[attr-defined] + self.current_stack + (expr,), # type: ignore[attr-defined] + self.stack_to_replace_on, # type: ignore[attr-defined] + ) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + new_stack_to_enter = self.current_stack + (expr,) + if self.stack_to_replace_on[:len(new_stack_to_enter)] == new_stack_to_enter: + # leading call-sites on the stack match the stack on which + # replacement must be performed. + new_mapper = self.clone_with_new_call_on_stack(expr) + + return Call(new_mapper.map_function_definition(expr.function), + Map({name: self.rec(bnd) + for name, bnd in expr.bindings.items()}), + result_axes=expr.result_axes, + result_tags=expr.result_tags, + tags=expr.tags) + else: + return Call(expr.function, # do not map the exprs in function's body. + Map({name: self.rec(bnd) + for name, bnd in expr.bindings.items()}), + result_axes=expr.result_axes, + result_tags=expr.result_tags, + tags=expr.tags) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + if self.current_stack == self.stack_to_replace_on: + try: + return self.replacement_map[expr] + except KeyError: + return super().map_named_call_result(expr) + else: + return super().map_named_call_result(expr) + + +def _have_same_axis_length(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same axis length + along *iaxis*. + """ + axis_length = next(iter(arrays)).shape[iaxis] + return all(are_shape_components_equal(other_ary.shape[iaxis], + axis_length) + for other_ary in arrays) + + +def _have_same_axis_length_except(arrays: Collection[Array], + iaxis: int) -> bool: + """ + Returns *True* only if every array in *arrays* have the same + dimensionality and have axes with the same lengths except along the + *iaxis*-axis. + """ + ndim = next(iter(arrays)).ndim + return (all(ary.ndim == ndim for ary in arrays) + and all(_have_same_axis_length(arrays, idim) + for idim in range(ndim) + if idim != iaxis)) + + +@attrs.define(frozen=True) +class _InputConcatabilityGetterAcc: + r""" + Return type for :class:`_InputConcatabilityGetter`. An instance of this class is + returned after mapping a :class:`~pytato.Array` expression. + + .. attribute:: seen_inputs + + A :class:`frozenset` of all :class:`pytato.InputArgumentBase` + predecessors of a node. + + .. attribute:: input_concatability + + Records the constraints that come along with concatenating the array + being mapped. The constraints are recorded as a mapping from the axes + of the array being mapped to the axes of the input arguments. This + mapping informs us which axes in the :class:`InputArgumentBase`\ s' + must be concatenated to soundly concatenate a particular axis in the + array being mapped. The axes in this mapping are represented using + :class:`int`. If certain axes are missing in this mapping, then + concatenation cannot be performed along those axes for the mapped + array. + """ + seen_inputs: FrozenSet[InputArgumentBase] + input_concatability: Mapping[Concatenatability, + Mapping[InputArgumentBase, Concatenatability]] + + def __post_init__(self) -> None: + assert all( + frozenset(input_concat.keys()) == self.seen_inputs + for input_concat in self.input_concatability.values()) + + __attrs_post_init__ = __post_init__ + + +class NonConcatableExpression(RuntimeError): + """ + Used internally by :class:`_ScalarExprConcatabilityMapper`. + """ + + +class _InvalidConcatenatability(RuntimeError): + """ + Used internally by :func:`_get_ary_to_concatenatabilities`. + """ + + +class _ScalarExprConcatabilityMapper(scalar_expr.CombineMapper): + """ + Maps :attr:`~pytato.array.IndexLambda.expr` to the axes of the bindings + that must be concatenated to concatenate the IndexLambda's + :attr:`iaxis`-axis. + + .. attribute:: allow_indirect_addr + + If *True* indirect access are allowed. However, concatenating along the + iaxis-axis would be sound only if the binding which is being indexed + into is same for all the expressions to be concatenated. + """ + def __init__(self, iaxis: int, allow_indirect_addr: bool) -> None: + self.iaxis = iaxis + self.allow_indirect_addr = allow_indirect_addr + super().__init__() + + def combine(self, values: Collection[Mapping[str, Concatenatability]] + ) -> Mapping[str, Concatenatability]: + result: Dict[str, Concatenatability] = {} + for value in values: + for bnd_name, iaxis in value.items(): + try: + if result[bnd_name] != iaxis: + # only one axis of a particular binding can be + # concatenated. If multiple axes must be concatenated + # that means the index lambda is not + # iaxis-concatenatable. + raise NonConcatableExpression + except KeyError: + result[bnd_name] = iaxis + + return Map(result) + + def map_variable(self, expr: prim.Variable) -> Mapping[str, Concatenatability]: + if expr.name == f"_{self.iaxis}": + raise NonConcatableExpression + else: + return Map() + + def map_constant(self, expr: Any) -> Mapping[str, Concatenatability]: + return Map() + + map_nan = map_constant + + def map_subscript(self, expr: prim.Subscript + ) -> Mapping[str, Concatenatability]: + name: str = expr.aggregate.name + rec_indices: List[Mapping[str, Concatenatability]] = [] + for iaxis, idx in enumerate(expr.index_tuple): + if idx == prim.Variable(f"_{self.iaxis}"): + rec_indices.append({name: ConcatableAlongAxis(iaxis)}) + else: + rec_idx = self.rec(idx) + if rec_idx: + if not self.allow_indirect_addr: + raise NonConcatableExpression + else: + # indirect accesses cannot be concatenated in the general + # case unless the indexee is the same for the + # expression graphs being concatenated. + pass + rec_indices.append(rec_idx) + + combined_rec_indices = dict(self.combine(rec_indices)) + + if name not in combined_rec_indices: + combined_rec_indices[name] = ConcatableIfConstant() + + return Map(combined_rec_indices) + + +@memoize_on_first_arg +def _get_binding_to_concatenatability(expr: scalar_expr.ScalarExpression, + iaxis: int, + allow_indirect_addr: bool, + ) -> Mapping[str, Concatenatability]: + """ + Maps *expr* using :class:`_ScalarExprConcatabilityMapper`. + """ + mapper = _ScalarExprConcatabilityMapper(iaxis, allow_indirect_addr) + return mapper(expr) # type: ignore[no-any-return] + + +def _combine_input_accs( + operand_accs: Tuple[_InputConcatabilityGetterAcc, ...], + concat_to_operand_concatabilities: Mapping[Concatenatability, + Tuple[Concatenatability, ...] + ], +) -> _InputConcatabilityGetterAcc: + """ + For an index lambda ``I`` with operands ``I1, I2, .. IN`` that specify their + concatenatability constraints using *operand_accs*, this routine returns + the axes concatenation constaints of ``I``. + + :arg concat_to_operand_concatabilities: Mapping of the form ``concat_I -> + (C_I1, C_I2, ..., C_IN)`` specifying the concatabilities of the + operands ``I1, I2, .., IN`` in order to concatenate the + ``I`` axis via the criterion ``conncat_I``. + """ + + from functools import reduce + + input_concatabilities: Dict[Concatenatability, Map[InputArgumentBase, + Concatenatability]] = {} + seen_inputs: FrozenSet[InputArgumentBase] = reduce( + frozenset.union, + (operand_acc.seen_inputs for operand_acc in operand_accs), + frozenset()) + + # The core logic here is to filter the iaxis in out_axis_to_operand_axes + # so that all the operands agree on how the input arguments must be + # concatenated. + + for out_concat, operand_concatabilities in (concat_to_operand_concatabilities + .items()): + is_i_out_axis_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for operand_concatability, operand_acc in zip(operand_concatabilities, + operand_accs, + strict=True): + if operand_concatability not in ( + operand_acc.input_concatability): + # required operand concatability cannot be achieved + # => out_concat cannot be concatenated + is_i_out_axis_concatenatable = False + break + + for input_arg, input_concat in ( + operand_acc + .input_concatability[operand_concatability] + .items()): + try: + if input_concatability[input_arg] != input_concat: + is_i_out_axis_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = input_concat + if not is_i_out_axis_concatenatable: + break + + if is_i_out_axis_concatenatable: + input_concatabilities[out_concat] = Map(input_concatability) + + return _InputConcatabilityGetterAcc(seen_inputs, + Map(input_concatabilities)) + + +@attrs.define(frozen=True) +class FunctionConcatenability: + r""" + Records a valid concatenatability criterion for a + :class:`pytato.function.FunctionDefinition`. + + .. attribute:: output_to_concatenatability + + A mapping from the name of a + :class:`FunctionDefinition`\ 's returned array to how it should be + concatenated. + + .. attribute:: input_to_concatenatability + + A mapping from a :class:`FunctionDefinition`\ 's parameter to how it + should be concatenated. + + .. note:: + + A :class:`FunctionDefinition` typically has multiple valid + concatenability constraints. This class only records one of those valid + constraints. + """ + output_to_concatenatability: Mapping[str, Concatenatability] + input_to_concatenatability: Mapping[str, Concatenatability] + + def __str__(self) -> str: + outputs = [] + for name, concat in self.output_to_concatenatability.items(): + outputs.append(f"{name} => {concat}") + + inputs = [] + for name, concat in self.input_to_concatenatability.items(): + inputs.append(f"{name} => {concat}") + + output_str = "\n".join(outputs) + input_str = "\n".join(inputs) + + return (f"Outputs:\n--------\n{output_str}\n" + f"========\nInputs:\n-------\n{input_str}\n" + "========") + + +def _combine_named_result_accs( + named_result_accs: Mapping[str, _InputConcatabilityGetterAcc] +) -> Tuple[FunctionConcatenability, ...]: + """ + Combines the concantenatability constraints of named results of a + :class:`FunctionDefinition` and returns a :class:`tuple` of the valid + concatenatable constraints. + """ + potential_concatenatable_ouptut_axes = itertools.product(*[ + [(name, concat) for concat in acc.input_concatability] + for name, acc in named_result_accs.items()]) + + valid_concatenatabilities: List[FunctionConcatenability] = [] + + for output_concats in potential_concatenatable_ouptut_axes: + is_concatenatable = True + input_concatability: Dict[InputArgumentBase, Concatenatability] = {} + + for result_name, iresult_axis in output_concats: + for input_arg, i_input_axis in ( + named_result_accs[result_name] + .input_concatability[iresult_axis] + .items()): + try: + if input_concatability[input_arg] != i_input_axis: + is_concatenatable = False + break + except KeyError: + input_concatability[input_arg] = i_input_axis + + if not is_concatenatable: + break + + if is_concatenatable: + pl_concatabilities = {pl.name: concat + for pl, concat in input_concatability.items() + if isinstance(pl, Placeholder)} + valid_concatenatabilities.append( + FunctionConcatenability(Map(output_concats), + Map(pl_concatabilities)) + ) + + return tuple(valid_concatenatabilities) + + +class _InputConcatabilityGetter(CachedMapper[ArrayOrNames]): + """ + Maps :class:`pytato.array.Array` expressions to + :class:`_InputConcatenatabilityGetterAcc` that summarizes constraints + induced on the concatenatability of the inputs of the expression by the + expression's concatenatability. + """ + + def _map_input_arg_base(self, expr: InputArgumentBase + ) -> _InputConcatabilityGetterAcc: + input_concatenatability: Dict[Concatenatability, + Map[InputArgumentBase, + Concatenatability]] = {} + for idim in range(expr.ndim): + input_concatenatability[ConcatableAlongAxis(idim)] = Map( + {expr: ConcatableAlongAxis(idim)}) + + input_concatenatability[ConcatableIfConstant()] = Map( + {expr: ConcatableIfConstant()}) + + return _InputConcatabilityGetterAcc(frozenset([expr]), + Map(input_concatenatability)) + + map_placeholder = _map_input_arg_base + map_data_wrapper = _map_input_arg_base + + def _map_index_lambda_like(self, expr: Array, + allow_indirect_addr: bool + ) -> _InputConcatabilityGetterAcc: + expr = to_index_lambda(expr) + input_accs = tuple(self.rec(expr.bindings[name]) + for name in sorted(expr.bindings.keys())) + expr_concat_to_input_concats: Dict[Concatenatability, + Tuple[Concatenatability, ...]] = {} + + for iaxis in range(expr.ndim): + try: + bnd_name_to_concat = _get_binding_to_concatenatability( + expr.expr, iaxis, allow_indirect_addr) + expr_concat_to_input_concats[ConcatableAlongAxis(iaxis)] = ( + tuple(concat + for _, concat in sorted(bnd_name_to_concat.items(), + key=lambda x: x[0])) + ) + except NonConcatableExpression: + pass + + expr_concat_to_input_concats[ConcatableIfConstant()] = tuple( + ConcatableIfConstant() for _ in expr.bindings) + + return _combine_input_accs(input_accs, expr_concat_to_input_concats) + + map_index_lambda = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_einsum = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_basic_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_roll = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_stack = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_concatenate = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_axis_permutation = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + map_reshape = partialmethod(_map_index_lambda_like, + allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod(_map_index_lambda_like, + allow_indirect_addr=True) + + def map_named_call_result(self, expr: NamedCallResult + ) -> _InputConcatabilityGetterAcc: + assert isinstance(expr._container, Call) + valid_concatenatabilities = _get_valid_concatenatability_constraints( + expr._container.function) + + expr_concat_possibilities = { + valid_concatenability.output_to_concatenatability[expr.name] + for valid_concatenability in valid_concatenatabilities + } + + input_concatenatabilities: Dict[Concatenatability, + Mapping[InputArgumentBase, + Concatenatability]] = {} + rec_bindings = {bnd_name: self.rec(binding) + for bnd_name, binding in expr._container.bindings.items()} + callee_acc = self.rec(expr._container.function.returns[expr.name]) + seen_inputs: Set[InputArgumentBase] = set() + + for seen_input in callee_acc.seen_inputs: + if isinstance(seen_input, Placeholder): + seen_inputs.update(rec_bindings[seen_input.name].seen_inputs) + elif isinstance(seen_input, (DataWrapper, SizeParam)): + seen_inputs.add(seen_input) + else: + raise NotImplementedError(type(seen_input)) + + for concat_possibility in expr_concat_possibilities: + caller_input_concatabilities: Dict[InputArgumentBase, + Concatenatability] = {} + + is_concat_possibility_valid = True + for callee_input_arg, callee_input_concat in ( + callee_acc.input_concatability[concat_possibility].items()): + caller_acc = rec_bindings[callee_input_arg.name] + if isinstance(callee_input_arg, Placeholder): + if callee_input_concat in caller_acc.input_concatability: + for caller_input_arg, caller_input_concat in ( + caller_acc + .input_concatability[callee_input_concat] + .items()): + try: + if (caller_input_concatabilities[caller_input_arg] + != caller_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + caller_input_concat) + if not is_concat_possibility_valid: + break + else: + is_concat_possibility_valid = False + break + elif isinstance(callee_input_arg, (DataWrapper, SizeParam)): + try: + if (caller_input_concatabilities[callee_input_arg] + != callee_input_concat): + is_concat_possibility_valid = False + break + except KeyError: + caller_input_concatabilities[callee_input_arg] = ( + callee_input_concat) + else: + raise NotImplementedError(type(callee_input_arg)) + + if is_concat_possibility_valid: + input_concatenatabilities[concat_possibility] = Map( + caller_input_concatabilities) + + return _InputConcatabilityGetterAcc(frozenset(seen_inputs), + Map(input_concatenatabilities)) + + def map_loopy_call_result(self, expr: "LoopyCallResult" + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +def _verify_arrays_can_be_concated_along_axis( + arrays: Collection[Array], + fields_that_must_be_same: Collection[str], + iaxis: int) -> None: + """ + Performs some common checks if *arrays* from different function bodies can be + concatenated. + + .. attribute:: arrays + + Corresponding expressions from the function bodies for call-site that + are being checked for concatenation along *iaxis*. + """ + if not _have_same_axis_length_except(arrays, iaxis): + raise _InvalidConcatenatability("Cannot be concatenate the calls.") + if len({ary.__class__ for ary in arrays}) != 1: + raise _InvalidConcatenatability("Cannot be concatenate the calls.") + for field in fields_that_must_be_same: + if len({getattr(ary, field) for ary in arrays}) != 1: + raise _InvalidConcatenatability("Cannot be concatenate the calls.") + + +def _verify_arrays_same(arrays: Collection[Array]) -> None: + if len(set(arrays)) != 1: + raise _InvalidConcatenatability("Cannot be concatenated as arrays across " + " functions is not the same.") + + +def _get_concatenated_shape(arrays: Collection[Array], iaxis: int) -> ShapeType: + # type-ignore-reason: mypy expects 'ary.shape[iaxis]' as 'int' since the + # 'start' is an 'int' + concatenated_axis_length = sum(ary.shape[iaxis] # type: ignore[misc] + for ary in arrays) + template_ary = next(iter(arrays)) + + return tuple(dim + if idim != iaxis + else concatenated_axis_length + for idim, dim in enumerate(template_ary.shape) + ) + + +class _ConcatabilityCollector(CachedWalkMapper): + def __init__(self, current_stack: Tuple[Call, ...]) -> None: + self.ary_to_concatenatability: Dict[ArrayOnStackT, Concatenatability] = {} + self.current_stack = current_stack + self.call_sites_on_hold: Set[Call] = set() + super().__init__() + + # type-ignore-reason: CachedWalkMaper takes variadic `*args, **kwargs`. + def get_cache_key(self, # type: ignore[override] + expr: ArrayOrNames, + *args: Any, + ) -> Tuple[ArrayOrNames, Any]: + return (expr, args) + + def _record_concatability(self, expr: Array, + concatenatability: Concatenatability, + ) -> None: + key = (self.current_stack, expr) + assert key not in self.ary_to_concatenatability + self.ary_to_concatenatability[key] = concatenatability + + @memoize_method + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + raise AssertionError("Control should not reach here." + " Call clone_with_new_call_on_stack instead.") + + @memoize_method + def clone_with_new_call_on_stack(self: _SelfMapper, expr: Call) -> _SelfMapper: + # type-ignore-reason: Mapper class does not define these attributes. + return type(self)( # type: ignore[call-arg] + self.current_stack + (expr,), # type: ignore[attr-defined] + ) + + def _map_input_arg_base(self, + expr: InputArgumentBase, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((expr,) + exprs_from_other_calls) + elif isinstance(concatenatability, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + (expr,) + exprs_from_other_calls, + ["dtype", "name"], + concatenatability.axis) + else: + raise NotImplementedError(type(concatenatability)) + + self._record_concatability(expr, concatenatability) + + map_placeholder = _map_input_arg_base # type: ignore[assignment] + map_data_wrapper = _map_input_arg_base # type: ignore[assignment] + + def _map_index_lambda_like(self, + expr: Array, + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + allow_indirect_addr: bool, + ) -> None: + self._record_concatability(expr, concatenatability) + + idx_lambda = to_index_lambda(expr) + idx_lambdas_from_other_calls = tuple(to_index_lambda(ary) + for ary in exprs_from_other_calls) + + if isinstance(concatenatability, ConcatableIfConstant): + _verify_arrays_same((idx_lambda,) + idx_lambdas_from_other_calls) + for ary in idx_lambda.bindings.values(): + self.rec(ary, concatenatability) + elif isinstance(concatenatability, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + (idx_lambda, ) + idx_lambdas_from_other_calls, + ["dtype", "expr"], + concatenatability.axis) + bnd_name_to_concat = _get_binding_to_concatenatability( + idx_lambda.expr, concatenatability.axis, allow_indirect_addr) + for bnd_name, bnd_concat in bnd_name_to_concat.items(): + self.rec(idx_lambda.bindings[bnd_name], bnd_concat, + tuple(ary.bindings[bnd_name] + for ary in idx_lambdas_from_other_calls)) + else: + raise NotImplementedError(type(concatenatability)) + + map_index_lambda = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_einsum = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_basic_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_roll = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_stack = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_concatenate = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_axis_permutation = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + map_reshape = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=False) + + map_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + map_non_contiguous_advanced_index = partialmethod( # type: ignore[assignment] + _map_index_lambda_like, allow_indirect_addr=True) + + # type-ignore-reason: CachedWalkMapper.map_call takes in variadic args, kwargs + def map_call(self, # type: ignore[override] + expr: Call, + exprs_from_other_calls: Tuple[Call, ...]) -> None: + if not all( + (self.current_stack, named_result) in self.ary_to_concatenatability + for named_result in expr.values()): + self.call_sites_on_hold.add(expr) + else: + self.call_sites_on_hold -= {expr} + new_mapper = self.clone_with_new_call_on_stack(expr) + for name, val_in_callee in expr.function.returns.items(): + new_mapper(val_in_callee, + self.ary_to_concatenatability[(self.current_stack, + expr[name])], + tuple(other_call.function.returns[name] + for other_call in exprs_from_other_calls) + ) + + if new_mapper.call_sites_on_hold: + raise NotImplementedError("Call sites that do not all use all" + " the returned values not yet" + " supported for concatenation.") + + for ary, concat in new_mapper.ary_to_concatenatability.items(): + assert ary not in self.ary_to_concatenatability + self.ary_to_concatenatability[ary] = concat + + for name, binding in expr.bindings.items(): + concat = ( + new_mapper + .ary_to_concatenatability[(self.current_stack + (expr,), + expr.function.get_placeholder(name))] + ) + self.rec(binding, + concat, + tuple(other_call.bindings[name] + for other_call in exprs_from_other_calls)) + + # type-ignore-reason: CachedWalkMapper's method takes in variadic args, kwargs + def map_named_call_result(self, expr: NamedCallResult, # type: ignore[override] + concatenatability: Concatenatability, + exprs_from_other_calls: Tuple[Array, ...], + ) -> None: + self._record_concatability(expr, concatenatability) + if any(not isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls): + raise _InvalidConcatenatability() + + # type-ignore-reason: mypy does not respect the conditional which + # asserts that all arrays in `exprs_from_other_calls` are + # NamedCallResult. + self.rec(expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls) + ) + + def map_loopy_call_result(self, expr: "LoopyCallResult" + ) -> None: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +class _Concatenator(Mapper): + def __init__(self, + current_stack: Tuple[Call, ...], + ary_to_concatenatability: Map[ArrayOnStackT, Concatenatability], + ) -> None: + self.current_stack = current_stack + self.ary_to_concatenatability = ary_to_concatenatability + + self._cache: Dict[Tuple[Array, Tuple[Array, ...]], Array] = {} + + # type-ignore-reason: super-type Mapper does not allow the extra args. + def rec(self, expr: Array, # type: ignore[override] + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + key = (expr, exprs_from_other_calls) + try: + return self._cache[key] + except KeyError: + result: Array = super().rec(expr, + exprs_from_other_calls) + self._cache[key] = result + return result + + @memoize_method + def clone_with_new_call_on_stack(self, expr: Call) -> _Concatenator: + return _Concatenator( + self.current_stack + (expr,), + self.ary_to_concatenatability, + ) + + def _get_concatenatability(self, expr: Array) -> Concatenatability: + return self.ary_to_concatenatability[(self.current_stack, expr)] + + def map_placeholder(self, + expr: Placeholder, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_shape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + return Placeholder(name=expr.name, + dtype=expr.dtype, + shape=new_shape, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_data_wrapper(self, + expr: DataWrapper, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + return concatenate((expr,) + exprs_from_other_calls, concat.axis) + else: + raise NotImplementedError(type(concat)) + + def map_index_lambda(self, + expr: IndexLambda, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + concat = self._get_concatenatability(expr) + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexLambda) + for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexLambda. + new_bindings = { + bnd_name: self.rec( + subexpr, + tuple(ary.bindings[bnd_name] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for bnd_name, subexpr in expr.bindings.items() + } + new_shape = _get_concatenated_shape((expr,) + exprs_from_other_calls, + concat.axis) + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=Map(new_bindings), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_einsum(self, expr: Einsum, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Einsum) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Einsum. + new_args = [self.rec(arg, + tuple(ary.args[iarg] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + for iarg, arg in enumerate(expr.args)] + + return Einsum(expr.access_descriptors, + tuple(new_args), + expr.redn_axis_to_redn_descr, + expr.index_to_access_descr, + tags=expr.tags, + axes=expr.axes, + ) + else: + raise NotImplementedError(type(concat)) + + def _map_index_base(self, expr: IndexBase, + exprs_from_other_calls: Tuple[Array, ...]) -> Array: + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, IndexBase) for ary in exprs_from_other_calls) + + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are IndexBase. + new_indices = [ + self.rec(idx, + tuple(ary.indices[i_idx] # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + if isinstance(idx, Array) + else idx + for i_idx, idx in enumerate(expr.indices) + ] + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + + return type(expr)(array=new_array, + indices=tuple(new_indices), + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + map_contiguous_advanced_index = _map_index_base + map_non_contiguous_advanced_index = _map_index_base + map_basic_index = _map_index_base + + def map_roll(self, + expr: Roll, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert concat.axis != expr.axis + assert all(isinstance(ary, Roll) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Roll. + return Roll(self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)), + shift=expr.shift, + axis=expr.axis, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_stack(self, expr: Stack, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Stack) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Stack. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate stack expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays)) + + return Stack(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_concatenate(self, expr: Concatenate, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, Concatenate) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Concatenate. + if any(len(ary.arrays) != len(expr.arrays) # type: ignore[attr-defined] + for ary in exprs_from_other_calls): + raise ValueError("Cannot concatenate concatenate-expressions" + " with different number of arrays.") + + new_arrays = tuple( + self.rec(array, + tuple(subexpr.arrays[iarray] # type: ignore[attr-defined] + for subexpr in exprs_from_other_calls) + ) + for iarray, array in enumerate(expr.arrays) + ) + + return Concatenate(new_arrays, + expr.axis, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_axis_permutation(self, expr: AxisPermutation, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, AxisPermutation) + for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are AxisPermutation. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return AxisPermutation(new_array, + expr.axis_permutation, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + def map_reshape(self, expr: Reshape, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + new_newshape = _get_concatenated_shape( + (expr,) + exprs_from_other_calls, concat.axis) + + assert all(isinstance(ary, Reshape) for ary in exprs_from_other_calls) + # type-ignore-reason: mypy does not respect the assertion that all + # other exprs are Reshape. + new_array = self.rec(expr.array, + tuple(ary.array # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return Reshape(new_array, + new_newshape, + expr.order, + tags=expr.tags, + axes=expr.axes) + else: + raise NotImplementedError(type(concat)) + + @memoize_method + def map_call(self, expr: Call, other_callsites: Tuple[Call, ...]) -> Call: + new_bindings = {name: self.rec(bnd, + tuple(callsite.bindings[name] + for callsite in other_callsites)) + for name, bnd in expr.bindings.items()} + new_mapper = self.clone_with_new_call_on_stack(expr) + fn_defn = expr.function + new_fn_defn = FunctionDefinition( + fn_defn.parameters, + fn_defn.return_type, + Map({ret: new_mapper(ret_val, + tuple(other_call.function.returns[ret] + for other_call in other_callsites) + ) + for ret, ret_val in fn_defn.returns.items()}), + tags=fn_defn.tags, + ) + return Call(new_fn_defn, + Map(new_bindings), + result_axes=expr.result_axes, + result_tags=expr.result_tags, + tags=expr.tags) + + def map_named_call_result(self, + expr: NamedCallResult, + exprs_from_other_calls: Tuple[Array, ...] + ) -> Array: + + concat = self._get_concatenatability(expr) + + if isinstance(concat, ConcatableIfConstant): + return expr + elif isinstance(concat, ConcatableAlongAxis): + assert all(isinstance(ary, NamedCallResult) + for ary in exprs_from_other_calls) + assert isinstance(expr._container, Call) + new_call = self.map_call( + expr._container, + tuple(ary._container # type: ignore[attr-defined] + for ary in exprs_from_other_calls)) + return new_call[expr.name] + else: + raise NotImplementedError(type(concat)) + + def map_loopy_call_result(self, expr: "LoopyCallResult", + exprs_from_other_calls: Tuple[Array, ...], + ) -> _InputConcatabilityGetterAcc: + raise ValueError("Loopy Calls are illegal to concatenate. Maybe" + " rewrite the operation as array operations?") + + +@memoize_on_first_arg +def _get_valid_concatenatability_constraints(fn: FunctionDefinition, + ) -> Tuple[FunctionConcatenability, + ...]: + mapper = _InputConcatabilityGetter() + output_accs = {name: mapper(output) + for name, output in fn.returns.items()} + + return _combine_named_result_accs(output_accs) + + +def _get_ary_to_concatenatabilities(call_sites: Sequence[Call], + ) -> Generator[Map[ArrayOnStackT, + Concatenatability], + None, + None]: + """ + Generates a :class:`Concatenatability` criterion for each array in the + expression graph of *call_sites*'s function body if they traverse identical + function bodies. + """ + fn_body = next(iter(call_sites)).function + fn_concatenatabilities = _get_valid_concatenatability_constraints(fn_body) + + for fn_concatenatability in fn_concatenatabilities: + collector = _ConcatabilityCollector(current_stack=()) + + # select a template call site to start the traversal. + template_call, *other_calls = call_sites + + try: + # verify the constraints on parameters are satisfied + for name, input_concat in (fn_concatenatability + .input_to_concatenatability + .items()): + if isinstance(input_concat, ConcatableIfConstant): + _verify_arrays_same([cs.bindings[name] for cs in call_sites]) + elif isinstance(input_concat, ConcatableAlongAxis): + _verify_arrays_can_be_concated_along_axis( + [cs.bindings[name] for cs in call_sites], + [], + input_concat.axis) + else: + raise NotImplementedError(type(input_concat)) + + # verify the constraints on function bodies are satisfied + for name, output_concat in (fn_concatenatability + .output_to_concatenatability + .items()): + collector(template_call.function.returns[name], + output_concat, + tuple(other_call.function.returns[name] + for other_call in other_calls)) + except _InvalidConcatenatability: + pass + else: + if collector.call_sites_on_hold: + raise NotImplementedError("Expressions that use part of" + " function's returned values are not" + " yet supported.") + + logger.info("Found a valid concatenatability --\n" + f"{fn_concatenatability}") + + yield Map(collector.ary_to_concatenatability) + + +def _get_replacement_map_post_concatenating(call_sites: Sequence[Call], + ) -> Mapping[NamedCallResult, + Array]: + """ + .. note:: + + We require *call_sites* to be ordered to determine the concatenation + order. + """ + assert call_sites, "Empty `call_sites`." + + ary_to_concatenatabilities = _get_ary_to_concatenatabilities(call_sites) + + try: + ary_to_concatenatability = next(ary_to_concatenatabilities) + except StopIteration: + raise ValueError("No valid concatenatibilities found.") + else: + if __debug__: + try: + next(ary_to_concatenatabilities) + except StopIteration: + # unique concatenatibility + pass + else: + from warnings import warn + # TODO: Take some input from the user to resolve this ambiguity. + warn("Multiple concatenation possibilities found. This may" + " lead to non-deterministic transformed expression graphs.") + + # {{{ actually perform the concatenation + + template_call_site, *other_call_sites = call_sites + + concatenator = _Concatenator(current_stack=(), + ary_to_concatenatability=ary_to_concatenatability) + + # new_returns: concatenated function body + new_returns: Dict[str, Array] = {} + for output_name in template_call_site.keys(): + new_returns[output_name] = concatenator( + template_call_site.function.returns[output_name], + tuple(csite.function.returns[output_name] + for csite in other_call_sites)) + + # }}} + + # construct new function body + new_function = FunctionDefinition( + template_call_site.function.parameters, + template_call_site.function.return_type, + Map(new_returns), + tags=template_call_site.function.tags, + ) + + result: Dict[NamedCallResult, Array] = {} + + new_call_bindings: Dict[str, Array] = {} + + # construct new bindings + for param_name in template_call_site.bindings: + param_placeholder = template_call_site.function.get_placeholder(param_name) + param_concat = ary_to_concatenatability[((), param_placeholder)] + if isinstance(param_concat, ConcatableAlongAxis): + new_binding = concatenate([csite.bindings[param_name] + for csite in call_sites], + param_concat.axis) + elif isinstance(param_concat, ConcatableIfConstant): + _verify_arrays_same([csite.bindings[param_name] + for csite in call_sites]) + new_binding = template_call_site.bindings[param_name] + else: + raise NotImplementedError(type(param_concat)) + new_call_bindings[param_name] = new_binding + + # construct new call + new_call = Call( + function=new_function, + bindings=Map(new_call_bindings), + result_tags=template_call_site.result_tags, + result_axes=template_call_site.result_axes, + tags=template_call_site.tags) + + # slice into new_call's outputs to replace the old expressions. + for output_name, output_ary in (template_call_site + .function + .returns + .items()): + start_idx: ShapeComponent = 0 + for cs in call_sites: + concat = ary_to_concatenatability[((), output_ary)] + if isinstance(concat, ConcatableIfConstant): + result[cs[output_name]] = new_call[output_name] + elif isinstance(concat, ConcatableAlongAxis): + ndim = output_ary.ndim + indices = [slice(None) for i in range(ndim)] + indices[concat.axis] = slice( + start_idx, start_idx+cs[output_name].shape[concat.axis]) + + sliced_output = new_call[output_name][tuple(indices)] + assert isinstance(sliced_output, BasicIndex) + result[cs[output_name]] = sliced_output + start_idx = start_idx + cs[output_name].shape[concat.axis] + else: + raise NotImplementedError(type(concat)) + + return Map(result) + + +def concatenate_calls(expr: ArrayOrNames, + call_site_filter: Callable[[CallSiteLocation], bool], + *, + warn_if_no_calls: bool = True, + err_if_no_calls: bool = False, + ) -> ArrayOrNames: + r""" + Returns a copy of *expr* after concatenating all call-sites ``C`` such that + ``call_site_filter(C) is True``. + + :arg call_site_filter: A callable to select which instances of + :class:`~pytato.function.Call`\ s must be concatenated. + """ + all_call_sites = CallsiteCollector(stack=())(expr) + + filtered_call_sites = {callsite + for callsite in all_call_sites + if call_site_filter(callsite)} + + if not filtered_call_sites: + if err_if_no_calls: + raise ValueError("No calls to concatenate.") + elif warn_if_no_calls: + from warnings import warn + warn("No calls to concatenate.", stacklevel=2) + else: + pass + return expr + elif len({csite.stack for csite in filtered_call_sites}) == 1: + pass + else: + raise ValueError("Call-sites to concatenate are called" + " at multiple stack frames. This is not allowed.") + + old_expr_to_new_expr_map = ( + _get_replacement_map_post_concatenating([csite.call + for csite in filtered_call_sites])) + + stack, = {csite.stack for csite in filtered_call_sites} + + result = _NamedCallResultReplacerPostConcatenate( + replacement_map=old_expr_to_new_expr_map, + current_stack=(), + stack_to_replace_on=stack)(expr) + + assert isinstance(result, (Array, AbstractResultWithNamedArrays)) + return result + +# }}} + +# vim:foldmethod=marker diff --git a/pytato/visualization.py b/pytato/visualization.py index 4bdb103ee..1d13780ec 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -37,6 +37,7 @@ from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall +from pytato.function import Call, NamedCallResult from pytato.array import ( Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase, @@ -245,6 +246,28 @@ def map_distributed_send_ref_holder( self.nodes[expr] = info + def map_call(self, expr: Call) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges=dict(expr.bindings), + fields={ + "addr": hex(id(expr)), + "tags": stringify_tags(expr.tags), + } + ) + + def map_named_call_result(self, expr: NamedCallResult) -> None: + self.rec(expr._container) + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges={"": expr._container}, + fields={"addr": hex(id(expr)), + "name": expr.name}, + ) + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. @@ -638,4 +661,7 @@ def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: """ print(get_ascii_graph(result, use_color=True)) + # }}} + +# vim:fdm=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index f0e4469c9..fa5739fab 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1921,6 +1921,205 @@ def test_pad(ctx_factory): np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) +def test_function_call(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + def f(x): + return 2*x + + def g(x): + return 2*x, 3*x + + def h(x, y): + return {"twice": 2*x+y, "thrice": 3*x+y} + + def build_expression(tracer): + x = pt.arange(500, dtype=np.float32) + twice_x = tracer(f, x) + twice_x_2, thrice_x_2 = tracer(g, x) + + result = tracer(h, x, 2*x) + twice_x_3 = result["twice"] + thrice_x_3 = result["thrice"] + + return {"foo": 3.14 + twice_x_3, + "bar": 4 * thrice_x_3, + "baz": 65 * twice_x, + "quux": 7 * twice_x_2} + + result1 = pt.tag_all_calls_to_be_inlined( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + result2 = pt.make_dict_of_named_arrays( + build_expression(lambda fn, *args: fn(*args))) + + _, outputs = pt.generate_loopy(result1)(cq, out_host=True) + _, expected = pt.generate_loopy(result2)(cq, out_host=True) + + assert len(outputs) == len(expected) + + for key in outputs.keys(): + np.testing.assert_allclose(outputs[key], expected[key]) + + +@pytest.mark.parametrize("should_concatenate_bar", (False, True)) +def test_nested_function_calls(ctx_factory, should_concatenate_bar): + from functools import partial + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + rng = np.random.default_rng(0) + x_np = rng.random((10,)) + + x = pt.make_placeholder("x", 10, np.float64).tagged(pt.tags.ImplStored()) + prg = pt.generate_loopy({"out1": 3*x, "out2": x}) + _, out = prg(cq, x=x_np) + np.testing.assert_allclose(out["out1"], 3*x_np) + np.testing.assert_allclose(out["out2"], x_np) + ref_tracer = lambda f, *args, identifier: f(*args) # noqa: E731 + + def foo(tracer, x, y): + return 2*x + 3*y + + def bar(tracer, x, y): + foo_x_y = tracer(partial(foo, tracer), x, y, identifier="foo") + return foo_x_y * x * y + + def call_bar(tracer, x, y): + return tracer(partial(bar, tracer), x, y, identifier="bar") + + x1_np, y1_np = rng.random((2, 13, 29)) + x2_np, y2_np = rng.random((2, 4, 29)) + x1, y1 = pt.make_data_wrapper(x1_np), pt.make_data_wrapper(y1_np) + x2, y2 = pt.make_data_wrapper(x2_np), pt.make_data_wrapper(y2_np) + result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1), + "out2": call_bar(pt.trace_call, x2, y2)} + ) + result = pt.tag_all_calls_to_be_inlined(result) + if should_concatenate_bar: + from pytato.transform.calls import CallsiteCollector + assert len(CallsiteCollector(())(result)) == 4 + result = pt.concatenate_calls( + result, + lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags) + assert len(CallsiteCollector(())(result)) == 2 + + expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1), + "out2": call_bar(ref_tracer, x2, y2)} + ) + + _, result_out = pt.generate_loopy(result)(cq) + _, expect_out = pt.generate_loopy(expect)(cq) + + assert result_out.keys() == expect_out.keys() + for k in expect_out: + np.testing.assert_allclose(result_out[k], expect_out[k]) + + +def test_concatenate_calls_no_nested(ctx_factory): + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def foo(x, y): + return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y) + + x1 = pt.make_placeholder("x1", (10, 4), np.float64) + x2 = pt.make_placeholder("x2", (10, 4), np.float64) + + y1 = pt.make_placeholder("y1", (10, 4), np.float64) + y2 = pt.make_placeholder("y2", (10, 4), np.float64) + + z1 = pt.make_placeholder("z1", (10, 4), np.float64) + z2 = pt.make_placeholder("z2", (10, 4), np.float64) + + result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2), + "out2": 4*pt.trace_call(foo, 4*y1, 9*y2), + "out3": 6*pt.trace_call(foo, 7*z1, 8*z2) + }) + + concatenated_result = pt.concatenate_calls( + result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags) + + result = pt.tag_all_calls_to_be_inlined(result) + concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result) + + assert (pt.analysis.get_num_nodes(pt.inline_calls(result)) + > pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result))) + + x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4)) + + _, out_dict1 = pt.generate_loopy(result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + + _, out_dict2 = pt.generate_loopy(concatenated_result)(cq, + x1=x1_np, x2=x2_np, + y1=y1_np, y2=y2_np, + z1=z1_np, z2=z2_np) + assert out_dict1.keys() == out_dict2.keys() + + for key in out_dict1: + np.testing.assert_allclose(out_dict1[key], out_dict2[key]) + + +def test_concatenation_via_constant_expressions(ctx_factory): + + from pytato.transform.calls import CallsiteCollector + + rng = np.random.default_rng(0) + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + def resampling(coords, iels): + return coords[iels] + + n_el = 1000 + n_dof = 20 + n_dim = 3 + + n_left_els = 17 + n_right_els = 29 + + coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64) + left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els) + right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els) + + coords_dofs = pt.make_data_wrapper(coords_dofs_np) + left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np) + right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np) + + lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels) + rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels) + + result = pt.make_dict_of_named_arrays({"lcoords": lcoords, + "rcoords": rcoords}) + result = pt.tag_all_calls_to_be_inlined(result) + + assert len(CallsiteCollector(())(result)) == 2 + concated_result = pt.concatenate_calls( + result, + lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags + ) + assert len(CallsiteCollector(())(concated_result)) == 1 + + _, out_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + _, out_concated_result = pt.generate_loopy(result)(cq) + np.testing.assert_allclose(out_concated_result["lcoords"], + coords_dofs_np[left_bnd_iels_np]) + np.testing.assert_allclose(out_concated_result["rcoords"], + coords_dofs_np[right_bnd_iels_np]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])