diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 24f2c684b..982737fa5 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -142,7 +142,8 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: IDX_LAMBDA_RE = re.compile("_r?(0|([1-9][0-9]*))") - +IDX_LAMBDA_INAME = re.compile("^(_(0|([1-9][0-9]*)))$") +IDX_LAMBDA_JUST_REDUCTIONS = re.compile("^(_r(0|([1-9][0-9]*)))$") class DependencyMapper(DependencyMapperBase): def __init__(self, *, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6cae6bd25..74f756484 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -121,6 +121,7 @@ .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum .. automodule:: pytato.transform.einsum_distributive_law +.. automodule:: pytato.transform.parameter_study .. currentmodule:: pytato.transform Dict representation of DAGs diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 319611107..2afcafb8e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -99,6 +99,34 @@ GraphNodeT = TypeVar("GraphNodeT") +from pytato.scalar_expr import ( + IDX_LAMBDA_INAME, + IDX_LAMBDA_JUST_REDUCTIONS, + IdentityMapper as ScalarMapper, +) + +class AxesUsedMapper(ScalarMapper): + """ + Determine which axes are used in the scalar expression and which ones just flow + through the expression. + """ + + def __init__(self, var_names_in_use: list[str]): + self.var_names_in_use: list[str] = var_names_in_use + + self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] + for vname in + self.var_names_in_use} + + def map_subscript(self, expr: prim.Subscript) -> None: + + name = expr.aggregate.name + if name in self.var_names_in_use: + self.usage_dict[name].append(expr.index_tuple) + + self.rec(expr.index) + + # {{{ AxesTagsEquationCollector @@ -237,65 +265,36 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - from warnings import warn - warn(f"'{expr}' is an unknown index lambda type" - " no tags were propagated across it.", stacklevel=1) - # no propagation semantics implemented for such cases - return - - if isinstance(hlo, BinaryOp): - subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2) - elif isinstance(hlo, WhereOp): - subexprs = (hlo.condition, hlo.then, hlo.else_) - elif isinstance(hlo, FullOp): - # A full-op does not impose any equations - subexprs = () - elif isinstance(hlo, BroadcastOp): - subexprs = (hlo.x,) - elif isinstance(hlo, C99CallOp): - subexprs = hlo.args - elif isinstance(hlo, ReduceOp): - - # {{{ ReduceOp doesn't quite involve broadcasting - - i_out_axis = 0 - for i_in_axis in range(hlo.x.ndim): - if i_in_axis not in hlo.axes: - self.record_equation( - self.get_var_for_axis(hlo.x, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - i_out_axis += 1 - assert i_out_axis == expr.ndim + keys = list(expr.bindings.keys()) - # }}} + mymapper = AxesUsedMapper(keys) - return + mymapper(expr.expr) - else: - raise NotImplementedError(type(hlo)) + out_shape = expr.shape + assert len(out_shape) == expr.ndim - for subexpr in subexprs: - if isinstance(subexpr, Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim, expr.ndim)): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - self.record_equation( - self.get_var_for_axis(subexpr, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - else: - # i_in_axis is broadcasted => do not propagate - assert are_shape_components_equal(in_dim, 1) - else: - assert isinstance(subexpr, SCALAR_CLASSES) + for key in keys: + if len(mymapper.usage_dict[key]) > 0: + for tup_ind in range(len(mymapper.usage_dict[key][0])): + vname = mymapper.usage_dict[key][0][tup_ind] + if isinstance(vname, prim.Variable): + if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(vname.name): + # Reduction axis. Pass for now. + pass + elif vname.name[:3] == "_in": + # Array used as part of the index. Pass for now. + pass + elif IDX_LAMBDA_INAME.fullmatch(vname.name): + # Matched with an output axis. + inum = int(vname.name[1:]) + val = (self.get_var_for_axis(expr.bindings[key], tup_ind), + self.get_var_for_axis(expr, inum)) + self.equations.append(val) + else: + raise ValueError(f"Unknown index name used in {vname}") + return def map_stack(self, expr: Stack) -> None: """ diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py new file mode 100644 index 000000000..f88ccccbe --- /dev/null +++ b/pytato/transform/parameter_study.py @@ -0,0 +1,485 @@ +from __future__ import annotations + + +__doc__ = """ +.. currentmodule:: pytato.transform.parameter_study + +.. autoclass:: ParameterStudyAxisTag +.. autoclass:: ParameterStudyVectorizer +.. autoclass:: IndexLambdaScalarExpressionVectorizer +""" + +__copyright__ = "Copyright (C) 2024 Nicholas Koskelo" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from dataclasses import dataclass +from typing import ( + Mapping, + Sequence, +) + +from immutabledict import immutabledict + +import pymbolic.primitives as prim +from pytools.tag import UniqueTag + +from pytato.array import ( + AbstractResultWithNamedArrays, + Array, + AxesT, + Axis, + AxisPermutation, + Concatenate, + Einsum, + EinsumAxisDescriptor, + EinsumElementwiseAxis, + IndexBase, + IndexLambda, + NormalizedSlice, + Placeholder, + Reshape, + Roll, + Stack, +) +from pytato.distributed.nodes import ( + DistributedRecv, + DistributedSendRefHolder, +) +from pytato.function import ( + Call, + FunctionDefinition, + NamedCallResult, +) +from pytato.scalar_expr import IdentityMapper, IntegralT +from pytato.transform import CopyMapper + + +KnownShapeType = tuple[IntegralT, ...] + + +@dataclass(frozen=True) +class ParameterStudyAxisTag(UniqueTag): + """ + A tag to indicate that the axis is being used + for independent trials like in a parameter study. + If you want to vary multiple input variables in the + same study then you need to have the same type of + :class:`ParameterStudyAxisTag`. + """ + size: int + + +class IndexLambdaScalarExpressionVectorizer(IdentityMapper): + """ + The goal of this mapper is to convert a single instance scalar expression + into a single instruction multiple data scalar expression. We assume that any + array variables in the original scalar expression will be indexed completely. + Also, new axes for the studies are assumed to be on the end of + those array variables. + """ + + def __init__(self, varname_to_studies_num: Mapping[str, tuple[int, ...]], + num_orig_elem_inds: int): + """ + `arg` varname_to_studies_num: is a mapping from the variable name used + in the scalar expression to the studies present in the multiple instance + expression. Note that the varnames must be for the array variables only. + + `arg` num_orig_elem_inds: is the number of element axes in the result of + the single instance expression. + """ + + super().__init__() + self.varname_to_studies_num = varname_to_studies_num + self.num_orig_elem_inds = num_orig_elem_inds + + def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: + # We only need to modify the indexing which is stored in the index member. + assert isinstance(expr.aggregate, prim.Variable) + + name = expr.aggregate.name + if name in self.varname_to_studies_num.keys(): + # These are the single instance information. + + index = self.rec(expr.index) + + additional_inds = (prim.Variable(f"_{self.num_orig_elem_inds + num}") for + num in self.varname_to_studies_num[name]) + + return type(expr)(aggregate=expr.aggregate, + index=(*index, *additional_inds,)) + + return super().map_subscript(expr) + + def map_variable(self, expr: prim.Variable) -> prim.Expression: + # We know that a variable is a leaf node. So we only need to update it + # if the variable is part of a study. + if expr.name in self.varname_to_studies_num.keys(): + # The variable may need to be updated. + + my_studies: tuple[int, ...] = self.varname_to_studies_num[expr.name] + + if len(my_studies) == 0: + # No studies + return expr + + assert my_studies + assert len(my_studies) > 0 + + new_vars = (prim.Variable(f"_{self.num_orig_elem_inds + num}") # noqa E501 + for num in my_studies) + + return prim.Subscript(aggregate=expr, index=tuple(new_vars)) + + # Since the variable is not in a study we can just return the answer. + return super().map_variable(expr) + + +def _param_study_to_index(tag: ParameterStudyAxisTag) -> str: + """ + Get the canonical index string associated with the input tag. + """ + return str(tag.__class__) # Update to use the qualname or name. + + +class ParameterStudyVectorizer(CopyMapper): + r""" + This mapper will expand a DAG into a DAG for parameter studies. An array is part + of a parameter study if one of its axes is tagged with a with a tag from + :class:`ParameterStudyAxisTag` and all of the axes after that axis are also tagged + with a distinct :class:`ParameterStudyAxisTag` tag. An array may be a member of + multiple parameter studies. The new DAG for parameter studies which will be + equivalent to running your original DAG once for each input in the parameter study + space. The parameter study space is defined as the Cartesian product of all + the input parameter studies. When calling this mapper you must specify which input + :class:`~pytato.array.Placeholder` arrays are part of what parameter study. + + To maintain the equivalence with repeated calling the single instance DAG, the + DAG for parameter studies will not create any expressions which depend on the + specific instance of a parameter study. + + It is not required that each input be part of a parameter study as we will + broadcast the input to the appropriate size. + + The mapper does not support distributed programming or function definitions. + + .. note:: + + Any new axes used for parameter studies will be added to the end of the arrays. + """ + + def __init__(self, + place_name_to_parameter_studies: Mapping[str, + tuple[ParameterStudyAxisTag, ...]], + study_to_size: Mapping[ParameterStudyAxisTag, int]): + super().__init__() + self.place_name_to_parameter_studies = place_name_to_parameter_studies # E501 + self.study_to_size = study_to_size + + def _get_canonical_ordered_studies( + self, mapped_preds: tuple[Array, ...]) -> Sequence[ParameterStudyAxisTag]: + active_studies: set[ParameterStudyAxisTag] = set() + for arr in mapped_preds: + for axis in arr.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + assert len(tags) == 1 # only one study per axis. + active_studies = active_studies.union(tags) + + return sorted(active_studies, key=_param_study_to_index) + + def _canonical_ordered_studies_to_shapes_and_axes( + self, studies: Sequence[ParameterStudyAxisTag]) -> tuple[list[int], + list[Axis]]: + """ + Get the shapes and axes in the canonical ordering. + """ + + return [self.study_to_size[study] for study in studies], \ + [Axis(tags=frozenset((study,))) for study in studies] + + def map_placeholder(self, expr: Placeholder) -> Array: + + if expr.name in self.place_name_to_parameter_studies.keys(): + canonical_studies = sorted(self.place_name_to_parameter_studies[expr.name], + key=_param_study_to_index) # noqa E501 + + # We know that there are no predecessors and we know the studies to add. + # We need to get them in the right order. + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + return Placeholder(name=expr.name, + shape=self.rec_idx_or_size_tuple((*expr.shape, + *end_shape,)), + dtype=expr.dtype, + axes=(*expr.axes, *end_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + return super().map_placeholder(expr) + + def map_roll(self, expr: Roll) -> Array: + new_predecessor = self.rec(expr.array) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + _, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + return Roll(array=new_predecessor, + shift=expr.shift, + axis=expr.axis, + axes=(*expr.axes, *end_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_axis_permutation(self, expr: AxisPermutation) -> Array: + new_predecessor = self.rec(expr.array) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shapes, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + # Include the axes we are adding to the system. + n_single_inst_axes: int = len(expr.axis_permutation) + axis_permute_gen = (i + n_single_inst_axes for i in range(len(end_shapes))) + + return AxisPermutation(array=new_predecessor, + axis_permutation=(*expr.axis_permutation, + *axis_permute_gen,), + axes=(*expr.axes, *end_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def _map_index_base(self, expr: IndexBase) -> Array: + new_predecessor = self.rec(expr.array) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + # Update the indices. + end_indices = (NormalizedSlice(0, shape, 1) for shape in end_shape) + + return type(expr)(new_predecessor, + indices=self.rec_idx_or_size_tuple((*expr.indices, + *end_indices)), + axes=(*expr.axes, *end_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_reshape(self, expr: Reshape) -> Array: + new_predecessor = self.rec(expr.array) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + return Reshape(new_predecessor, + newshape=self.rec_idx_or_size_tuple((*expr.shape, + *end_shape,)), + order=expr.order, + axes=(*expr.axes, *end_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + # {{{ Operations with multiple predecessors. + + def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ + -> tuple[tuple[Array, ...], AxesT]: + + """ + This method will convert predecessors who were originally the same + shape in a single instance program to the same shape in this multiple + instance program. + """ + + new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) + + canonical_studies = self._get_canonical_ordered_studies(new_predecessors) + studies_shape, new_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + if not studies_shape: + # We do not need to do anything as the expression we have is unmodified. + return new_predecessors, tuple(new_axes) + + correct_shape_preds: tuple[Array, ...] = () + for iarr, array in enumerate(new_predecessors): + # Broadcast out to the right shape. + n_single_inst_axes = len(expr.arrays[iarr].shape) + + # We assume there is at most one axis + # tag of type ParameterStudyAxisTag per axis. + + index = tuple(prim.Variable(f"_{ind}") if not \ + array.axes[ind].tags_of_type(ParameterStudyAxisTag) \ + else prim.Variable(f"_{n_single_inst_axes + canonical_studies.index(tuple(array.axes[ind].tags_of_type(ParameterStudyAxisTag))[0])}") # noqa E501 + for ind in range(len(array.shape))) + + assert len(index) == len(array.axes) + + new_array = IndexLambda(expr=prim.Subscript(prim.Variable("_in0"), + index=index), + bindings=immutabledict({"_in0": array}), + tags=array.tags, + non_equality_tags=array.non_equality_tags, + dtype=array.dtype, + var_to_reduction_descr=immutabledict({}), + shape=(*expr.arrays[iarr].shape, *studies_shape,), + axes=(*expr.arrays[iarr].axes, *new_axes,)) + + correct_shape_preds = (*correct_shape_preds, new_array,) + + for arr in correct_shape_preds: + assert arr.shape == correct_shape_preds[0].shape + + return correct_shape_preds, tuple(new_axes) + + def map_stack(self, expr: Stack) -> Array: + new_arrays, new_axes = self._broadcast_predecessors_to_same_shape(expr) + out = Stack(arrays=new_arrays, + axis=expr.axis, + axes=(*expr.axes, *new_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + assert out.ndim == len(out.shape) + assert len(out.shape) == len(out.arrays[0].shape) + 1 + + return out + + def map_concatenate(self, expr: Concatenate) -> Array: + new_arrays, new_axes = self._broadcast_predecessors_to_same_shape(expr) + + return Concatenate(arrays=new_arrays, + axis=expr.axis, + axes=(*expr.axes, *new_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_index_lambda(self, expr: IndexLambda) -> Array: + # Update bindings first. + + new_binds: dict[str, Array] = {name: self.rec(bnd) + for name, bnd in + sorted(expr.bindings.items())} + new_arrays = (*new_binds.values(),) + + # The arrays may be the same for a predecessors. + # However, the index will be unique. + + array_num_to_bnd_name: dict[int, str] = {ind: name for ind, (name, _) + in enumerate(sorted(new_binds.items()))} # noqa E501 + + # Determine the new parameter studies that are being conducted. + canonical_studies = self._get_canonical_ordered_studies(new_arrays) + postpend_shapes, post_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + varname_to_studies_nums: dict[str, tuple[int, ...]] = {bnd_name: () for + _, bnd_name + in array_num_to_bnd_name.items()} + + for iarr, array in enumerate(new_arrays): + for axis in array.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + assert len(tags) == 1 + study: ParameterStudyAxisTag = next(iter(tags)) + name: str = array_num_to_bnd_name[iarr] + varname_to_studies_nums[name] = (*varname_to_studies_nums[name], + canonical_studies.index(study),) + + #varname_to_studies_nums = {array_num_to_bnd_name[iarr]: studies for iarr, + # studies in arr_num_to_study_nums.items()} + + assert all(vn_key in new_binds.keys() for + vn_key in varname_to_studies_nums.keys()) + + assert all(vn_key in varname_to_studies_nums.keys() for + vn_key in new_binds.keys()) + + # Now we need to update the expressions. + scalar_expr_mapper = IndexLambdaScalarExpressionVectorizer(varname_to_studies_nums, len(expr.shape)) # noqa E501 + + return IndexLambda(expr=scalar_expr_mapper(expr.expr), + bindings=immutabledict(new_binds), + shape=(*expr.shape, *postpend_shapes,), + var_to_reduction_descr=expr.var_to_reduction_descr, + dtype=expr.dtype, + axes=(*expr.axes, *post_axes,), + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_einsum(self, expr: Einsum) -> Array: + + new_predecessors = tuple(self.rec(arg) for arg in expr.args) + canonical_studies = self._get_canonical_ordered_studies(new_predecessors) + studies_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...] = () + for ival, array in enumerate(new_predecessors): + one_descr = expr.access_descriptors[ival] + for axis in array.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + # Need to append a descriptor + assert len(tags) == 1 + study: ParameterStudyAxisTag = next(iter(tags)) + one_descr = (*one_descr, + EinsumElementwiseAxis(dim=len(expr.shape) + + canonical_studies.index(study))) + + access_descriptors = (*access_descriptors, one_descr) + + # One descriptor per axis. + assert len(one_descr) == len(array.shape) + + return Einsum(access_descriptors, + new_predecessors, + axes=(*expr.axes, *end_axes,), + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + # }}} Operations with multiple predecessors. + + # {{{ Function definitions + def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: + raise NotImplementedError(" Expanding functions is not yet supported.") + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + raise NotImplementedError(" Expanding functions is not yet supported.") + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + raise NotImplementedError(" Expanding functions is not yet supported.") + + # }}} + + # {{{ Distributed Programming Constructs + def map_distributed_recv(self, expr: DistributedRecv) -> DistributedRecv: + # This data will depend solely on the rank sending it to you. + raise NotImplementedError(" Expanding distributed programming constructs is" + " not yet supported.") + + def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder) \ + -> Array: + # We are sending the data. This data may increase in size due to the + # parameter studies. + raise NotImplementedError(" Expanding distributed programming constructs is" + " not yet supported.") + + # }}} diff --git a/test/test_vectorizer.py b/test/test_vectorizer.py new file mode 100644 index 000000000..8ab15a589 --- /dev/null +++ b/test/test_vectorizer.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +from __future__ import annotations + + +__copyright__ = """Copyright (C) 2020 Andreas Kloeckner +Copyright (C) 2021 Kaushik Kulkarni +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +import numpy as np +from testlib import RandomDAGContext, make_random_dag + +import pytato as pt +from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, +) + + +class Study1(ParameterStudyAxisTag): + """First parameter study.""" + pass + + +class Study2(ParameterStudyAxisTag): + """Second parameter study.""" + pass + + +GLOBAL_ARRAY_NAME1 = "foo" +GLOBAL_ARRAY_NAME2 = "bar" +GLOBAL_SHAPE1: int = 1234 +GLOBAL_SHAPE2: int = 5678 +GLOBAL_STUDY1 = Study1(GLOBAL_SHAPE1) +GLOBAL_STUDY2 = Study2(GLOBAL_SHAPE2) +GLOBAL_NAME_TO_STUDIES = {GLOBAL_ARRAY_NAME1: frozenset((GLOBAL_STUDY1,)), + GLOBAL_ARRAY_NAME2: frozenset((GLOBAL_STUDY2,)), } +GLOBAL_STUDY_TO_SHAPES = {GLOBAL_STUDY1: GLOBAL_SHAPE1, GLOBAL_STUDY2: GLOBAL_SHAPE2} + +# {{{ Expansion Mapper tests. +def test_vectorize_mapper_placeholder(): + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int) + assert expr.shape == (15, 5) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (15, 5, GLOBAL_SHAPE1) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_vectorize_mapper_basic_index(): + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] + + assert expr.shape == () + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (GLOBAL_SHAPE1,) + assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_advanced_index_contiguous_axes(): + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int) + expr = expr[pt.arange(10, dtype=int)] + + assert expr.shape == (10, 5) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 5, GLOBAL_SHAPE1) + assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) + assert isinstance(expr, type(new_expr)) + + +def test_vectorize_mapper_advanced_index_non_contiguous_axes(): + ind0 = pt.arange(10, dtype=int).reshape(10, 1) + ind1 = pt.arange(2, dtype=int).reshape(1, 2) + + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 1000, 5), dtype=int) + expr = expr[ind0, :, ind1] + + assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) + assert expr.shape == (10, 2, 1000) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 2, 1000, GLOBAL_SHAPE1) + assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) + assert isinstance(expr, type(new_expr)) + + +def test_vectorize_mapper_index_lambda(): + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] \ + + pt.ones(100) + + assert expr.shape == (100,) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, GLOBAL_SHAPE1) + assert isinstance(new_expr, pt.IndexLambda) + + scalar_expr = new_expr.expr + + assert len(scalar_expr.children) == len(expr.expr.children) + assert scalar_expr != expr.expr + # We modified it so that we have the new axis. + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_roll(): + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] \ + + pt.ones(100) + + expr = pt.roll(expr, axis=0, shift=22) + + assert expr.shape == (100,) + assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, GLOBAL_SHAPE1,) + assert isinstance(new_expr, pt.Roll) + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_axis_permutation(): + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)) + assert expr.shape == (5, 15) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(expr) + assert new_expr.shape == (5, 15, GLOBAL_SHAPE1) + assert isinstance(new_expr, pt.AxisPermutation) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_vectorize_mapper_reshape(): + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, + (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) + assert out_expr.shape == (10, 15) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) + assert isinstance(new_expr, pt.Reshape) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) + + +def test_vectorize_mapper_stack(): + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, + (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0) + assert out_expr.shape == (2, 5, 15) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (2, 5, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) + assert isinstance(new_expr, pt.Stack) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 2: + assert tags + else: + assert not tags + + assert not new_expr.axes[3].tags_of_type(Study2) + assert not new_expr.axes[4].tags_of_type(Study1) + + +def test_vectorize_mapper_concatenate(): + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, + (15, 5), dtype=int)) + + out_expr = pt.concatenate([expr, expr2], axis=0) + assert out_expr.shape == (10, 15) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) + assert isinstance(new_expr, pt.Concatenate) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) + + +def test_vectorize_mapper_einsum_matmul(): + # Matmul gets expanded correctly. + a = pt.make_placeholder(GLOBAL_ARRAY_NAME1, + (47, 42), dtype=int) + b = pt.make_placeholder(GLOBAL_ARRAY_NAME2, + (42, 5), dtype=int) + + c = pt.matmul(a, b) + assert isinstance(c, pt.Einsum) + assert c.shape == (47, 5) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + updated_c = my_mapper(c) + + assert updated_c.shape == (47, 5, GLOBAL_SHAPE1, GLOBAL_SHAPE2) + + +def test_vectorize_mapper_does_nothing_if_tags_not_there(): + + from pytools import UniqueNameGenerator + axis_len = 5 + + for i in range(50): + print(i) # progress indicator + + seed = 120 + i + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + dag = pt.make_dict_of_named_arrays({"out": make_random_dag(rdagc_pt)}) + + # {{{ convert data-wrappers to placeholders + + vng = UniqueNameGenerator() + + def make_dws_placeholder(expr): + if isinstance(expr, pt.DataWrapper): + return pt.make_placeholder(vng("_pt_ph"), # noqa: B023 + expr.shape, expr.dtype) + else: + return expr + + dag = pt.transform.map_and_copy(dag, make_dws_placeholder) + + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) + new_dag = my_mapper(dag) + + assert new_dag == dag + + # }}} +# }}}