From a049355b648b33131db18556ccf2404a12a65d47 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 15:35:15 -0500 Subject: [PATCH 01/27] Move the transform mapper from arraycontext to pytato. --- pytato/transform/parameter_study.py | 405 ++++++++++++++++++++++++++++ 1 file changed, 405 insertions(+) create mode 100644 pytato/transform/parameter_study.py diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py new file mode 100644 index 000000000..86d8abcd6 --- /dev/null +++ b/pytato/transform/parameter_study.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +""" +.. currentmodule:: pytato.transform + +TODO: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: pytato.transform.parameter_study +""" +__copyright__ = """ +Copyright (C) 2020-1 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from immutabledict import immutabledict +from dataclasses import dataclass +from typing import ( + Dict, + FrozenSet, + Iterable, + Mapping, + Sequence, + Set, + Tuple, + Union, +) + +from pytato.array import ( + Array, + AxesT, + Axis, + AxisPermutation, + Concatenate, + Einsum, + EinsumElementwiseAxis, + IndexBase, + IndexLambda, + NormalizedSlice, + Placeholder, + Reshape, + Roll, + ShapeType, + Stack, +) + +from pytato.scalar_expr import IdentityMapper, IntegralT + +import pymbolic.primitives as prim + +from pytools.tag import UniqueTag, Tag + +from pytato.transform import CopyMapper + + +@dataclass(frozen=True) +class ParameterStudyAxisTag(UniqueTag): + """ + A tag for acting on axes of arrays. + To enable multiple parameter studies on the same variable name + specify a different axis number and potentially a different size. + + Currently does not allow multiple variables of different names to be in + the same parameter study. + """ + axis_num: int + axis_size: int + + +StudiesT = Tuple[ParameterStudyAxisTag, ...] +ArraysT = Tuple[Array, ...] +KnownShapeType = Tuple[IntegralT, ...] + + +class ExpansionMapper(CopyMapper): + + def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT]): + super().__init__() + self.placeholder_name_to_parameter_studies = placeholder_name_to_parameter_studies # noqa + + def _shapes_and_axes_from_predecessor(self, curr_expr: Array, + mapped_preds: ArraysT) -> \ + Tuple[KnownShapeType, + AxesT, + Dict[Array, Tuple[int, ...]]]: + # Initialize with something for the typing. + + assert not any(axis.tags_of_type(ParameterStudyAxisTag) for + axis in curr_expr.axes) + + # We are post pending the axes we are using for parameter studies. + new_shape: KnownShapeType = () + studies_axes: AxesT = () + + study_to_arrays: Dict[FrozenSet[ParameterStudyAxisTag], ArraysT] = {} + + active_studies: Set[ParameterStudyAxisTag] = set() + + for arr in mapped_preds: + for axis in arr.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + active_studies = active_studies.union(tags) + if tags in study_to_arrays.keys(): + study_to_arrays[tags] = (*study_to_arrays[tags], arr) + else: + study_to_arrays[tags] = (arr,) + + return self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa + new_shape, + studies_axes, + study_to_arrays) + + def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, + studies: Iterable[ParameterStudyAxisTag], + new_shape: KnownShapeType, new_axes: AxesT, + study_to_arrays: Dict[FrozenSet[ParameterStudyAxisTag], ArraysT]) \ + -> Tuple[KnownShapeType, AxesT, Dict[Array, Tuple[int, ...]]]: + + # This is where we specify the canonical ordering. + + array_to_canonical_ordered_studies: Dict[Array, Tuple[int, ...]] = {} + studies_axes = new_axes + + for ind, study in enumerate(sorted(studies, + key=lambda x: str(x.__class__))): + new_shape = (*new_shape, study.axis_size) + studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) + for arr in study_to_arrays[frozenset((study,))]: + if arr in array_to_canonical_ordered_studies.keys(): + array_to_canonical_ordered_studies[arr] = (*array_to_canonical_ordered_studies[arr], ind) # noqa + else: + array_to_canonical_ordered_studies[arr] = (ind,) + + return new_shape, studies_axes, array_to_canonical_ordered_studies + + def map_placeholder(self, expr: Placeholder) -> Array: + # This is where we could introduce extra axes. + if expr.name in self.placeholder_name_to_parameter_studies.keys(): + new_axes = expr.axes + studies = self.placeholder_name_to_parameter_studies[expr.name] + new_shape, new_axes, _ = self._studies_to_shape_and_axes_and_arrays_in_canonical_order( # noqa + studies, + (), + expr.axes, + {}) + + return Placeholder(name=expr.name, + shape=self.rec_idx_or_size_tuple((*expr.shape, + *new_shape,)), + dtype=expr.dtype, + axes=new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + return super().map_placeholder(expr) + + def map_roll(self, expr: Roll) -> Array: + new_predecessor = self.rec(expr.array) + _, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + (new_predecessor,)) + + return Roll(array=new_predecessor, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes + new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_axis_permutation(self, expr: AxisPermutation) -> Array: + new_predecessor = self.rec(expr.array) + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + (new_predecessor,)) + # Include the axes we are adding to the system. + axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) + for i in range(len(postpend_shape))]) + + return AxisPermutation(array=new_predecessor, + axis_permutation=axis_permute, + axes=expr.axes + new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def _map_index_base(self, expr: IndexBase) -> Array: + new_predecessor = self.rec(expr.array) + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + (new_predecessor,)) + # Update the indicies. + new_indices = expr.indices + for shape in postpend_shape: + new_indices = (*new_indices, NormalizedSlice(0, shape, 1)) + + return type(expr)(new_predecessor, + indices=self.rec_idx_or_size_tuple(new_indices), + axes=expr.axes + new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_reshape(self, expr: Reshape) -> Array: + new_predecessor = self.rec(expr.array) + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + (new_predecessor,)) + return Reshape(new_predecessor, + newshape=self.rec_idx_or_size_tuple(expr.newshape + + postpend_shape), + order=expr.order, + axes=expr.axes + new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + # {{{ Operations with multiple predecessors. + + def map_stack(self, expr: Stack) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + + return Stack(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_concatenate(self, expr: Concatenate) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + + return Concatenate(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, + AxesT]: + + new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) + + studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa + + # This is going to be expensive. + + # Now we need to update the expressions. + # Now that we have the appropriate shape, + # we need to update the input arrays to match. + + cp_map = CopyMapper() + corrected_new_arrays: ArraysT = () + for iarr, array in enumerate(new_predecessors): + tmp = cp_map(array) # Get a copy of the array. + if len(array.axes) < len(new_axes): + # We need to grow the array to the new size. + studies_present = arrays_to_study_num_present[array] + for ind, size in enumerate(studies_shape): + if ind not in studies_present: + build: ArraysT = tuple([cp_map(tmp) for _ in range(size)]) + + # Note we are stacking the arrays into the appropriate shape. + tmp = Stack(arrays=build, + axis=len(expr.arrays[iarr].axes) + ind, + axes=new_axes[:ind], + tags=tmp.tags, + non_equality_tags=tmp.non_equality_tags) + + return corrected_new_arrays, new_axes + + def map_index_lambda(self, expr: IndexLambda) -> Array: + # Update bindings first. + new_bindings: Dict[str, Array] = {name: self.rec(bnd) + for name, bnd in + sorted(expr.bindings.items())} + + # Determine the new parameter studies that are being conducted. + from pytools import unique + + all_axis_tags: StudiesT = () + varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {} + for name, bnd in sorted(new_bindings.items()): + axis_tags_for_bnd: Set[Tag] = set() + varname_to_studies[name] = {} + for i in range(len(bnd.axes)): + axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) # noqa + for tag in axis_tags_for_bnd: + if isinstance(tag, ParameterStudyAxisTag): + # Defense + varname_to_studies[name][tag] = True + all_axis_tags = *all_axis_tags, tag, + + cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} + + new_shape = expr.shape + new_axes = expr.axes + + for study in cur_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + # The active studies are added to the end of the bindings. + study_to_axis_number[study] = len(new_shape) + new_shape = (*new_shape, study.axis_size,) + new_axes = (*new_axes, Axis(tags=frozenset((study,))),) + # This assumes that the axis only has 1 tag, + # because there should be no dependence across instances. + + # Now we need to update the expressions. + scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, + study_to_axis_number) + + return IndexLambda(expr=scalar_expr, + bindings=immutabledict(new_bindings), + shape=new_shape, + var_to_reduction_descr=expr.var_to_reduction_descr, + dtype=expr.dtype, + axes=new_axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_einsum(self, expr: Einsum) -> Array: + + return super().map_einsum(expr) + + """ + new_arrays = tuple([self.rec(arg) for arg in expr.args]) + studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa + new_axes_for_end, cur_studies, _ = self._studies_from_multiple_pred(new_arrays) + + + # Access Descriptors hold the Einsum notation. + new_access_descriptors = list(expr.access_descriptors) + study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} + + new_shape = expr.shape + + for study in cur_studies: + if isinstance(study, ParameterStudyAxisTag): + # Just defensive programming + # The active studies are added to the end. + study_to_axis_number[study] = len(new_shape) + new_shape = *new_shape, study.axis_size, + + for ind, array in enumerate(new_arrays): + for _, axis in enumerate(array.axes): + axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) + if axis_tags: + assert len(axis_tags) == 1 + new_access_descriptors[ind] = new_access_descriptors[ind] + \ + (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) + + return Einsum(tuple(new_access_descriptors), new_arrays, + axes=expr.axes + new_axes_for_end, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + index_to_access_descr=expr.index_to_access_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + """ + + # }}} Operations with multiple predecessors. + + +class ParamAxisExpander(IdentityMapper): + + def map_subscript(self, expr: prim.Subscript, + varname_to_studies: Mapping[str, + Mapping[ParameterStudyAxisTag, bool]], + study_to_axis_number: Mapping[ParameterStudyAxisTag, int]) -> \ + prim.Subscript: + # We know that we are not changing the variable that we are indexing into. + # This is stored in the aggregate member of the class Subscript. + + # We only need to modify the indexing which is stored in the index member. + name = expr.aggregate.name + if name in varname_to_studies.keys(): + # These are the single instance information. + index = self.rec(expr.index, varname_to_studies, + study_to_axis_number) + + new_vars: Tuple[prim.Variable, ...] = () + + for key, num in sorted(study_to_axis_number.items(), + key=lambda item: item[1]): + if key in varname_to_studies[name]: + new_vars = *new_vars, prim.Variable(f"_{num}"), + + if isinstance(index, tuple): + index = index + new_vars + else: + index = tuple(index) + new_vars + return type(expr)(aggregate=expr.aggregate, index=index) + return expr From d745688460974eda41eb45c9878071e5b4b9e165 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 15:39:19 -0500 Subject: [PATCH 02/27] Revert einsum to have just the copymapper. This will be updated in a future commit. --- pytato/transform/parameter_study.py | 42 +++-------------------------- 1 file changed, 3 insertions(+), 39 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 86d8abcd6..5317627a2 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -52,14 +52,12 @@ AxisPermutation, Concatenate, Einsum, - EinsumElementwiseAxis, IndexBase, IndexLambda, NormalizedSlice, Placeholder, Reshape, Roll, - ShapeType, Stack, ) @@ -134,7 +132,8 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, studies: Iterable[ParameterStudyAxisTag], new_shape: KnownShapeType, new_axes: AxesT, study_to_arrays: Dict[FrozenSet[ParameterStudyAxisTag], ArraysT]) \ - -> Tuple[KnownShapeType, AxesT, Dict[Array, Tuple[int, ...]]]: + -> Tuple[KnownShapeType, AxesT, Dict[Array, + Tuple[int, ...]]]: # This is where we specify the canonical ordering. @@ -332,43 +331,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: non_equality_tags=expr.non_equality_tags) def map_einsum(self, expr: Einsum) -> Array: - - return super().map_einsum(expr) - - """ - new_arrays = tuple([self.rec(arg) for arg in expr.args]) - studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa - new_axes_for_end, cur_studies, _ = self._studies_from_multiple_pred(new_arrays) - - - # Access Descriptors hold the Einsum notation. - new_access_descriptors = list(expr.access_descriptors) - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape = expr.shape - - for study in cur_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end. - study_to_axis_number[study] = len(new_shape) - new_shape = *new_shape, study.axis_size, - - for ind, array in enumerate(new_arrays): - for _, axis in enumerate(array.axes): - axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) - if axis_tags: - assert len(axis_tags) == 1 - new_access_descriptors[ind] = new_access_descriptors[ind] + \ - (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) - return Einsum(tuple(new_access_descriptors), new_arrays, - axes=expr.axes + new_axes_for_end, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - """ + return super().map_einsum(expr) # }}} Operations with multiple predecessors. From fe17b51333521e660f118fa0e636180bafe5639f Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 16:25:19 -0500 Subject: [PATCH 03/27] Add some test cases for the mapper. --- pytato/transform/parameter_study.py | 12 +-- test/test_pytato.py | 132 ++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 5 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 5317627a2..0f388661f 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -144,11 +144,13 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, key=lambda x: str(x.__class__))): new_shape = (*new_shape, study.axis_size) studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) - for arr in study_to_arrays[frozenset((study,))]: - if arr in array_to_canonical_ordered_studies.keys(): - array_to_canonical_ordered_studies[arr] = (*array_to_canonical_ordered_studies[arr], ind) # noqa - else: - array_to_canonical_ordered_studies[arr] = (ind,) + print(study_to_arrays) + if study_to_arrays: + for arr in study_to_arrays[frozenset((study,))]: + if arr in array_to_canonical_ordered_studies.keys(): + array_to_canonical_ordered_studies[arr] = (*array_to_canonical_ordered_studies[arr], ind) # noqa + else: + array_to_canonical_ordered_studies[arr] = (ind,) return new_shape, studies_axes, array_to_canonical_ordered_studies diff --git a/test/test_pytato.py b/test/test_pytato.py index d76c65d62..f1f1aef71 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1048,6 +1048,138 @@ def test_lower_to_index_lambda(): assert isinstance(binding, Reshape) +# {{{ Expansion Mapper tests. +def test_expansion_mapper_placeholder(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int) + assert expr.shape == (15, 5) + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (15, 5, 10) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_expansion_mapper_basic_index(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + + assert expr.shape == () + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10,) + assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) + + +def test_expansion_mapper_index_lambda(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + + assert expr.shape == (100,) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, 10) + assert isinstance(new_expr, pt.IndexLambda) + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_expansion_mapper_roll(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + expr = pt.roll(expr, axis=0, shift=22) + + assert expr.shape == (100,) + assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, 10,) + assert isinstance(new_expr, pt.Roll) + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_expansion_mapper_axis_permutation(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + assert expr.shape == (5, 15) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (5, 15, 10) + assert isinstance(new_expr, pt.AxisPermutation) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_expansion_mapper_stack(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "my_array" + study1 = Study1(0, 10) + arr2 = "foo" + study2 = Study2(0, 1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0) + assert out_expr.shape == (2, 5, 15) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (2, 5, 15, 10, 1000) + assert isinstance(new_expr, pt.Stack) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 2: + assert tags + else: + assert not tags + + assert not new_expr.axes[3].tags_of_type(Study2) + assert not new_expr.axes[4].tags_of_type(Study1) + +# }}} + + def test_cached_walk_mapper_with_extra_args(): from testlib import RandomDAGContext, make_random_dag From 50fa04720b55b51c5d8ec47598e1a2790d63e9bb Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 17:25:48 -0500 Subject: [PATCH 04/27] Add a test case in for the einsum and ensure that stack returns an updated array. --- pytato/transform/parameter_study.py | 6 +- test/test_pytato.py | 184 ++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 3 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 0f388661f..582988d2d 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -232,7 +232,6 @@ def map_reshape(self, expr: Reshape) -> Array: def map_stack(self, expr: Stack) -> Array: new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, @@ -265,8 +264,8 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Arrays corrected_new_arrays: ArraysT = () for iarr, array in enumerate(new_predecessors): tmp = cp_map(array) # Get a copy of the array. - if len(array.axes) < len(new_axes): - # We need to grow the array to the new size. + # We need to grow the array to the new size. + if arrays_to_study_num_present: studies_present = arrays_to_study_num_present[array] for ind, size in enumerate(studies_shape): if ind not in studies_present: @@ -278,6 +277,7 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Arrays axes=new_axes[:ind], tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + corrected_new_arrays = (*corrected_new_arrays, tmp) return corrected_new_arrays, new_axes diff --git a/test/test_pytato.py b/test/test_pytato.py index f1f1aef71..ca4f41609 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1084,6 +1084,44 @@ def test_expansion_mapper_basic_index(): assert new_expr.shape == (10,) assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) +def test_expansion_mapper_advanced_index_contiguous_axes(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] + + assert expr.shape == (10,5) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 5, 10) + assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) + assert isinstance(expr, type(new_expr)) + +def test_expansion_mapper_advanced_index_non_contiguous_axes(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + name = "my_array" + my_study = ParameterStudyAxisTag(0, 10) + name_to_studies = {name: frozenset((my_study,))} + ind0 = pt.arange(10, dtype=int).reshape(10,1) + ind1 = pt.arange(2, dtype=int).reshape(1,2) + expr = pt.make_placeholder(name, (15, 1000, 5), dtype=int)[ind0, :, ind1] + + assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) + assert expr.shape == (10, 2, 1000) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 2, 1000, 10) + assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) + assert isinstance(expr, type(new_expr)) def test_expansion_mapper_index_lambda(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1099,6 +1137,9 @@ def test_expansion_mapper_index_lambda(): new_expr = my_mapper(expr) assert new_expr.shape == (100, 10) assert isinstance(new_expr, pt.IndexLambda) + + scalar_expr = new_expr.expr + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) @@ -1142,6 +1183,40 @@ def test_expansion_mapper_axis_permutation(): else: assert not tags +def test_expansion_mapper_reshape(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "my_array" + study1 = Study1(0, 10) + arr2 = "foo" + study2 = Study2(0, 1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) + assert out_expr.shape == (10, 15) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, 10, 1000) + assert isinstance(new_expr, pt.Reshape) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) + def test_expansion_mapper_stack(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1177,6 +1252,115 @@ class Study1(ParameterStudyAxisTag): assert not new_expr.axes[3].tags_of_type(Study2) assert not new_expr.axes[4].tags_of_type(Study1) +def test_expansion_mapper_concatenate(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "my_array" + study1 = Study1(0, 10) + arr2 = "foo" + study2 = Study2(0, 1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + + out_expr = pt.concatenate([expr, expr2], axis=0) + assert out_expr.shape == (10, 15) + + my_mapper = ExpansionMapper(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, 10, 1000) + assert isinstance(new_expr, pt.Concatenate) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) + +def test_expansion_mapper_einsum_matmul(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "my_array" + study1 = Study1(0, 10) + arr2 = "foo" + study2 = Study2(0, 1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + + # Matmul gets expanded correctly. + a = pt.make_placeholder(name, (47, 42), dtype=int) + b = pt.make_placeholder(arr2, (42, 5), dtype=int) + + c = pt.matmul(a,b) + assert isinstance(c, pt.Einsum) + assert c.shape == (47, 5) + + my_mapper = ExpansionMapper(name_to_studies) + updated_c = my_mapper(c) + + assert updated_c.shape == (47, 5, 10, 1000) + + + +def test_expansion_mapper_does_nothing_if_tags_not_there(): + from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "my_array" + study1 = Study1(0, 10) + arr2 = "foo" + study2 = Study2(0, 1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + from testlib import RandomDAGContext, make_random_dag + import pickle + from pytools import UniqueNameGenerator + axis_len = 5 + + for i in range(50): + print(i) # progress indicator + + seed = 120 + i + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + dag = pt.make_dict_of_named_arrays({"out": make_random_dag(rdagc_pt)}) + + # {{{ convert data-wrappers to placeholders + + vng = UniqueNameGenerator() + + def make_dws_placeholder(expr): + if isinstance(expr, pt.DataWrapper): + return pt.make_placeholder(vng("_pt_ph"), # noqa: B023 + expr.shape, expr.dtype) + else: + return expr + + dag = pt.transform.map_and_copy(dag, make_dws_placeholder) + + my_mapper = ExpansionMapper(name_to_studies) + new_dag = my_mapper(dag) + + assert new_dag == dag + + # }}} # }}} From cf7634e6e55f3186e6c2a50af214e571a0f19922 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 23 Jul 2024 14:26:44 -0500 Subject: [PATCH 05/27] Save of data before branching. --- pytato/transform/parameter_study.py | 22 ++++++++++++++++++++++ test/test_pytato.py | 18 +++++++++++------- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 582988d2d..d66087f8a 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -52,6 +52,7 @@ AxisPermutation, Concatenate, Einsum, + EinsumElementwiseAxis, IndexBase, IndexLambda, NormalizedSlice, @@ -334,6 +335,27 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: def map_einsum(self, expr: Einsum) -> Array: + new_predecessors = tuple(self.rec(arg) for arg in expr.args) + _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa + + access_descriptors = () + for ival, array in enumerate(new_predecessors): + one_descr = expr.access_descriptors[ival] + if arrays_to_study_num_present: + for ind in arrays_to_study_num_present[array]: + one_descr = (*one_descr, + # Adding in new element axes to the end of the arrays. + EinsumElementwiseAxis(dim=len(expr.shape) + ind)) + access_descriptors = (*access_descriptors, one_descr) + out = Einsum(access_descriptors, + new_predecessors, + axes=expr.axes + new_axes, + redn_axis_to_redn_descr = expr.redn_axis_to_redn_descr, + index_to_access_descr = expr.index_to_access_descr, + tags = expr.tags, + non_equality_tags = expr.non_equality_tags) + breakpoint() + return super().map_einsum(expr) # }}} Operations with multiple predecessors. diff --git a/test/test_pytato.py b/test/test_pytato.py index ca4f41609..346097be8 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1084,6 +1084,7 @@ def test_expansion_mapper_basic_index(): assert new_expr.shape == (10,) assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) + def test_expansion_mapper_advanced_index_contiguous_axes(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1092,7 +1093,7 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] - assert expr.shape == (10,5) + assert expr.shape == (10, 5) my_mapper = ExpansionMapper(name_to_studies) new_expr = my_mapper(expr) @@ -1102,14 +1103,15 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) assert isinstance(expr, type(new_expr)) + def test_expansion_mapper_advanced_index_non_contiguous_axes(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper name = "my_array" my_study = ParameterStudyAxisTag(0, 10) name_to_studies = {name: frozenset((my_study,))} - ind0 = pt.arange(10, dtype=int).reshape(10,1) - ind1 = pt.arange(2, dtype=int).reshape(1,2) + ind0 = pt.arange(10, dtype=int).reshape(10, 1) + ind1 = pt.arange(2, dtype=int).reshape(1, 2) expr = pt.make_placeholder(name, (15, 1000, 5), dtype=int)[ind0, :, ind1] assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) @@ -1123,6 +1125,7 @@ def test_expansion_mapper_advanced_index_non_contiguous_axes(): assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) assert isinstance(expr, type(new_expr)) + def test_expansion_mapper_index_lambda(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1183,6 +1186,7 @@ def test_expansion_mapper_axis_permutation(): else: assert not tags + def test_expansion_mapper_reshape(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1252,6 +1256,7 @@ class Study1(ParameterStudyAxisTag): assert not new_expr.axes[3].tags_of_type(Study2) assert not new_expr.axes[4].tags_of_type(Study1) + def test_expansion_mapper_concatenate(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1286,6 +1291,7 @@ class Study1(ParameterStudyAxisTag): assert not new_expr.axes[2].tags_of_type(Study2) assert not new_expr.axes[3].tags_of_type(Study1) + def test_expansion_mapper_einsum_matmul(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1299,12 +1305,12 @@ class Study1(ParameterStudyAxisTag): arr2 = "foo" study2 = Study2(0, 1000) name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - + # Matmul gets expanded correctly. a = pt.make_placeholder(name, (47, 42), dtype=int) b = pt.make_placeholder(arr2, (42, 5), dtype=int) - c = pt.matmul(a,b) + c = pt.matmul(a, b) assert isinstance(c, pt.Einsum) assert c.shape == (47, 5) @@ -1314,7 +1320,6 @@ class Study1(ParameterStudyAxisTag): assert updated_c.shape == (47, 5, 10, 1000) - def test_expansion_mapper_does_nothing_if_tags_not_there(): from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper @@ -1329,7 +1334,6 @@ class Study1(ParameterStudyAxisTag): study2 = Study2(0, 1000) name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} from testlib import RandomDAGContext, make_random_dag - import pickle from pytools import UniqueNameGenerator axis_len = 5 From 134f372b0b33d7802af063edac7783c0178550e0 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 00:48:29 -0500 Subject: [PATCH 06/27] Each parameter study only needs to know how many items are being tested in it. Not the axis as well. --- pytato/transform/parameter_study.py | 76 ++++++++++++----------------- 1 file changed, 30 insertions(+), 46 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index d66087f8a..8235f8f81 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -74,15 +74,14 @@ @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): """ - A tag for acting on axes of arrays. - To enable multiple parameter studies on the same variable name - specify a different axis number and potentially a different size. - - Currently does not allow multiple variables of different names to be in - the same parameter study. + A tag to indicate that the axis is being used + for independent trials like in a parameter study. + If you want to vary multiple input variables in the + same study then you need to have the same type of + class: 'ParameterStudyAxisTag'. """ - axis_num: int axis_size: int + assert axis_size > 0 StudiesT = Tuple[ParameterStudyAxisTag, ...] @@ -96,12 +95,11 @@ def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT] super().__init__() self.placeholder_name_to_parameter_studies = placeholder_name_to_parameter_studies # noqa - def _shapes_and_axes_from_predecessor(self, curr_expr: Array, + def _shapes_and_axes_from_predecessors(self, curr_expr: Array, mapped_preds: ArraysT) -> \ Tuple[KnownShapeType, AxesT, Dict[Array, Tuple[int, ...]]]: - # Initialize with something for the typing. assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in curr_expr.axes) @@ -118,6 +116,7 @@ def _shapes_and_axes_from_predecessor(self, curr_expr: Array, for axis in arr.axes: tags = axis.tags_of_type(ParameterStudyAxisTag) if tags: + assert len(tags) == 1 # only one study per axis. active_studies = active_studies.union(tags) if tags in study_to_arrays.keys(): study_to_arrays[tags] = (*study_to_arrays[tags], arr) @@ -136,7 +135,7 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, -> Tuple[KnownShapeType, AxesT, Dict[Array, Tuple[int, ...]]]: - # This is where we specify the canonical ordering. + # This is where we specify the canonical ordering of the studies. array_to_canonical_ordered_studies: Dict[Array, Tuple[int, ...]] = {} studies_axes = new_axes @@ -178,7 +177,7 @@ def map_placeholder(self, expr: Placeholder) -> Array: def map_roll(self, expr: Roll) -> Array: new_predecessor = self.rec(expr.array) - _, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + _, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) return Roll(array=new_predecessor, @@ -190,7 +189,7 @@ def map_roll(self, expr: Roll) -> Array: def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) @@ -204,7 +203,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: def _map_index_base(self, expr: IndexBase) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) # Update the indicies. new_indices = expr.indices @@ -219,7 +218,7 @@ def _map_index_base(self, expr: IndexBase) -> Array: def map_reshape(self, expr: Reshape) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessor(expr, + postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) return Reshape(new_predecessor, newshape=self.rec_idx_or_size_tuple(expr.newshape + @@ -251,9 +250,16 @@ def map_concatenate(self, expr: Concatenate) -> Array: def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, AxesT]: + """ + This method will convert predecessors who were originally the same + shape in a single instance program to the same shape in this multiple + instance program. + """ + assert isinstance(expr, [Stack, Concatenate]) + new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) - studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa + studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa # This is going to be expensive. @@ -278,6 +284,8 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Arrays axes=new_axes[:ind], tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + + assert tmp.shape[-(1 + len(studies_shape)):] == studies_shape corrected_new_arrays = (*corrected_new_arrays, tmp) return corrected_new_arrays, new_axes @@ -290,35 +298,12 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Determine the new parameter studies that are being conducted. from pytools import unique + postpend_shape, new_axes, array_to_studies = self._shapes_and_axes_from_predecessors(expr, + (new_bindings,)) + + varname_to_studies = { array.name: studies for array, + studies in array_to_studies.items()} - all_axis_tags: StudiesT = () - varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {} - for name, bnd in sorted(new_bindings.items()): - axis_tags_for_bnd: Set[Tag] = set() - varname_to_studies[name] = {} - for i in range(len(bnd.axes)): - axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag)) # noqa - for tag in axis_tags_for_bnd: - if isinstance(tag, ParameterStudyAxisTag): - # Defense - varname_to_studies[name][tag] = True - all_axis_tags = *all_axis_tags, tag, - - cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape = expr.shape - new_axes = expr.axes - - for study in cur_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end of the bindings. - study_to_axis_number[study] = len(new_shape) - new_shape = (*new_shape, study.axis_size,) - new_axes = (*new_axes, Axis(tags=frozenset((study,))),) - # This assumes that the axis only has 1 tag, - # because there should be no dependence across instances. # Now we need to update the expressions. scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, @@ -336,7 +321,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: def map_einsum(self, expr: Einsum) -> Array: new_predecessors = tuple(self.rec(arg) for arg in expr.args) - _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessor(expr, new_predecessors) # noqa + _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa access_descriptors = () for ival, array in enumerate(new_predecessors): @@ -364,8 +349,7 @@ def map_einsum(self, expr: Einsum) -> Array: class ParamAxisExpander(IdentityMapper): def map_subscript(self, expr: prim.Subscript, - varname_to_studies: Mapping[str, - Mapping[ParameterStudyAxisTag, bool]], + varname_to_studies: Mapping[str,StudiesT], study_to_axis_number: Mapping[ParameterStudyAxisTag, int]) -> \ prim.Subscript: # We know that we are not changing the variable that we are indexing into. From a0f608b6455cb2e0f9c3b41414a3e8cd8387a5ff Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 15:29:52 -0500 Subject: [PATCH 07/27] Update the scalar expression mapper to use the appropriate indicies on the variable even if it is not initially subscripted. --- pytato/transform/parameter_study.py | 128 +++++++++++++++------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 8235f8f81..bd7dbd95a 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -1,5 +1,6 @@ from __future__ import annotations + """ .. currentmodule:: pytato.transform @@ -32,19 +33,17 @@ THE SOFTWARE. """ -from immutabledict import immutabledict from dataclasses import dataclass from typing import ( - Dict, - FrozenSet, Iterable, Mapping, - Sequence, - Set, - Tuple, - Union, ) +from immutabledict import immutabledict + +import pymbolic.primitives as prim +from pytools.tag import UniqueTag + from pytato.array import ( Array, AxesT, @@ -61,13 +60,7 @@ Roll, Stack, ) - from pytato.scalar_expr import IdentityMapper, IntegralT - -import pymbolic.primitives as prim - -from pytools.tag import UniqueTag, Tag - from pytato.transform import CopyMapper @@ -81,12 +74,11 @@ class ParameterStudyAxisTag(UniqueTag): class: 'ParameterStudyAxisTag'. """ axis_size: int - assert axis_size > 0 -StudiesT = Tuple[ParameterStudyAxisTag, ...] -ArraysT = Tuple[Array, ...] -KnownShapeType = Tuple[IntegralT, ...] +StudiesT = tuple[ParameterStudyAxisTag, ...] +ArraysT = tuple[Array, ...] +KnownShapeType = tuple[IntegralT, ...] class ExpansionMapper(CopyMapper): @@ -97,9 +89,9 @@ def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT] def _shapes_and_axes_from_predecessors(self, curr_expr: Array, mapped_preds: ArraysT) -> \ - Tuple[KnownShapeType, + tuple[KnownShapeType, AxesT, - Dict[Array, Tuple[int, ...]]]: + dict[Array, tuple[int, ...]]]: assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in curr_expr.axes) @@ -108,15 +100,15 @@ def _shapes_and_axes_from_predecessors(self, curr_expr: Array, new_shape: KnownShapeType = () studies_axes: AxesT = () - study_to_arrays: Dict[FrozenSet[ParameterStudyAxisTag], ArraysT] = {} + study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT] = {} - active_studies: Set[ParameterStudyAxisTag] = set() + active_studies: set[ParameterStudyAxisTag] = set() for arr in mapped_preds: for axis in arr.axes: tags = axis.tags_of_type(ParameterStudyAxisTag) if tags: - assert len(tags) == 1 # only one study per axis. + assert len(tags) == 1 # only one study per axis. active_studies = active_studies.union(tags) if tags in study_to_arrays.keys(): study_to_arrays[tags] = (*study_to_arrays[tags], arr) @@ -131,20 +123,19 @@ def _shapes_and_axes_from_predecessors(self, curr_expr: Array, def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, studies: Iterable[ParameterStudyAxisTag], new_shape: KnownShapeType, new_axes: AxesT, - study_to_arrays: Dict[FrozenSet[ParameterStudyAxisTag], ArraysT]) \ - -> Tuple[KnownShapeType, AxesT, Dict[Array, - Tuple[int, ...]]]: + study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT]) \ + -> tuple[KnownShapeType, AxesT, dict[Array, + tuple[int, ...]]]: # This is where we specify the canonical ordering of the studies. - array_to_canonical_ordered_studies: Dict[Array, Tuple[int, ...]] = {} + array_to_canonical_ordered_studies: dict[Array, tuple[int, ...]] = {} studies_axes = new_axes for ind, study in enumerate(sorted(studies, key=lambda x: str(x.__class__))): new_shape = (*new_shape, study.axis_size) studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) - print(study_to_arrays) if study_to_arrays: for arr in study_to_arrays[frozenset((study,))]: if arr in array_to_canonical_ordered_studies.keys(): @@ -247,15 +238,14 @@ def map_concatenate(self, expr: Concatenate) -> Array: tags=expr.tags, non_equality_tags=expr.non_equality_tags) - def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT, - AxesT]: + def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, + AxesT]: """ This method will convert predecessors who were originally the same shape in a single instance program to the same shape in this multiple instance program. """ - assert isinstance(expr, [Stack, Concatenate]) new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) @@ -285,36 +275,38 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Arrays tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) - assert tmp.shape[-(1 + len(studies_shape)):] == studies_shape + assert tmp.shape[-(len(studies_shape)):] == studies_shape corrected_new_arrays = (*corrected_new_arrays, tmp) return corrected_new_arrays, new_axes def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. - new_bindings: Dict[str, Array] = {name: self.rec(bnd) + new_bindings: dict[str, Array] = {name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} + new_arrays = list(new_bindings.values()) - # Determine the new parameter studies that are being conducted. - from pytools import unique - postpend_shape, new_axes, array_to_studies = self._shapes_and_axes_from_predecessors(expr, - (new_bindings,)) + array_to_bnd_name: dict[Array, str] = {bnd: name for name, bnd + in sorted(new_bindings.items())} - varname_to_studies = { array.name: studies for array, - studies in array_to_studies.items()} + # Determine the new parameter studies that are being conducted. + postpend_shape, new_axes, array_to_studies_num = self._shapes_and_axes_from_predecessors(expr, # noqa + new_arrays) + varname_to_studies_nums = {array_to_bnd_name[array]: studies for array, + studies in array_to_studies_num.items()} # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies, - study_to_axis_number) + scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies_nums, + len(expr.shape)) return IndexLambda(expr=scalar_expr, bindings=immutabledict(new_bindings), - shape=new_shape, + shape=(*expr.shape, *postpend_shape,), var_to_reduction_descr=expr.var_to_reduction_descr, dtype=expr.dtype, - axes=new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -322,7 +314,7 @@ def map_einsum(self, expr: Einsum) -> Array: new_predecessors = tuple(self.rec(arg) for arg in expr.args) _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa - + access_descriptors = () for ival, array in enumerate(new_predecessors): one_descr = expr.access_descriptors[ival] @@ -335,13 +327,11 @@ def map_einsum(self, expr: Einsum) -> Array: out = Einsum(access_descriptors, new_predecessors, axes=expr.axes + new_axes, - redn_axis_to_redn_descr = expr.redn_axis_to_redn_descr, - index_to_access_descr = expr.index_to_access_descr, - tags = expr.tags, - non_equality_tags = expr.non_equality_tags) - breakpoint() + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - return super().map_einsum(expr) + return out # }}} Operations with multiple predecessors. @@ -349,25 +339,23 @@ def map_einsum(self, expr: Einsum) -> Array: class ParamAxisExpander(IdentityMapper): def map_subscript(self, expr: prim.Subscript, - varname_to_studies: Mapping[str,StudiesT], - study_to_axis_number: Mapping[ParameterStudyAxisTag, int]) -> \ - prim.Subscript: + varname_to_studies_num: Mapping[str, tuple[int, ...]], + num_original_inds: int) -> prim.Subscript: # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. # We only need to modify the indexing which is stored in the index member. name = expr.aggregate.name - if name in varname_to_studies.keys(): + if name in varname_to_studies_num.keys(): # These are the single instance information. - index = self.rec(expr.index, varname_to_studies, - study_to_axis_number) + index = self.rec(expr.index, varname_to_studies_num, + num_original_inds) - new_vars: Tuple[prim.Variable, ...] = () + new_vars: tuple[prim.Variable, ...] = () + my_studies: tuple[int, ...] = varname_to_studies_num[expr.name] - for key, num in sorted(study_to_axis_number.items(), - key=lambda item: item[1]): - if key in varname_to_studies[name]: - new_vars = *new_vars, prim.Variable(f"_{num}"), + for num in my_studies: + new_vars = *new_vars, prim.Variable(f"_{num_original_inds + num}"), if isinstance(index, tuple): index = index + new_vars @@ -375,3 +363,23 @@ def map_subscript(self, expr: prim.Subscript, index = tuple(index) + new_vars return type(expr)(aggregate=expr.aggregate, index=index) return expr + + def map_variable(self, expr: prim.Variable, + varname_to_studies: Mapping[str, tuple[int, ...]], + num_original_inds: int) -> prim.Expression: + # We know that a variable is a leaf node. So we only need to update it + # if the variable is part of a study. + + breakpoint() + if expr.name in varname_to_studies.keys(): + # These are the single instance information. + # In the multiple instance we will need to index into the variable. + + new_vars: tuple[prim.Variable, ...] = () + my_studies: tuple[int, ...] = varname_to_studies[expr.name] + + for num in my_studies: + new_vars = *new_vars, prim.Variable(f"_{num_original_inds + num}"), + + return prim.Subscript(aggregate=expr, index=new_vars) + return expr From 9e7830bef04b37f14a2b29cef974b486905c3e45 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 15:38:41 -0500 Subject: [PATCH 08/27] Corrected mypy issues. --- pytato/transform/parameter_study.py | 6 +- test/test_pytato.py | 145 +++++++++++++--------------- 2 files changed, 72 insertions(+), 79 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index bd7dbd95a..bf4654851 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -51,6 +51,7 @@ AxisPermutation, Concatenate, Einsum, + EinsumAxisDescriptor, EinsumElementwiseAxis, IndexBase, IndexLambda, @@ -285,7 +286,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: new_bindings: dict[str, Array] = {name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} - new_arrays = list(new_bindings.values()) + new_arrays = (*new_bindings.values(),) array_to_bnd_name: dict[Array, str] = {bnd: name for name, bnd in sorted(new_bindings.items())} @@ -315,7 +316,7 @@ def map_einsum(self, expr: Einsum) -> Array: new_predecessors = tuple(self.rec(arg) for arg in expr.args) _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa - access_descriptors = () + access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...] = () for ival, array in enumerate(new_predecessors): one_descr = expr.access_descriptors[ival] if arrays_to_study_num_present: @@ -370,7 +371,6 @@ def map_variable(self, expr: prim.Variable, # We know that a variable is a leaf node. So we only need to update it # if the variable is part of a study. - breakpoint() if expr.name in varname_to_studies.keys(): # These are the single instance information. # In the multiple instance we will need to index into the variable. diff --git a/test/test_pytato.py b/test/test_pytato.py index b04373aa8..bef2f23b1 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -28,6 +28,9 @@ """ import sys +from typing import ( + Mapping, +) import attrs import numpy as np @@ -40,6 +43,7 @@ import pytato as pt from pytato.array import _SuppliedAxesAndTagsMixin +from pytato.transform.parameter_study import ParameterStudyAxisTag def test_matmul_input_validation(): @@ -1064,10 +1068,10 @@ def test_lower_to_index_lambda(): # {{{ Expansion Mapper tests. def test_expansion_mapper_placeholder(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int) assert expr.shape == (15, 5) @@ -1084,10 +1088,10 @@ def test_expansion_mapper_placeholder(): def test_expansion_mapper_basic_index(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] @@ -1100,10 +1104,10 @@ def test_expansion_mapper_basic_index(): def test_expansion_mapper_advanced_index_contiguous_axes(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] @@ -1119,10 +1123,10 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): def test_expansion_mapper_advanced_index_non_contiguous_axes(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} ind0 = pt.arange(10, dtype=int).reshape(10, 1) ind1 = pt.arange(2, dtype=int).reshape(1, 2) @@ -1141,10 +1145,10 @@ def test_expansion_mapper_advanced_index_non_contiguous_axes(): def test_expansion_mapper_index_lambda(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) @@ -1157,14 +1161,17 @@ def test_expansion_mapper_index_lambda(): scalar_expr = new_expr.expr + assert len(scalar_expr.children) == len(expr.expr.children) + assert scalar_expr != expr.expr + # We modified it so that we have the new axis. assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) def test_expansion_mapper_roll(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) expr = pt.roll(expr, axis=0, shift=22) @@ -1180,10 +1187,10 @@ def test_expansion_mapper_roll(): def test_expansion_mapper_axis_permutation(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag name = "my_array" - my_study = ParameterStudyAxisTag(0, 10) + my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) assert expr.shape == (5, 15) @@ -1202,20 +1209,13 @@ def test_expansion_mapper_axis_permutation(): def test_expansion_mapper_reshape(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper - class Study2(ParameterStudyAxisTag): - pass - - class Study1(ParameterStudyAxisTag): - pass - name = "my_array" - study1 = Study1(0, 10) - arr2 = "foo" - study2 = Study2(0, 1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + name_to_studies, studies, names = _set_up_expansion_mapper_tests() + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) assert out_expr.shape == (10, 15) @@ -1232,25 +1232,19 @@ class Study1(ParameterStudyAxisTag): else: assert not tags - assert not new_expr.axes[2].tags_of_type(Study2) - assert not new_expr.axes[3].tags_of_type(Study1) + assert not new_expr.axes[2].tags_of_type(studies[1]) + assert not new_expr.axes[3].tags_of_type(studies[0]) def test_expansion_mapper_stack(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper - class Study2(ParameterStudyAxisTag): - pass + name_to_studies, studies, names = _set_up_expansion_mapper_tests() - class Study1(ParameterStudyAxisTag): - pass - name = "my_array" - study1 = Study1(0, 10) - arr2 = "foo" - study2 = Study2(0, 1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0) assert out_expr.shape == (2, 5, 15) @@ -1267,25 +1261,38 @@ class Study1(ParameterStudyAxisTag): else: assert not tags - assert not new_expr.axes[3].tags_of_type(Study2) - assert not new_expr.axes[4].tags_of_type(Study1) + assert not new_expr.axes[3].tags_of_type(studies[1]) + assert not new_expr.axes[4].tags_of_type(studies[0]) -def test_expansion_mapper_concatenate(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper +def _set_up_expansion_mapper_tests() -> tuple[Mapping[str, + frozenset[ParameterStudyAxisTag]], + tuple[ParameterStudyAxisTag, ...], + tuple[str, ...]]: class Study2(ParameterStudyAxisTag): pass class Study1(ParameterStudyAxisTag): pass - name = "my_array" - study1 = Study1(0, 10) - arr2 = "foo" - study2 = Study2(0, 1000) + name = "a" + study1 = Study1(10) + arr2 = "b" + study2 = Study2(1000) name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(arr2, (15, 5), dtype=int)) + + return name_to_studies, (Study1, Study2,), (name, arr2,) + + +def test_expansion_mapper_concatenate(): + from pytato.transform.parameter_study import ExpansionMapper + + name_to_studies, studies, names = _set_up_expansion_mapper_tests() + + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) out_expr = pt.concatenate([expr, expr2], axis=0) assert out_expr.shape == (10, 15) @@ -1302,27 +1309,20 @@ class Study1(ParameterStudyAxisTag): else: assert not tags - assert not new_expr.axes[2].tags_of_type(Study2) - assert not new_expr.axes[3].tags_of_type(Study1) + assert not new_expr.axes[2].tags_of_type(studies[1]) + assert not new_expr.axes[3].tags_of_type(studies[0]) def test_expansion_mapper_einsum_matmul(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper - class Study2(ParameterStudyAxisTag): - pass - - class Study1(ParameterStudyAxisTag): - pass - name = "my_array" - study1 = Study1(0, 10) - arr2 = "foo" - study2 = Study2(0, 1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + name_to_studies, _, names = _set_up_expansion_mapper_tests() # Matmul gets expanded correctly. - a = pt.make_placeholder(name, (47, 42), dtype=int) - b = pt.make_placeholder(arr2, (42, 5), dtype=int) + a = pt.make_placeholder(names[0], + (47, 42), dtype=int) + b = pt.make_placeholder(names[1], + (42, 5), dtype=int) c = pt.matmul(a, b) assert isinstance(c, pt.Einsum) @@ -1335,19 +1335,12 @@ class Study1(ParameterStudyAxisTag): def test_expansion_mapper_does_nothing_if_tags_not_there(): - from pytato.transform.parameter_study import ParameterStudyAxisTag, ExpansionMapper + from pytato.transform.parameter_study import ExpansionMapper - class Study2(ParameterStudyAxisTag): - pass + name_to_studies, _, _ = _set_up_expansion_mapper_tests() - class Study1(ParameterStudyAxisTag): - pass - name = "my_array" - study1 = Study1(0, 10) - arr2 = "foo" - study2 = Study2(0, 1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} from testlib import RandomDAGContext, make_random_dag + from pytools import UniqueNameGenerator axis_len = 5 From c86930c5b876e3fb95a9323eac7e8ba7c1119976 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 16:19:26 -0500 Subject: [PATCH 09/27] Add a not implemented error for the function objects and the distributed programming constructs. --- pytato/transform/parameter_study.py | 44 ++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index bf4654851..06734f41a 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -45,6 +45,7 @@ from pytools.tag import UniqueTag from pytato.array import ( + AbstractResultWithNamedArrays, Array, AxesT, Axis, @@ -61,6 +62,15 @@ Roll, Stack, ) +from pytato.distributed.nodes import ( + DistributedRecv, + DistributedSendRefHolder, +) +from pytato.function import ( + Call, + FunctionDefinition, + NamedCallResult, +) from pytato.scalar_expr import IdentityMapper, IntegralT from pytato.transform import CopyMapper @@ -197,7 +207,7 @@ def _map_index_base(self, expr: IndexBase) -> Array: new_predecessor = self.rec(expr.array) postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) - # Update the indicies. + # Update the indices. new_indices = expr.indices for shape in postpend_shape: new_indices = (*new_indices, NormalizedSlice(0, shape, 1)) @@ -325,17 +335,43 @@ def map_einsum(self, expr: Einsum) -> Array: # Adding in new element axes to the end of the arrays. EinsumElementwiseAxis(dim=len(expr.shape) + ind)) access_descriptors = (*access_descriptors, one_descr) - out = Einsum(access_descriptors, + + return Einsum(access_descriptors, new_predecessors, axes=expr.axes + new_axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) - return out - # }}} Operations with multiple predecessors. + # {{{ Function definitions + def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: + raise NotImplementedError(" Expanding functions is not yet supported.") + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + raise NotImplementedError(" Expanding functions is not yet suppported.") + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + raise NotImplementedError(" Expanding functions is not yet suppported.") + + # }}} + + # {{{ Distributed Programming Constructs + def map_distributed_recv(self, expr: DistributedRecv) -> DistributedRecv: + # This data will depend solely on the rank sending it to you. + raise NotImplementedError(" Expanding distributed programming constructs is" + " not yet supported.") + + def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder) \ + -> Array: + # We are sending the data. This data may increase in size due to the + # parameter studies. + raise NotImplementedError(" Expanding distributed programming constructs is" + " not yet supported.") + + # }}} + class ParamAxisExpander(IdentityMapper): From 0c941bef53044b3f4c0aea01c0cb2d101370346c Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 30 Jul 2024 16:20:52 -0500 Subject: [PATCH 10/27] Fix a typo in a comment. --- pytato/transform/parameter_study.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 06734f41a..89a6f9c0f 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -350,10 +350,10 @@ def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinitio raise NotImplementedError(" Expanding functions is not yet supported.") def map_named_call_result(self, expr: NamedCallResult) -> Array: - raise NotImplementedError(" Expanding functions is not yet suppported.") + raise NotImplementedError(" Expanding functions is not yet supported.") def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - raise NotImplementedError(" Expanding functions is not yet suppported.") + raise NotImplementedError(" Expanding functions is not yet supported.") # }}} From e9040212a6c7d32748efa391d40e100c047c1e5a Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 1 Aug 2024 15:53:22 -0500 Subject: [PATCH 11/27] Add the fixes from testing with Mirgecom. --- pytato/transform/parameter_study.py | 57 +++++++++++++++++++---------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 89a6f9c0f..bfeaf37a0 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -42,6 +42,7 @@ from immutabledict import immutabledict import pymbolic.primitives as prim +from pytools import unique from pytools.tag import UniqueTag from pytato.array import ( @@ -126,10 +127,22 @@ def _shapes_and_axes_from_predecessors(self, curr_expr: Array, else: study_to_arrays[tags] = (arr,) - return self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa + ps, na, arr_to_studies = self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa new_shape, studies_axes, study_to_arrays) + # Add in the arrays that are not a part of a parameter study. + # This is done to avoid any KeyErrors later. + + for arr in unique(mapped_preds): # pytools unique maintains the order. + if arr not in arr_to_studies.keys(): + arr_to_studies[arr] = () + else: + assert len(arr_to_studies[arr]) > 0 + + assert len(arr_to_studies) == len(list(unique(mapped_preds))) + + return ps, na, arr_to_studies def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, studies: Iterable[ParameterStudyAxisTag], @@ -232,23 +245,6 @@ def map_reshape(self, expr: Reshape) -> Array: # {{{ Operations with multiple predecessors. - def map_stack(self, expr: Stack) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - return Stack(arrays=new_arrays, - axis=expr.axis, - axes=expr.axes + new_axes_for_end, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_concatenate(self, expr: Concatenate) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - - return Concatenate(arrays=new_arrays, - axis=expr.axis, - axes=expr.axes + new_axes_for_end, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, AxesT]: @@ -286,11 +282,32 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) - assert tmp.shape[-(len(studies_shape)):] == studies_shape + if studies_shape: + assert tmp.shape[-(len(studies_shape)):] == studies_shape + else: + assert tmp.shape[-(len(studies_shape)):] == tmp.shape + corrected_new_arrays = (*corrected_new_arrays, tmp) return corrected_new_arrays, new_axes + def map_stack(self, expr: Stack) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + return Stack(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + def map_concatenate(self, expr: Concatenate) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + + return Concatenate(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. new_bindings: dict[str, Array] = {name: self.rec(bnd) @@ -389,7 +406,7 @@ def map_subscript(self, expr: prim.Subscript, num_original_inds) new_vars: tuple[prim.Variable, ...] = () - my_studies: tuple[int, ...] = varname_to_studies_num[expr.name] + my_studies: tuple[int, ...] = varname_to_studies_num[name] for num in my_studies: new_vars = *new_vars, prim.Variable(f"_{num_original_inds + num}"), From ac811bf51e8661c73110ad19731cd2c9c83f03f7 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 6 Aug 2024 15:36:06 -0500 Subject: [PATCH 12/27] Update the expansion mapper. --- pytato/transform/parameter_study.py | 228 +++++++++++++++++----------- 1 file changed, 136 insertions(+), 92 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index bfeaf37a0..62b6a6f97 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -42,7 +42,6 @@ from immutabledict import immutabledict import pymbolic.primitives as prim -from pytools import unique from pytools.tag import UniqueTag from pytato.array import ( @@ -103,14 +102,12 @@ def _shapes_and_axes_from_predecessors(self, curr_expr: Array, mapped_preds: ArraysT) -> \ tuple[KnownShapeType, AxesT, - dict[Array, tuple[int, ...]]]: + dict[int, tuple[int, ...]]]: assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in curr_expr.axes) # We are post pending the axes we are using for parameter studies. - new_shape: KnownShapeType = () - studies_axes: AxesT = () study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT] = {} @@ -127,64 +124,71 @@ def _shapes_and_axes_from_predecessors(self, curr_expr: Array, else: study_to_arrays[tags] = (arr,) - ps, na, arr_to_studies = self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa - new_shape, - studies_axes, - study_to_arrays) + ps, na, arr_num_to_study_nums = self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa + study_to_arrays, mapped_preds) + # Add in the arrays that are not a part of a parameter study. # This is done to avoid any KeyErrors later. - for arr in unique(mapped_preds): # pytools unique maintains the order. - if arr not in arr_to_studies.keys(): - arr_to_studies[arr] = () + for arr_num in range(len(mapped_preds)): + if arr_num not in arr_num_to_study_nums.keys(): + arr_num_to_study_nums[arr_num] = () else: - assert len(arr_to_studies[arr]) > 0 + assert len(arr_num_to_study_nums[arr_num]) > 0 - assert len(arr_to_studies) == len(list(unique(mapped_preds))) + assert len(arr_num_to_study_nums) == len(mapped_preds) - return ps, na, arr_to_studies + for axis in na: + assert axis.tags_of_type(ParameterStudyAxisTag) + + return ps, na, arr_num_to_study_nums def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, studies: Iterable[ParameterStudyAxisTag], - new_shape: KnownShapeType, new_axes: AxesT, - study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT]) \ - -> tuple[KnownShapeType, AxesT, dict[Array, - tuple[int, ...]]]: + study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT], + mapped_preds: ArraysT) -> tuple[KnownShapeType, AxesT, + dict[int, tuple[int, ...]]]: # This is where we specify the canonical ordering of the studies. - - array_to_canonical_ordered_studies: dict[Array, tuple[int, ...]] = {} - studies_axes = new_axes + array_num_to_study_nums: dict[int, tuple[int, ...]] = {} + new_shape: KnownShapeType = () + studies_axes: AxesT = () for ind, study in enumerate(sorted(studies, key=lambda x: str(x.__class__))): new_shape = (*new_shape, study.axis_size) studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) - if study_to_arrays: - for arr in study_to_arrays[frozenset((study,))]: - if arr in array_to_canonical_ordered_studies.keys(): - array_to_canonical_ordered_studies[arr] = (*array_to_canonical_ordered_studies[arr], ind) # noqa + for arr_num, arr in enumerate(mapped_preds): + if arr in study_to_arrays[frozenset((study,))]: + if arr_num in array_num_to_study_nums.keys(): + array_num_to_study_nums[arr_num] = (*array_num_to_study_nums[arr_num], ind) # noqa else: - array_to_canonical_ordered_studies[arr] = (ind,) + array_num_to_study_nums[arr_num] = (ind,) + + assert len(new_shape) == len(studies) + assert len(studies_axes) == len(studies) - return new_shape, studies_axes, array_to_canonical_ordered_studies + return new_shape, studies_axes, array_num_to_study_nums def map_placeholder(self, expr: Placeholder) -> Array: # This is where we could introduce extra axes. if expr.name in self.placeholder_name_to_parameter_studies.keys(): new_axes = expr.axes studies = self.placeholder_name_to_parameter_studies[expr.name] + + new_shape: KnownShapeType = () + new_axes: AxesT = () + + # We know that there are no predecessors and we know the studies to add. + # We need to get them in the right order. new_shape, new_axes, _ = self._studies_to_shape_and_axes_and_arrays_in_canonical_order( # noqa - studies, - (), - expr.axes, - {}) + studies, {}, ()) return Placeholder(name=expr.name, shape=self.rec_idx_or_size_tuple((*expr.shape, *new_shape,)), dtype=expr.dtype, - axes=new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -198,7 +202,7 @@ def map_roll(self, expr: Roll) -> Array: return Roll(array=new_predecessor, shift=expr.shift, axis=expr.axis, - axes=expr.axes + new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -212,7 +216,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: return AxisPermutation(array=new_predecessor, axis_permutation=axis_permute, - axes=expr.axes + new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -227,7 +231,7 @@ def _map_index_base(self, expr: IndexBase) -> Array: return type(expr)(new_predecessor, indices=self.rec_idx_or_size_tuple(new_indices), - axes=expr.axes + new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -239,7 +243,7 @@ def map_reshape(self, expr: Reshape) -> Array: newshape=self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), order=expr.order, - axes=expr.axes + new_axes, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -256,81 +260,93 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) - studies_shape, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa + studies_shape, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa # This is going to be expensive. + correct_shape_preds: ArraysT = () - # Now we need to update the expressions. - # Now that we have the appropriate shape, - # we need to update the input arrays to match. - - cp_map = CopyMapper() - corrected_new_arrays: ArraysT = () for iarr, array in enumerate(new_predecessors): - tmp = cp_map(array) # Get a copy of the array. - # We need to grow the array to the new size. - if arrays_to_study_num_present: - studies_present = arrays_to_study_num_present[array] - for ind, size in enumerate(studies_shape): - if ind not in studies_present: - build: ArraysT = tuple([cp_map(tmp) for _ in range(size)]) - - # Note we are stacking the arrays into the appropriate shape. - tmp = Stack(arrays=build, - axis=len(expr.arrays[iarr].axes) + ind, - axes=new_axes[:ind], - tags=tmp.tags, - non_equality_tags=tmp.non_equality_tags) - - if studies_shape: - assert tmp.shape[-(len(studies_shape)):] == studies_shape - else: - assert tmp.shape[-(len(studies_shape)):] == tmp.shape - - corrected_new_arrays = (*corrected_new_arrays, tmp) - - return corrected_new_arrays, new_axes + # Broadcast out to the right shape. + num_single_inst_axes = len(expr.arrays[iarr].shape) + scale_expr = prim.Subscript(prim.Variable("_in0"), + index=tuple([prim.Variable(f"_{ind}") for + ind in range(num_single_inst_axes)])) + new_array = IndexLambda(expr=scale_expr, + bindings=immutabledict({"_in0": array}), + tags=array.tags, + non_equality_tags=array.non_equality_tags, + dtype=array.dtype, + var_to_reduction_descr=immutabledict({}), + shape=(*expr.arrays[iarr].shape, *studies_shape,), + axes=(*expr.arrays[iarr].axes, *new_axes,)) + + correct_shape_preds = (*correct_shape_preds, new_array,) + + for arr in correct_shape_preds: + assert arr.shape == correct_shape_preds[0].shape + + return correct_shape_preds, new_axes def map_stack(self, expr: Stack) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + new_arrays, new_axes = self._mult_pred_same_shape(expr) return Stack(arrays=new_arrays, axis=expr.axis, - axes=expr.axes + new_axes_for_end, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) + new_arrays, new_axes = self._mult_pred_same_shape(expr) return Concatenate(arrays=new_arrays, axis=expr.axis, - axes=expr.axes + new_axes_for_end, + axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. - new_bindings: dict[str, Array] = {name: self.rec(bnd) + + new_binds: dict[str, Array] = {name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())} - new_arrays = (*new_bindings.values(),) + new_arrays = (*new_binds.values(),) - array_to_bnd_name: dict[Array, str] = {bnd: name for name, bnd - in sorted(new_bindings.items())} + # The arrays may be the same for a predecessors. + # However, the index will be unique. + + array_num_to_bnd_name: dict[int, str] = {ind: name for ind, (name, _) + in enumerate(sorted(new_binds.items()))} # noqa # Determine the new parameter studies that are being conducted. - postpend_shape, new_axes, array_to_studies_num = self._shapes_and_axes_from_predecessors(expr, # noqa + postpend_shape, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, # noqa new_arrays) - varname_to_studies_nums = {array_to_bnd_name[array]: studies for array, - studies in array_to_studies_num.items()} + varname_to_studies_nums = {array_num_to_bnd_name[iarr]: studies for iarr, + studies in arr_num_to_study_nums.items()} + + for vn_key in varname_to_studies_nums.keys(): + assert vn_key in new_binds.keys() + + for vn_key in new_binds.keys(): + assert vn_key in varname_to_studies_nums.keys() # Now we need to update the expressions. scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies_nums, len(expr.shape)) + + + # Data dump the index lambda to a file which I can read. + with open("expansion_map_indexlambda.txt", "a+") as my_file: + my_file.write("\n") + my_file.write("\n") + my_file.write(str(scalar_expr)) + my_file.write("\n") + my_file.write(str({name: len(bnd.axes) for name, bnd in sorted(new_binds.items())})) + return IndexLambda(expr=scalar_expr, - bindings=immutabledict(new_bindings), + bindings=immutabledict(new_binds), shape=(*expr.shape, *postpend_shape,), var_to_reduction_descr=expr.var_to_reduction_descr, dtype=expr.dtype, @@ -341,21 +357,24 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: def map_einsum(self, expr: Einsum) -> Array: new_predecessors = tuple(self.rec(arg) for arg in expr.args) - _, new_axes, arrays_to_study_num_present = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa + _, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...] = () for ival, array in enumerate(new_predecessors): one_descr = expr.access_descriptors[ival] - if arrays_to_study_num_present: - for ind in arrays_to_study_num_present[array]: + if arr_num_to_study_nums: + for ind in arr_num_to_study_nums[ival]: one_descr = (*one_descr, # Adding in new element axes to the end of the arrays. EinsumElementwiseAxis(dim=len(expr.shape) + ind)) access_descriptors = (*access_descriptors, one_descr) + # One descriptor per axis. + assert len(one_descr) == len(array.shape) + return Einsum(access_descriptors, new_predecessors, - axes=expr.axes + new_axes, + axes=(*expr.axes, *new_axes,), redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -394,45 +413,70 @@ class ParamAxisExpander(IdentityMapper): def map_subscript(self, expr: prim.Subscript, varname_to_studies_num: Mapping[str, tuple[int, ...]], - num_original_inds: int) -> prim.Subscript: + num_orig_elem_inds: int) -> prim.Subscript: + """ + `arg' num_orig_elem_inds specifies the maximum number of indices + in a scalar expression that can be used to index into an array provided + that index is not used as part of a reduction. + """ # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. # We only need to modify the indexing which is stored in the index member. + assert isinstance(expr.aggregate, prim.Variable) + name = expr.aggregate.name if name in varname_to_studies_num.keys(): # These are the single instance information. + index = self.rec(expr.index, varname_to_studies_num, - num_original_inds) + num_orig_elem_inds) new_vars: tuple[prim.Variable, ...] = () my_studies: tuple[int, ...] = varname_to_studies_num[name] for num in my_studies: - new_vars = *new_vars, prim.Variable(f"_{num_original_inds + num}"), + new_vars = *new_vars, prim.Variable(f"_{num_orig_elem_inds + num}"), if isinstance(index, tuple): index = index + new_vars else: index = tuple(index) + new_vars + return type(expr)(aggregate=expr.aggregate, index=index) - return expr + + return super().map_subscript(expr, varname_to_studies_num, num_orig_elem_inds) def map_variable(self, expr: prim.Variable, varname_to_studies: Mapping[str, tuple[int, ...]], - num_original_inds: int) -> prim.Expression: + num_orig_elem_inds: int) -> prim.Expression: # We know that a variable is a leaf node. So we only need to update it # if the variable is part of a study. if expr.name in varname_to_studies.keys(): - # These are the single instance information. - # In the multiple instance we will need to index into the variable. + # The variable may need to be updated. - new_vars: tuple[prim.Variable, ...] = () my_studies: tuple[int, ...] = varname_to_studies[expr.name] + if len(my_studies) == 0: + # No studies + return expr + + new_vars: tuple[prim.Variable, ...] = () + + assert my_studies + assert len(my_studies) > 0 + for num in my_studies: - new_vars = *new_vars, prim.Variable(f"_{num_original_inds + num}"), + new_vars = *new_vars, prim.Variable(f"_{num_orig_elem_inds + num}"), return prim.Subscript(aggregate=expr, index=new_vars) - return expr + + # Since the variable is not in a study we can just return the answer. + return super().map_variable(expr, varname_to_studies, num_orig_elem_inds) + + def map_substitution(self, expr: prim.Substitution, varname_to_studies, + num_orig_elem_inds): + + breakpoint() + raise NotImplementedError("Substitution needs to be expanded.") From 9dc03b109fb847d2edf7d850090154ab27c2a132 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 7 Aug 2024 16:29:25 -0500 Subject: [PATCH 13/27] When you broadcast to a new shape you need an index for each axis in the original array. --- pytato/transform/parameter_study.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 62b6a6f97..ed18ad222 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -262,16 +262,24 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, studies_shape, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa + if not arr_num_to_study_nums: + # We do not need to do anything as the expression we have is unmodified. + return new_predecessors, new_axes + # This is going to be expensive. correct_shape_preds: ArraysT = () for iarr, array in enumerate(new_predecessors): # Broadcast out to the right shape. num_single_inst_axes = len(expr.arrays[iarr].shape) - scale_expr = prim.Subscript(prim.Variable("_in0"), - index=tuple([prim.Variable(f"_{ind}") for - ind in range(num_single_inst_axes)])) - new_array = IndexLambda(expr=scale_expr, + index = tuple(prim.Variable(f"_{ind}") for + ind in range(num_single_inst_axes)) + # Add in the axes from the studies we have in the predecessor. + for study_num in arr_num_to_study_nums[iarr]: + index = (*index, prim.Variable(f"_{num_single_inst_axes + study_num}")) + + new_array = IndexLambda(expr=prim.Subscript(prim.Variable("_in0"), + index=index), bindings=immutabledict({"_in0": array}), tags=array.tags, non_equality_tags=array.non_equality_tags, @@ -289,12 +297,18 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, def map_stack(self, expr: Stack) -> Array: new_arrays, new_axes = self._mult_pred_same_shape(expr) - return Stack(arrays=new_arrays, + breakpoint() + out = Stack(arrays=new_arrays, axis=expr.axis, axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) + assert out.ndim == len(out.shape) + assert len(out.shape) == len(out.arrays[0].shape) + 1 + + return out + def map_concatenate(self, expr: Concatenate) -> Array: new_arrays, new_axes = self._mult_pred_same_shape(expr) From a69da839eef23e1727d32022a5780d04cbec2128 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 7 Aug 2024 16:48:05 -0500 Subject: [PATCH 14/27] Correct typing information. --- pytato/transform/parameter_study.py | 31 +++++++---------------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index ed18ad222..df3507cc2 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -154,6 +154,7 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, new_shape: KnownShapeType = () studies_axes: AxesT = () + num_studies: int = 0 for ind, study in enumerate(sorted(studies, key=lambda x: str(x.__class__))): new_shape = (*new_shape, study.axis_size) @@ -164,21 +165,18 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, array_num_to_study_nums[arr_num] = (*array_num_to_study_nums[arr_num], ind) # noqa else: array_num_to_study_nums[arr_num] = (ind,) + num_studies += 1 - assert len(new_shape) == len(studies) - assert len(studies_axes) == len(studies) + assert len(new_shape) == num_studies + assert len(new_shape) == len(studies_axes) return new_shape, studies_axes, array_num_to_study_nums def map_placeholder(self, expr: Placeholder) -> Array: # This is where we could introduce extra axes. if expr.name in self.placeholder_name_to_parameter_studies.keys(): - new_axes = expr.axes studies = self.placeholder_name_to_parameter_studies[expr.name] - new_shape: KnownShapeType = () - new_axes: AxesT = () - # We know that there are no predecessors and we know the studies to add. # We need to get them in the right order. new_shape, new_axes, _ = self._studies_to_shape_and_axes_and_arrays_in_canonical_order( # noqa @@ -278,6 +276,8 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, for study_num in arr_num_to_study_nums[iarr]: index = (*index, prim.Variable(f"_{num_single_inst_axes + study_num}")) + assert len(index) == len(array.axes) + new_array = IndexLambda(expr=prim.Subscript(prim.Variable("_in0"), index=index), bindings=immutabledict({"_in0": array}), @@ -291,13 +291,12 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, correct_shape_preds = (*correct_shape_preds, new_array,) for arr in correct_shape_preds: - assert arr.shape == correct_shape_preds[0].shape + assert arr.shape == correct_shape_preds[0].shape return correct_shape_preds, new_axes def map_stack(self, expr: Stack) -> Array: new_arrays, new_axes = self._mult_pred_same_shape(expr) - breakpoint() out = Stack(arrays=new_arrays, axis=expr.axis, axes=(*expr.axes, *new_axes,), @@ -348,16 +347,6 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # Now we need to update the expressions. scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies_nums, len(expr.shape)) - - - # Data dump the index lambda to a file which I can read. - with open("expansion_map_indexlambda.txt", "a+") as my_file: - my_file.write("\n") - my_file.write("\n") - my_file.write(str(scalar_expr)) - my_file.write("\n") - my_file.write(str({name: len(bnd.axes) for name, bnd in sorted(new_binds.items())})) - return IndexLambda(expr=scalar_expr, bindings=immutabledict(new_binds), @@ -488,9 +477,3 @@ def map_variable(self, expr: prim.Variable, # Since the variable is not in a study we can just return the answer. return super().map_variable(expr, varname_to_studies, num_orig_elem_inds) - - def map_substitution(self, expr: prim.Substitution, varname_to_studies, - num_orig_elem_inds): - - breakpoint() - raise NotImplementedError("Substitution needs to be expanded.") From 524ff55d71a19dc1b9db5f7e0f276751ebf38441 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 7 Aug 2024 17:50:02 -0500 Subject: [PATCH 15/27] Address @inducer's comments. --- pytato/transform/parameter_study.py | 190 ++++++++++++++++------------ test/test_pytato.py | 48 +++---- 2 files changed, 131 insertions(+), 107 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index df3507cc2..c73f6b575 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -78,11 +78,11 @@ @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): """ - A tag to indicate that the axis is being used - for independent trials like in a parameter study. - If you want to vary multiple input variables in the - same study then you need to have the same type of - class: 'ParameterStudyAxisTag'. + A tag to indicate that the axis is being used + for independent trials like in a parameter study. + If you want to vary multiple input variables in the + same study then you need to have the same type of + class: 'ParameterStudyAxisTag'. """ axis_size: int @@ -92,7 +92,98 @@ class ParameterStudyAxisTag(UniqueTag): KnownShapeType = tuple[IntegralT, ...] -class ExpansionMapper(CopyMapper): +class ParamAxisExpander(IdentityMapper): + """ + The goal of this mapper is to convert a single instance scalar expression + into a single instruction multiple data scalar expression. We assume that any + array variables in the original scalar expression will be indexed completely. + Also, new axes for the studies are assumed to be on the end of + those array variables. + """ + + def __init__(self, varname_to_studies_num: Mapping[str, tuple[int, ...]], + num_orig_elem_inds: int): + """ + `arg' varname_to_studies_num: is a mapping from the variable name used + in the scalar expression to the studies present in the multiple instance + expression. Note that the varnames must be for the array variables only. + + `arg' num_orig_elem_inds: is the number of element axes in the result of + the single instance expression. + """ + + super().__init__() + self.varname_to_studies_num = varname_to_studies_num + self.num_orig_elem_inds = num_orig_elem_inds + + def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: + # We only need to modify the indexing which is stored in the index member. + assert isinstance(expr.aggregate, prim.Variable) + + name = expr.aggregate.name + if name in self.varname_to_studies_num.keys(): + # These are the single instance information. + + index = self.rec(expr.index) + + new_vars: tuple[prim.Variable, ...] = () + my_studies: tuple[int, ...] = self.varname_to_studies_num[name] + + for num in my_studies: + new_vars = (*new_vars, + prim.Variable(f"_{self.num_orig_elem_inds + num}"),) + + if isinstance(index, tuple): + index = index + new_vars + else: + index = tuple(index) + new_vars + + return type(expr)(aggregate=expr.aggregate, index=index) + + return super().map_subscript(expr) + + def map_variable(self, expr: prim.Variable) -> prim.Expression: + # We know that a variable is a leaf node. So we only need to update it + # if the variable is part of a study. + if expr.name in self.varname_to_studies.keys(): + # The variable may need to be updated. + + my_studies: tuple[int, ...] = self.varname_to_studies[expr.name] + + if len(my_studies) == 0: + # No studies + return expr + + assert my_studies + assert len(my_studies) > 0 + + new_vars = tuple([prim.Variable(f"_{self.num_orig_elem_inds + num}") for + num in my_studies]) + + return prim.Subscript(aggregate=expr, index=new_vars) + + # Since the variable is not in a study we can just return the answer. + return super().map_variable(expr) + + +class ParameterStudyVectorizer(CopyMapper): + """ + This mapper will expand a single instance DAG into a DAG for parameter studies. + It is assumed that the parameter studies cannot interact with each other. + Currently, this only supports DAGs which are made for a single processing unit. + That is we do not support distributed programming right now. + + Any new axes used for parameter studies will be added to the end of the arrays. + Note this will break broadcasting assumptions. Therefore, one needs to be careful + if only a portion of the program is expanded. This decision was made under the + assumption that the generated code would execute faster if the parameter study + axes were the ones with the shortest strides. + + If there are multiple distinct parameter studies then the DAG will be expanded + for the Cartesian product of the input parameter studies. A parameter study is + specified in an array by tagging the corresponding axis with a tag that is a + class: `ParameterStudyAxisTag' or a class which inherits from it. + """ def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT]): super().__init__() @@ -247,13 +338,13 @@ def map_reshape(self, expr: Reshape) -> Array: # {{{ Operations with multiple predecessors. - def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, - AxesT]: + def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ + -> tuple[ArraysT, AxesT]: """ - This method will convert predecessors who were originally the same - shape in a single instance program to the same shape in this multiple - instance program. + This method will convert predecessors who were originally the same + shape in a single instance program to the same shape in this multiple + instance program. """ new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) @@ -273,6 +364,7 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, index = tuple(prim.Variable(f"_{ind}") for ind in range(num_single_inst_axes)) # Add in the axes from the studies we have in the predecessor. + for study_num in arr_num_to_study_nums[iarr]: index = (*index, prim.Variable(f"_{num_single_inst_axes + study_num}")) @@ -296,7 +388,7 @@ def _mult_pred_same_shape(self, expr: Stack | Concatenate) -> tuple[ArraysT, return correct_shape_preds, new_axes def map_stack(self, expr: Stack) -> Array: - new_arrays, new_axes = self._mult_pred_same_shape(expr) + new_arrays, new_axes = self._broadcast_predecessors_to_same_shape(expr) out = Stack(arrays=new_arrays, axis=expr.axis, axes=(*expr.axes, *new_axes,), @@ -309,7 +401,7 @@ def map_stack(self, expr: Stack) -> Array: return out def map_concatenate(self, expr: Concatenate) -> Array: - new_arrays, new_axes = self._mult_pred_same_shape(expr) + new_arrays, new_axes = self._broadcast_predecessors_to_same_shape(expr) return Concatenate(arrays=new_arrays, axis=expr.axis, @@ -345,10 +437,9 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: assert vn_key in varname_to_studies_nums.keys() # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies_nums, - len(expr.shape)) + scalar_expr_mapper = ParamAxisExpander(varname_to_studies_nums, len(expr.shape)) - return IndexLambda(expr=scalar_expr, + return IndexLambda(expr=scalar_expr_mapper(expr.expr), bindings=immutabledict(new_binds), shape=(*expr.shape, *postpend_shape,), var_to_reduction_descr=expr.var_to_reduction_descr, @@ -410,70 +501,3 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder) \ " not yet supported.") # }}} - - -class ParamAxisExpander(IdentityMapper): - - def map_subscript(self, expr: prim.Subscript, - varname_to_studies_num: Mapping[str, tuple[int, ...]], - num_orig_elem_inds: int) -> prim.Subscript: - """ - `arg' num_orig_elem_inds specifies the maximum number of indices - in a scalar expression that can be used to index into an array provided - that index is not used as part of a reduction. - """ - # We know that we are not changing the variable that we are indexing into. - # This is stored in the aggregate member of the class Subscript. - - # We only need to modify the indexing which is stored in the index member. - assert isinstance(expr.aggregate, prim.Variable) - - name = expr.aggregate.name - if name in varname_to_studies_num.keys(): - # These are the single instance information. - - index = self.rec(expr.index, varname_to_studies_num, - num_orig_elem_inds) - - new_vars: tuple[prim.Variable, ...] = () - my_studies: tuple[int, ...] = varname_to_studies_num[name] - - for num in my_studies: - new_vars = *new_vars, prim.Variable(f"_{num_orig_elem_inds + num}"), - - if isinstance(index, tuple): - index = index + new_vars - else: - index = tuple(index) + new_vars - - return type(expr)(aggregate=expr.aggregate, index=index) - - return super().map_subscript(expr, varname_to_studies_num, num_orig_elem_inds) - - def map_variable(self, expr: prim.Variable, - varname_to_studies: Mapping[str, tuple[int, ...]], - num_orig_elem_inds: int) -> prim.Expression: - # We know that a variable is a leaf node. So we only need to update it - # if the variable is part of a study. - - if expr.name in varname_to_studies.keys(): - # The variable may need to be updated. - - my_studies: tuple[int, ...] = varname_to_studies[expr.name] - - if len(my_studies) == 0: - # No studies - return expr - - new_vars: tuple[prim.Variable, ...] = () - - assert my_studies - assert len(my_studies) > 0 - - for num in my_studies: - new_vars = *new_vars, prim.Variable(f"_{num_orig_elem_inds + num}"), - - return prim.Subscript(aggregate=expr, index=new_vars) - - # Since the variable is not in a study we can just return the answer. - return super().map_variable(expr, varname_to_studies, num_orig_elem_inds) diff --git a/test/test_pytato.py b/test/test_pytato.py index d0a0b538c..32fd08ff0 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1219,14 +1219,14 @@ def test_lower_to_index_lambda(): # {{{ Expansion Mapper tests. def test_expansion_mapper_placeholder(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) name_to_studies = {name: frozenset((my_study,))} expr = pt.make_placeholder(name, (15, 5), dtype=int) assert expr.shape == (15, 5) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (15, 5, 10) @@ -1239,7 +1239,7 @@ def test_expansion_mapper_placeholder(): def test_expansion_mapper_basic_index(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1248,14 +1248,14 @@ def test_expansion_mapper_basic_index(): assert expr.shape == () - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (10,) assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) def test_expansion_mapper_advanced_index_contiguous_axes(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1264,7 +1264,7 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): assert expr.shape == (10, 5) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (10, 5, 10) assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) @@ -1274,7 +1274,7 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): def test_expansion_mapper_advanced_index_non_contiguous_axes(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1286,7 +1286,7 @@ def test_expansion_mapper_advanced_index_non_contiguous_axes(): assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) assert expr.shape == (10, 2, 1000) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (10, 2, 1000, 10) assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) @@ -1296,7 +1296,7 @@ def test_expansion_mapper_advanced_index_non_contiguous_axes(): def test_expansion_mapper_index_lambda(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1305,7 +1305,7 @@ def test_expansion_mapper_index_lambda(): assert expr.shape == (100,) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (100, 10) assert isinstance(new_expr, pt.IndexLambda) @@ -1319,7 +1319,7 @@ def test_expansion_mapper_index_lambda(): def test_expansion_mapper_roll(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1330,7 +1330,7 @@ def test_expansion_mapper_roll(): assert expr.shape == (100,) assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (100, 10,) assert isinstance(new_expr, pt.Roll) @@ -1338,7 +1338,7 @@ def test_expansion_mapper_roll(): def test_expansion_mapper_axis_permutation(): - from pytato.transform.parameter_study import ExpansionMapper, ParameterStudyAxisTag + from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1346,7 +1346,7 @@ def test_expansion_mapper_axis_permutation(): expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) assert expr.shape == (5, 15) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(expr) assert new_expr.shape == (5, 15, 10) assert isinstance(new_expr, pt.AxisPermutation) @@ -1360,7 +1360,7 @@ def test_expansion_mapper_axis_permutation(): def test_expansion_mapper_reshape(): - from pytato.transform.parameter_study import ExpansionMapper + from pytato.transform.parameter_study import ParameterStudyVectorizer name_to_studies, studies, names = _set_up_expansion_mapper_tests() expr = pt.transpose(pt.make_placeholder(names[0], @@ -1371,7 +1371,7 @@ def test_expansion_mapper_reshape(): out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) assert out_expr.shape == (10, 15) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(out_expr) assert new_expr.shape == (10, 15, 10, 1000) assert isinstance(new_expr, pt.Reshape) @@ -1388,7 +1388,7 @@ def test_expansion_mapper_reshape(): def test_expansion_mapper_stack(): - from pytato.transform.parameter_study import ExpansionMapper + from pytato.transform.parameter_study import ParameterStudyVectorizer name_to_studies, studies, names = _set_up_expansion_mapper_tests() @@ -1400,7 +1400,7 @@ def test_expansion_mapper_stack(): out_expr = pt.stack([expr, expr2], axis=0) assert out_expr.shape == (2, 5, 15) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(out_expr) assert new_expr.shape == (2, 5, 15, 10, 1000) assert isinstance(new_expr, pt.Stack) @@ -1436,7 +1436,7 @@ class Study1(ParameterStudyAxisTag): def test_expansion_mapper_concatenate(): - from pytato.transform.parameter_study import ExpansionMapper + from pytato.transform.parameter_study import ParameterStudyVectorizer name_to_studies, studies, names = _set_up_expansion_mapper_tests() @@ -1448,7 +1448,7 @@ def test_expansion_mapper_concatenate(): out_expr = pt.concatenate([expr, expr2], axis=0) assert out_expr.shape == (10, 15) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_expr = my_mapper(out_expr) assert new_expr.shape == (10, 15, 10, 1000) assert isinstance(new_expr, pt.Concatenate) @@ -1465,7 +1465,7 @@ def test_expansion_mapper_concatenate(): def test_expansion_mapper_einsum_matmul(): - from pytato.transform.parameter_study import ExpansionMapper + from pytato.transform.parameter_study import ParameterStudyVectorizer name_to_studies, _, names = _set_up_expansion_mapper_tests() @@ -1479,14 +1479,14 @@ def test_expansion_mapper_einsum_matmul(): assert isinstance(c, pt.Einsum) assert c.shape == (47, 5) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) updated_c = my_mapper(c) assert updated_c.shape == (47, 5, 10, 1000) def test_expansion_mapper_does_nothing_if_tags_not_there(): - from pytato.transform.parameter_study import ExpansionMapper + from pytato.transform.parameter_study import ParameterStudyVectorizer name_to_studies, _, _ = _set_up_expansion_mapper_tests() @@ -1517,7 +1517,7 @@ def make_dws_placeholder(expr): dag = pt.transform.map_and_copy(dag, make_dws_placeholder) - my_mapper = ExpansionMapper(name_to_studies) + my_mapper = ParameterStudyVectorizer(name_to_studies) new_dag = my_mapper(dag) assert new_dag == dag From 85cc23ed59caf1bce7f1f27506ab2c330ee855c2 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 7 Aug 2024 17:51:39 -0500 Subject: [PATCH 16/27] Address inducer's comments. --- test/test_pytato.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 32fd08ff0..7b3f7235d 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1219,7 +1219,10 @@ def test_lower_to_index_lambda(): # {{{ Expansion Mapper tests. def test_expansion_mapper_placeholder(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1239,7 +1242,10 @@ def test_expansion_mapper_placeholder(): def test_expansion_mapper_basic_index(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1255,7 +1261,10 @@ def test_expansion_mapper_basic_index(): def test_expansion_mapper_advanced_index_contiguous_axes(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1274,7 +1283,10 @@ def test_expansion_mapper_advanced_index_contiguous_axes(): def test_expansion_mapper_advanced_index_non_contiguous_axes(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1296,7 +1308,10 @@ def test_expansion_mapper_advanced_index_non_contiguous_axes(): def test_expansion_mapper_index_lambda(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1319,7 +1334,10 @@ def test_expansion_mapper_index_lambda(): def test_expansion_mapper_roll(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) @@ -1338,7 +1356,10 @@ def test_expansion_mapper_roll(): def test_expansion_mapper_axis_permutation(): - from pytato.transform.parameter_study import ParameterStudyVectorizer, ParameterStudyAxisTag + from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, + ) name = "my_array" my_study = ParameterStudyAxisTag(10) From 9b5feabe848456c6897aad1cc50c40bd8cffe4ae Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 8 Aug 2024 16:02:10 -0500 Subject: [PATCH 17/27] Ensure the variable names are consistent. --- pytato/transform/parameter_study.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index c73f6b575..85ef1ee87 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -145,10 +145,10 @@ def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: def map_variable(self, expr: prim.Variable) -> prim.Expression: # We know that a variable is a leaf node. So we only need to update it # if the variable is part of a study. - if expr.name in self.varname_to_studies.keys(): + if expr.name in self.varname_to_studies_num.keys(): # The variable may need to be updated. - my_studies: tuple[int, ...] = self.varname_to_studies[expr.name] + my_studies: tuple[int, ...] = self.varname_to_studies_num[expr.name] if len(my_studies) == 0: # No studies From 8675f917e4927e06c7a27dc973ee6bac866bea83 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 15:22:08 -0500 Subject: [PATCH 18/27] Use generators and not tuples of list comprehensions --- pytato/transform/parameter_study.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 85ef1ee87..c1cc4dc2d 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -157,10 +157,10 @@ def map_variable(self, expr: prim.Variable) -> prim.Expression: assert my_studies assert len(my_studies) > 0 - new_vars = tuple([prim.Variable(f"_{self.num_orig_elem_inds + num}") for - num in my_studies]) + new_vars = (prim.Variable(f"_{self.num_orig_elem_inds + num}") # noqa + for num in my_studies) - return prim.Subscript(aggregate=expr, index=new_vars) + return prim.Subscript(aggregate=expr, index=tuple(new_vars)) # Since the variable is not in a study we can just return the answer. return super().map_variable(expr) @@ -300,11 +300,12 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, (new_predecessor,)) # Include the axes we are adding to the system. - axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) - for i in range(len(postpend_shape))]) + n_single_inst_axes: int = len(expr.axis_permutation) + axis_permute_gen = (i + n_single_inst_axes for i in range(len(postpend_shape))) return AxisPermutation(array=new_predecessor, - axis_permutation=axis_permute, + axis_permutation=(*expr.axis_permutation, + *axis_permute_gen,), axes=(*expr.axes, *new_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -360,13 +361,13 @@ def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ for iarr, array in enumerate(new_predecessors): # Broadcast out to the right shape. - num_single_inst_axes = len(expr.arrays[iarr].shape) + n_single_inst_axes = len(expr.arrays[iarr].shape) index = tuple(prim.Variable(f"_{ind}") for - ind in range(num_single_inst_axes)) + ind in range(n_single_inst_axes)) # Add in the axes from the studies we have in the predecessor. for study_num in arr_num_to_study_nums[iarr]: - index = (*index, prim.Variable(f"_{num_single_inst_axes + study_num}")) + index = (*index, prim.Variable(f"_{n_single_inst_axes + study_num}")) assert len(index) == len(array.axes) From 6b9b4ccc948a5e1c7ef0381301bf174557457e77 Mon Sep 17 00:00:00 2001 From: nkoskelo <129830924+nkoskelo@users.noreply.github.com> Date: Thu, 15 Aug 2024 16:10:53 -0500 Subject: [PATCH 19/27] Update pytato/transform/parameter_study.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/transform/parameter_study.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index c1cc4dc2d..05e7e772e 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -7,7 +7,6 @@ TODO: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. automodule:: pytato.transform.parameter_study """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees From b772b261ee064a61a1bbd9b835981896ac456ecf Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 17:21:55 -0500 Subject: [PATCH 20/27] Add to the documentation. --- pytato/transform/__init__.py | 1 + pytato/transform/parameter_study.py | 74 ++++++++++++++++++----------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index ea790e9b6..af67829be 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -122,6 +122,7 @@ .. automodule:: pytato.transform.lower_to_index_lambda .. automodule:: pytato.transform.remove_broadcasts_einsum .. automodule:: pytato.transform.einsum_distributive_law +.. automodule:: pytato.transform.parameter_study .. currentmodule:: pytato.transform Dict representation of DAGs diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 05e7e772e..243305920 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -1,17 +1,16 @@ from __future__ import annotations -""" -.. currentmodule:: pytato.transform - -TODO: -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +__doc__ = """ +.. currentmodule:: pytato.transform.parameter_study -""" -__copyright__ = """ -Copyright (C) 2020-1 University of Illinois Board of Trustees +.. autoclass:: ParameterStudyAxisTag +.. autoclass:: ParameterStudyVectorizer +.. autoclass:: IndexLambdaScalarExpressionVectorizer """ +__copyright__ = "Copyright (C) 2024 Nicholas Koskelo" + __license__ = """ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -81,9 +80,9 @@ class ParameterStudyAxisTag(UniqueTag): for independent trials like in a parameter study. If you want to vary multiple input variables in the same study then you need to have the same type of - class: 'ParameterStudyAxisTag'. + :class:`ParameterStudyAxisTag`. """ - axis_size: int + size: int StudiesT = tuple[ParameterStudyAxisTag, ...] @@ -91,7 +90,7 @@ class ParameterStudyAxisTag(UniqueTag): KnownShapeType = tuple[IntegralT, ...] -class ParamAxisExpander(IdentityMapper): +class IndexLambdaScalarExpressionVectorizer(IdentityMapper): """ The goal of this mapper is to convert a single instance scalar expression into a single instruction multiple data scalar expression. We assume that any @@ -103,11 +102,11 @@ class ParamAxisExpander(IdentityMapper): def __init__(self, varname_to_studies_num: Mapping[str, tuple[int, ...]], num_orig_elem_inds: int): """ - `arg' varname_to_studies_num: is a mapping from the variable name used + `arg` varname_to_studies_num: is a mapping from the variable name used in the scalar expression to the studies present in the multiple instance expression. Note that the varnames must be for the array variables only. - `arg' num_orig_elem_inds: is the number of element axes in the result of + `arg` num_orig_elem_inds: is the number of element axes in the result of the single instance expression. """ @@ -168,20 +167,39 @@ def map_variable(self, expr: prim.Variable) -> prim.Expression: class ParameterStudyVectorizer(CopyMapper): """ This mapper will expand a single instance DAG into a DAG for parameter studies. - It is assumed that the parameter studies cannot interact with each other. + The DAG for parameter studies will be equivalent to running the single instance + DAG for each input in the parameter study space, $P$. You must specify which + input :class:`~pytato.array.Placeholder' are part of what parameter study. + + To maintain the equivalence with repeated calling the single instance DAG, the + DAG for parameter studies will not create any expressions which depend on the + specific instance of a parameter study. + + We do allow an input parameter which is specified by a + :class:`~pytato.array.Placeholder` to be a part of multiple distinct parameter + studies. Each distinct parameter study will be a new axis at the end of the array. + + We do NOT require that each input be part of each parameter study. We will broadcast + the input as necessary. + + Consider an binary operation, $z = x + y$. Let $x$ be a part of parameter study + $S1$. Let $y$ be a part of parameter study $S2$. Then, $z$ will be a part of the + parameter studies $S1$ and $S2$. $z_{i,j} = x_{i} + y_{j}$. So, the shape of $z$ + would be (single instance shape, $S1.size$, $S2.size$). + + A parameter study is specified in an array by tagging the corresponding axis + with a tag that is a :class:`ParameterStudyAxisTag` or a class which + inherits from it. + Currently, this only supports DAGs which are made for a single processing unit. - That is we do not support distributed programming right now. - - Any new axes used for parameter studies will be added to the end of the arrays. - Note this will break broadcasting assumptions. Therefore, one needs to be careful - if only a portion of the program is expanded. This decision was made under the - assumption that the generated code would execute faster if the parameter study - axes were the ones with the shortest strides. - - If there are multiple distinct parameter studies then the DAG will be expanded - for the Cartesian product of the input parameter studies. A parameter study is - specified in an array by tagging the corresponding axis with a tag that is a - class: `ParameterStudyAxisTag' or a class which inherits from it. + That is we do not support distributed programming right now. We also do not support + function definitions within the single instance DAG. + + NOTE any new axes used for parameter studies will be added to the end of the arrays. + If only a portion of the program is expanded, broadcasting may break as the single + instance axes will be left aligned instead of right aligned. This decision was + made under the assumption that the generated code would execute faster + if the parameter study axes were the ones with the shortest strides. """ def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT]): @@ -247,7 +265,7 @@ def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, num_studies: int = 0 for ind, study in enumerate(sorted(studies, key=lambda x: str(x.__class__))): - new_shape = (*new_shape, study.axis_size) + new_shape = (*new_shape, study.size) studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) for arr_num, arr in enumerate(mapped_preds): if arr in study_to_arrays[frozenset((study,))]: @@ -437,7 +455,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: assert vn_key in varname_to_studies_nums.keys() # Now we need to update the expressions. - scalar_expr_mapper = ParamAxisExpander(varname_to_studies_nums, len(expr.shape)) + scalar_expr_mapper = IndexLambdaScalarExpressionVectorizer(varname_to_studies_nums, len(expr.shape)) # noqa return IndexLambda(expr=scalar_expr_mapper(expr.expr), bindings=immutabledict(new_binds), From ae6820485f9c5c62068b3733e41ebf922f213e58 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 17:47:10 -0500 Subject: [PATCH 21/27] Update documentation to be more readible. --- pytato/transform/parameter_study.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 243305920..b3279e2c9 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -165,11 +165,11 @@ def map_variable(self, expr: prim.Variable) -> prim.Expression: class ParameterStudyVectorizer(CopyMapper): - """ + r""" This mapper will expand a single instance DAG into a DAG for parameter studies. The DAG for parameter studies will be equivalent to running the single instance - DAG for each input in the parameter study space, $P$. You must specify which - input :class:`~pytato.array.Placeholder' are part of what parameter study. + DAG for each input in the parameter study space. You must specify which + input :class:`~pytato.array.Placeholder` are part of what parameter study. To maintain the equivalence with repeated calling the single instance DAG, the DAG for parameter studies will not create any expressions which depend on the @@ -182,10 +182,23 @@ class ParameterStudyVectorizer(CopyMapper): We do NOT require that each input be part of each parameter study. We will broadcast the input as necessary. - Consider an binary operation, $z = x + y$. Let $x$ be a part of parameter study - $S1$. Let $y$ be a part of parameter study $S2$. Then, $z$ will be a part of the - parameter studies $S1$ and $S2$. $z_{i,j} = x_{i} + y_{j}$. So, the shape of $z$ - would be (single instance shape, $S1.size$, $S2.size$). + Ex: + + .. math:: + + \mathbf{Z} = \mathbf{X} + \mathbf{Y}, + + where :math:`\mathbf{X}` is a part of parameter study :math:`\mathbf{S1}` and + :math:`\mathbf{Y}` is a part of parameter study :math:`\mathbf{S2}`. Then, + :math:`\mathbf{Z}` will be a part of both parameter studies :math:`\mathbf{S1}` and + :math:`\mathbf{S2}`. + + .. math:: + + \mathbf{Z}_{i,j} = \mathbf{X}_{i} + \mathbf{Y}_{j}, + + + and so the shape of :math:`\mathbf{Z}` will be (orig_shape, :math:`\mathbf{S1}.size`, :math:`\mathbf{S2}.size`). A parameter study is specified in an array by tagging the corresponding axis with a tag that is a :class:`ParameterStudyAxisTag` or a class which From 5949958e8461e41d48abaafd2d13156071769aae Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 15 Aug 2024 17:55:21 -0500 Subject: [PATCH 22/27] Underscores still create subscripts even if you don't want it to. :( --- pytato/transform/parameter_study.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index b3279e2c9..851d6e16f 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -189,16 +189,16 @@ class ParameterStudyVectorizer(CopyMapper): \mathbf{Z} = \mathbf{X} + \mathbf{Y}, where :math:`\mathbf{X}` is a part of parameter study :math:`\mathbf{S1}` and - :math:`\mathbf{Y}` is a part of parameter study :math:`\mathbf{S2}`. Then, - :math:`\mathbf{Z}` will be a part of both parameter studies :math:`\mathbf{S1}` and - :math:`\mathbf{S2}`. + :math:`\mathbf{Y}` is a part of parameter study :math:`\mathbf{S2}` both with single + instance shapes of :math:`\mathbf{orig\_shape}`. Then, :math:`\mathbf{Z}` will be a + part of both parameter studies :math:`\mathbf{S1}` and :math:`\mathbf{S2}`. .. math:: - - \mathbf{Z}_{i,j} = \mathbf{X}_{i} + \mathbf{Y}_{j}, + \mathbf{Z}_{i,j} = \mathbf{X}_{i} + \mathbf{Y}_{j}, - and so the shape of :math:`\mathbf{Z}` will be (orig_shape, :math:`\mathbf{S1}.size`, :math:`\mathbf{S2}.size`). + and so the shape of :math:`\mathbf{Z}` will be + (:math:`\mathbf{orig\_shape}`, :math:`\mathbf{S1}.size`, :math:`\mathbf{S2}.size`). A parameter study is specified in an array by tagging the corresponding axis with a tag that is a :class:`ParameterStudyAxisTag` or a class which From 2e97182408dbd33f1ebb023d2669dcc3d587125e Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 3 Sep 2024 16:26:41 -0500 Subject: [PATCH 23/27] Move tests to a new file test_vectorizer.py --- test/test_vectorizer.py | 321 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 test/test_vectorizer.py diff --git a/test/test_vectorizer.py b/test/test_vectorizer.py new file mode 100644 index 000000000..55652c818 --- /dev/null +++ b/test/test_vectorizer.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +from __future__ import annotations + + +__copyright__ = """Copyright (C) 2020 Andreas Kloeckner +Copyright (C) 2021 Kaushik Kulkarni +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +import numpy as np +import pytato as pt + +from testlib import RandomDAGContext, make_random_dag +from pytato.transform.parameter_study import ( + ParameterStudyAxisTag, + ParameterStudyVectorizer, +) + +# {{{ Expansion Mapper tests. +def test_vectorize_mapper_placeholder(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int) + assert expr.shape == (15, 5) + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (15, 5, 10) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_vectorize_mapper_basic_index(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + + assert expr.shape == () + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10,) + assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_advanced_index_contiguous_axes(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] + + assert expr.shape == (10, 5) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 5, 10) + assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) + assert isinstance(expr, type(new_expr)) + + +def test_vectorize_mapper_advanced_index_non_contiguous_axes(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + ind0 = pt.arange(10, dtype=int).reshape(10, 1) + ind1 = pt.arange(2, dtype=int).reshape(1, 2) + expr = pt.make_placeholder(name, (15, 1000, 5), dtype=int)[ind0, :, ind1] + + assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) + assert expr.shape == (10, 2, 1000) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (10, 2, 1000, 10) + assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) + + assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) + assert isinstance(expr, type(new_expr)) + + +def test_vectorize_mapper_index_lambda(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + + assert expr.shape == (100,) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, 10) + assert isinstance(new_expr, pt.IndexLambda) + + scalar_expr = new_expr.expr + + assert len(scalar_expr.children) == len(expr.expr.children) + assert scalar_expr != expr.expr + # We modified it so that we have the new axis. + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_roll(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + expr = pt.roll(expr, axis=0, shift=22) + + assert expr.shape == (100,) + assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (100, 10,) + assert isinstance(new_expr, pt.Roll) + assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) + + +def test_vectorize_mapper_axis_permutation(): + name = "my_array" + my_study = ParameterStudyAxisTag(10) + name_to_studies = {name: frozenset((my_study,))} + expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + assert expr.shape == (5, 15) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(expr) + assert new_expr.shape == (5, 15, 10) + assert isinstance(new_expr, pt.AxisPermutation) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i == 2: + assert tags + else: + assert not tags + + +def test_vectorize_mapper_reshape(): + name_to_studies, studies, names = _set_up_vectorize_mapper_tests() + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) + assert out_expr.shape == (10, 15) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, 10, 1000) + assert isinstance(new_expr, pt.Reshape) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(studies[1]) + assert not new_expr.axes[3].tags_of_type(studies[0]) + + +def test_vectorize_mapper_stack(): + name_to_studies, studies, names = _set_up_vectorize_mapper_tests() + + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) + + out_expr = pt.stack([expr, expr2], axis=0) + assert out_expr.shape == (2, 5, 15) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (2, 5, 15, 10, 1000) + assert isinstance(new_expr, pt.Stack) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 2: + assert tags + else: + assert not tags + + assert not new_expr.axes[3].tags_of_type(studies[1]) + assert not new_expr.axes[4].tags_of_type(studies[0]) + + +def _set_up_vectorize_mapper_tests() -> tuple[Mapping[str, + frozenset[ParameterStudyAxisTag]], + tuple[ParameterStudyAxisTag, ...], + tuple[str, ...]]: + + class Study2(ParameterStudyAxisTag): + pass + + class Study1(ParameterStudyAxisTag): + pass + name = "a" + study1 = Study1(10) + arr2 = "b" + study2 = Study2(1000) + name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} + + return name_to_studies, (Study1, Study2,), (name, arr2,) + + +def test_vectorize_mapper_concatenate(): + name_to_studies, studies, names = _set_up_vectorize_mapper_tests() + + expr = pt.transpose(pt.make_placeholder(names[0], + (15, 5), dtype=int)) + expr2 = pt.transpose(pt.make_placeholder(names[1], + (15, 5), dtype=int)) + + out_expr = pt.concatenate([expr, expr2], axis=0) + assert out_expr.shape == (10, 15) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_expr = my_mapper(out_expr) + assert new_expr.shape == (10, 15, 10, 1000) + assert isinstance(new_expr, pt.Concatenate) + + for i, axis in enumerate(new_expr.axes): + tags = axis.tags_of_type(ParameterStudyAxisTag) + if i > 1: + assert tags + else: + assert not tags + + assert not new_expr.axes[2].tags_of_type(studies[1]) + assert not new_expr.axes[3].tags_of_type(studies[0]) + + +def test_vectorize_mapper_einsum_matmul(): + + name_to_studies, _, names = _set_up_vectorize_mapper_tests() + + # Matmul gets expanded correctly. + a = pt.make_placeholder(names[0], + (47, 42), dtype=int) + b = pt.make_placeholder(names[1], + (42, 5), dtype=int) + + c = pt.matmul(a, b) + assert isinstance(c, pt.Einsum) + assert c.shape == (47, 5) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + updated_c = my_mapper(c) + + assert updated_c.shape == (47, 5, 10, 1000) + + +def test_vectorize_mapper_does_nothing_if_tags_not_there(): + name_to_studies, _, _ = _set_up_vectorize_mapper_tests() + + from testlib import RandomDAGContext, make_random_dag + + from pytools import UniqueNameGenerator + axis_len = 5 + + for i in range(50): + print(i) # progress indicator + + seed = 120 + i + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + dag = pt.make_dict_of_named_arrays({"out": make_random_dag(rdagc_pt)}) + + # {{{ convert data-wrappers to placeholders + + vng = UniqueNameGenerator() + + def make_dws_placeholder(expr): + if isinstance(expr, pt.DataWrapper): + return pt.make_placeholder(vng("_pt_ph"), # noqa: B023 + expr.shape, expr.dtype) + else: + return expr + + dag = pt.transform.map_and_copy(dag, make_dws_placeholder) + + my_mapper = ParameterStudyVectorizer(name_to_studies) + new_dag = my_mapper(dag) + + assert new_dag == dag + + # }}} +# }}} From d0e3f865fb033a24931e0e9844c6ef799fb8e5ac Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 3 Sep 2024 16:58:59 -0500 Subject: [PATCH 24/27] Complete the move to test_vectorizer.py --- test/test_pytato.py | 331 ---------------------------------------- test/test_vectorizer.py | 170 +++++++++------------ 2 files changed, 73 insertions(+), 428 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 7b3f7235d..0e9661d53 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -43,7 +43,6 @@ import pytato as pt from pytato.array import _SuppliedAxesAndTagsMixin -from pytato.transform.parameter_study import ParameterStudyAxisTag def test_matmul_input_validation(): @@ -1217,336 +1216,6 @@ def test_lower_to_index_lambda(): assert isinstance(binding, Reshape) -# {{{ Expansion Mapper tests. -def test_expansion_mapper_placeholder(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int) - assert expr.shape == (15, 5) - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (15, 5, 10) - - for i, axis in enumerate(new_expr.axes): - tags = axis.tags_of_type(ParameterStudyAxisTag) - if i == 2: - assert tags - else: - assert not tags - - -def test_expansion_mapper_basic_index(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] - - assert expr.shape == () - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (10,) - assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) - - -def test_expansion_mapper_advanced_index_contiguous_axes(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] - - assert expr.shape == (10, 5) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (10, 5, 10) - assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) - - assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) - assert isinstance(expr, type(new_expr)) - - -def test_expansion_mapper_advanced_index_non_contiguous_axes(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - ind0 = pt.arange(10, dtype=int).reshape(10, 1) - ind1 = pt.arange(2, dtype=int).reshape(1, 2) - expr = pt.make_placeholder(name, (15, 1000, 5), dtype=int)[ind0, :, ind1] - - assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) - assert expr.shape == (10, 2, 1000) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (10, 2, 1000, 10) - assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) - - assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) - assert isinstance(expr, type(new_expr)) - - -def test_expansion_mapper_index_lambda(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) - - assert expr.shape == (100,) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (100, 10) - assert isinstance(new_expr, pt.IndexLambda) - - scalar_expr = new_expr.expr - - assert len(scalar_expr.children) == len(expr.expr.children) - assert scalar_expr != expr.expr - # We modified it so that we have the new axis. - assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) - - -def test_expansion_mapper_roll(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) - expr = pt.roll(expr, axis=0, shift=22) - - assert expr.shape == (100,) - assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (100, 10,) - assert isinstance(new_expr, pt.Roll) - assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) - - -def test_expansion_mapper_axis_permutation(): - from pytato.transform.parameter_study import ( - ParameterStudyAxisTag, - ParameterStudyVectorizer, - ) - - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) - assert expr.shape == (5, 15) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(expr) - assert new_expr.shape == (5, 15, 10) - assert isinstance(new_expr, pt.AxisPermutation) - - for i, axis in enumerate(new_expr.axes): - tags = axis.tags_of_type(ParameterStudyAxisTag) - if i == 2: - assert tags - else: - assert not tags - - -def test_expansion_mapper_reshape(): - from pytato.transform.parameter_study import ParameterStudyVectorizer - - name_to_studies, studies, names = _set_up_expansion_mapper_tests() - expr = pt.transpose(pt.make_placeholder(names[0], - (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], - (15, 5), dtype=int)) - - out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) - assert out_expr.shape == (10, 15) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, 10, 1000) - assert isinstance(new_expr, pt.Reshape) - - for i, axis in enumerate(new_expr.axes): - tags = axis.tags_of_type(ParameterStudyAxisTag) - if i > 1: - assert tags - else: - assert not tags - - assert not new_expr.axes[2].tags_of_type(studies[1]) - assert not new_expr.axes[3].tags_of_type(studies[0]) - - -def test_expansion_mapper_stack(): - from pytato.transform.parameter_study import ParameterStudyVectorizer - - name_to_studies, studies, names = _set_up_expansion_mapper_tests() - - expr = pt.transpose(pt.make_placeholder(names[0], - (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], - (15, 5), dtype=int)) - - out_expr = pt.stack([expr, expr2], axis=0) - assert out_expr.shape == (2, 5, 15) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(out_expr) - assert new_expr.shape == (2, 5, 15, 10, 1000) - assert isinstance(new_expr, pt.Stack) - - for i, axis in enumerate(new_expr.axes): - tags = axis.tags_of_type(ParameterStudyAxisTag) - if i > 2: - assert tags - else: - assert not tags - - assert not new_expr.axes[3].tags_of_type(studies[1]) - assert not new_expr.axes[4].tags_of_type(studies[0]) - - -def _set_up_expansion_mapper_tests() -> tuple[Mapping[str, - frozenset[ParameterStudyAxisTag]], - tuple[ParameterStudyAxisTag, ...], - tuple[str, ...]]: - - class Study2(ParameterStudyAxisTag): - pass - - class Study1(ParameterStudyAxisTag): - pass - name = "a" - study1 = Study1(10) - arr2 = "b" - study2 = Study2(1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - - return name_to_studies, (Study1, Study2,), (name, arr2,) - - -def test_expansion_mapper_concatenate(): - from pytato.transform.parameter_study import ParameterStudyVectorizer - - name_to_studies, studies, names = _set_up_expansion_mapper_tests() - - expr = pt.transpose(pt.make_placeholder(names[0], - (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], - (15, 5), dtype=int)) - - out_expr = pt.concatenate([expr, expr2], axis=0) - assert out_expr.shape == (10, 15) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, 10, 1000) - assert isinstance(new_expr, pt.Concatenate) - - for i, axis in enumerate(new_expr.axes): - tags = axis.tags_of_type(ParameterStudyAxisTag) - if i > 1: - assert tags - else: - assert not tags - - assert not new_expr.axes[2].tags_of_type(studies[1]) - assert not new_expr.axes[3].tags_of_type(studies[0]) - - -def test_expansion_mapper_einsum_matmul(): - from pytato.transform.parameter_study import ParameterStudyVectorizer - - name_to_studies, _, names = _set_up_expansion_mapper_tests() - - # Matmul gets expanded correctly. - a = pt.make_placeholder(names[0], - (47, 42), dtype=int) - b = pt.make_placeholder(names[1], - (42, 5), dtype=int) - - c = pt.matmul(a, b) - assert isinstance(c, pt.Einsum) - assert c.shape == (47, 5) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - updated_c = my_mapper(c) - - assert updated_c.shape == (47, 5, 10, 1000) - - -def test_expansion_mapper_does_nothing_if_tags_not_there(): - from pytato.transform.parameter_study import ParameterStudyVectorizer - - name_to_studies, _, _ = _set_up_expansion_mapper_tests() - - from testlib import RandomDAGContext, make_random_dag - - from pytools import UniqueNameGenerator - axis_len = 5 - - for i in range(50): - print(i) # progress indicator - - seed = 120 + i - rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), - axis_len=axis_len, use_numpy=False) - - dag = pt.make_dict_of_named_arrays({"out": make_random_dag(rdagc_pt)}) - - # {{{ convert data-wrappers to placeholders - - vng = UniqueNameGenerator() - - def make_dws_placeholder(expr): - if isinstance(expr, pt.DataWrapper): - return pt.make_placeholder(vng("_pt_ph"), # noqa: B023 - expr.shape, expr.dtype) - else: - return expr - - dag = pt.transform.map_and_copy(dag, make_dws_placeholder) - - my_mapper = ParameterStudyVectorizer(name_to_studies) - new_dag = my_mapper(dag) - - assert new_dag == dag - - # }}} -# }}} - - def test_cached_walk_mapper_with_extra_args(): from testlib import RandomDAGContext, make_random_dag diff --git a/test/test_vectorizer.py b/test/test_vectorizer.py index 55652c818..6f6060914 100644 --- a/test/test_vectorizer.py +++ b/test/test_vectorizer.py @@ -27,24 +27,42 @@ THE SOFTWARE. """ import numpy as np -import pytato as pt - from testlib import RandomDAGContext, make_random_dag + +import pytato as pt from pytato.transform.parameter_study import ( ParameterStudyAxisTag, ParameterStudyVectorizer, ) + +class Study1(ParameterStudyAxisTag): + """First parameter study.""" + pass + + +class Study2(ParameterStudyAxisTag): + """Second parameter study.""" + pass + + +global_array_name1 = "foo" +global_array_name2 = "bar" +global_shape1: int = 1234 +global_shape2: int = 5678 +global_study1 = Study1(global_shape1) +global_study2 = Study2(global_shape2) +global_name_to_studies = {global_array_name1: frozenset((global_study1,)), + global_array_name2: frozenset((global_study2,)), } + + # {{{ Expansion Mapper tests. def test_vectorize_mapper_placeholder(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int) + expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int) assert expr.shape == (15, 5) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (15, 5, 10) + assert new_expr.shape == (15, 5, global_shape1) for i, axis in enumerate(new_expr.axes): tags = axis.tags_of_type(ParameterStudyAxisTag) @@ -55,30 +73,25 @@ def test_vectorize_mapper_placeholder(): def test_vectorize_mapper_basic_index(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] assert expr.shape == () - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (10,) + assert new_expr.shape == (global_shape1,) assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) def test_vectorize_mapper_advanced_index_contiguous_axes(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[pt.arange(10, dtype=int)] + expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int) + expr = expr[pt.arange(10, dtype=int)] assert expr.shape == (10, 5) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (10, 5, 10) + assert new_expr.shape == (10, 5, global_shape1) assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) @@ -86,19 +99,18 @@ def test_vectorize_mapper_advanced_index_contiguous_axes(): def test_vectorize_mapper_advanced_index_non_contiguous_axes(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} ind0 = pt.arange(10, dtype=int).reshape(10, 1) ind1 = pt.arange(2, dtype=int).reshape(1, 2) - expr = pt.make_placeholder(name, (15, 1000, 5), dtype=int)[ind0, :, ind1] + + expr = pt.make_placeholder(global_array_name1, (15, 1000, 5), dtype=int) + expr = expr[ind0, :, ind1] assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) assert expr.shape == (10, 2, 1000) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (10, 2, 1000, 10) + assert new_expr.shape == (10, 2, 1000, global_shape1) assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) @@ -106,16 +118,14 @@ def test_vectorize_mapper_advanced_index_non_contiguous_axes(): def test_vectorize_mapper_index_lambda(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] \ + + pt.ones(100) assert expr.shape == (100,) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (100, 10) + assert new_expr.shape == (100, global_shape1) assert isinstance(new_expr, pt.IndexLambda) scalar_expr = new_expr.expr @@ -127,32 +137,28 @@ def test_vectorize_mapper_index_lambda(): def test_vectorize_mapper_roll(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.make_placeholder(name, (15, 5), dtype=int)[14, 0] + pt.ones(100) + expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] \ + + pt.ones(100) + expr = pt.roll(expr, axis=0, shift=22) assert expr.shape == (100,) assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (100, 10,) + assert new_expr.shape == (100, global_shape1,) assert isinstance(new_expr, pt.Roll) assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) def test_vectorize_mapper_axis_permutation(): - name = "my_array" - my_study = ParameterStudyAxisTag(10) - name_to_studies = {name: frozenset((my_study,))} - expr = pt.transpose(pt.make_placeholder(name, (15, 5), dtype=int)) + expr = pt.transpose(pt.make_placeholder(global_array_name1, (15, 5), dtype=int)) assert expr.shape == (5, 15) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(expr) - assert new_expr.shape == (5, 15, 10) + assert new_expr.shape == (5, 15, global_shape1) assert isinstance(new_expr, pt.AxisPermutation) for i, axis in enumerate(new_expr.axes): @@ -164,18 +170,17 @@ def test_vectorize_mapper_axis_permutation(): def test_vectorize_mapper_reshape(): - name_to_studies, studies, names = _set_up_vectorize_mapper_tests() - expr = pt.transpose(pt.make_placeholder(names[0], + expr = pt.transpose(pt.make_placeholder(global_array_name1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], + expr2 = pt.transpose(pt.make_placeholder(global_array_name2, (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) assert out_expr.shape == (10, 15) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, 10, 1000) + assert new_expr.shape == (10, 15, global_shape1, global_shape2) assert isinstance(new_expr, pt.Reshape) for i, axis in enumerate(new_expr.axes): @@ -185,24 +190,22 @@ def test_vectorize_mapper_reshape(): else: assert not tags - assert not new_expr.axes[2].tags_of_type(studies[1]) - assert not new_expr.axes[3].tags_of_type(studies[0]) + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) def test_vectorize_mapper_stack(): - name_to_studies, studies, names = _set_up_vectorize_mapper_tests() - - expr = pt.transpose(pt.make_placeholder(names[0], + expr = pt.transpose(pt.make_placeholder(global_array_name1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], + expr2 = pt.transpose(pt.make_placeholder(global_array_name2, (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0) assert out_expr.shape == (2, 5, 15) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(out_expr) - assert new_expr.shape == (2, 5, 15, 10, 1000) + assert new_expr.shape == (2, 5, 15, global_shape1, global_shape2) assert isinstance(new_expr, pt.Stack) for i, axis in enumerate(new_expr.axes): @@ -212,43 +215,22 @@ def test_vectorize_mapper_stack(): else: assert not tags - assert not new_expr.axes[3].tags_of_type(studies[1]) - assert not new_expr.axes[4].tags_of_type(studies[0]) - - -def _set_up_vectorize_mapper_tests() -> tuple[Mapping[str, - frozenset[ParameterStudyAxisTag]], - tuple[ParameterStudyAxisTag, ...], - tuple[str, ...]]: - - class Study2(ParameterStudyAxisTag): - pass - - class Study1(ParameterStudyAxisTag): - pass - name = "a" - study1 = Study1(10) - arr2 = "b" - study2 = Study2(1000) - name_to_studies = {name: frozenset((study1,)), arr2: frozenset((study2,))} - - return name_to_studies, (Study1, Study2,), (name, arr2,) + assert not new_expr.axes[3].tags_of_type(Study2) + assert not new_expr.axes[4].tags_of_type(Study1) def test_vectorize_mapper_concatenate(): - name_to_studies, studies, names = _set_up_vectorize_mapper_tests() - - expr = pt.transpose(pt.make_placeholder(names[0], + expr = pt.transpose(pt.make_placeholder(global_array_name1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(names[1], + expr2 = pt.transpose(pt.make_placeholder(global_array_name2, (15, 5), dtype=int)) out_expr = pt.concatenate([expr, expr2], axis=0) assert out_expr.shape == (10, 15) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, 10, 1000) + assert new_expr.shape == (10, 15, global_shape1, global_shape2) assert isinstance(new_expr, pt.Concatenate) for i, axis in enumerate(new_expr.axes): @@ -258,34 +240,28 @@ def test_vectorize_mapper_concatenate(): else: assert not tags - assert not new_expr.axes[2].tags_of_type(studies[1]) - assert not new_expr.axes[3].tags_of_type(studies[0]) + assert not new_expr.axes[2].tags_of_type(Study2) + assert not new_expr.axes[3].tags_of_type(Study1) def test_vectorize_mapper_einsum_matmul(): - - name_to_studies, _, names = _set_up_vectorize_mapper_tests() - # Matmul gets expanded correctly. - a = pt.make_placeholder(names[0], + a = pt.make_placeholder(global_array_name1, (47, 42), dtype=int) - b = pt.make_placeholder(names[1], + b = pt.make_placeholder(global_array_name2, (42, 5), dtype=int) c = pt.matmul(a, b) assert isinstance(c, pt.Einsum) assert c.shape == (47, 5) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) updated_c = my_mapper(c) - assert updated_c.shape == (47, 5, 10, 1000) + assert updated_c.shape == (47, 5, global_shape1, global_shape2) def test_vectorize_mapper_does_nothing_if_tags_not_there(): - name_to_studies, _, _ = _set_up_vectorize_mapper_tests() - - from testlib import RandomDAGContext, make_random_dag from pytools import UniqueNameGenerator axis_len = 5 @@ -312,7 +288,7 @@ def make_dws_placeholder(expr): dag = pt.transform.map_and_copy(dag, make_dws_placeholder) - my_mapper = ParameterStudyVectorizer(name_to_studies) + my_mapper = ParameterStudyVectorizer(global_name_to_studies) new_dag = my_mapper(dag) assert new_dag == dag From fc73e1c6ba3e2be520dadc62661b73546315351f Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 4 Sep 2024 10:07:46 -0500 Subject: [PATCH 25/27] Update the test file to be the same as the original. --- test/test_pytato.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 0e9661d53..f67e7e5f1 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -28,9 +28,6 @@ """ import sys -from typing import ( - Mapping, -) import attrs import numpy as np From 01549d588e8be9c7901f8f522d814ec4cf998807 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 6 Sep 2024 15:03:54 -0500 Subject: [PATCH 26/27] Respond to Andreas comments. --- pytato/transform/parameter_study.py | 317 ++++++++++++---------------- test/test_vectorizer.py | 94 ++++----- 2 files changed, 181 insertions(+), 230 deletions(-) diff --git a/pytato/transform/parameter_study.py b/pytato/transform/parameter_study.py index 851d6e16f..f88ccccbe 100644 --- a/pytato/transform/parameter_study.py +++ b/pytato/transform/parameter_study.py @@ -33,8 +33,8 @@ from dataclasses import dataclass from typing import ( - Iterable, Mapping, + Sequence, ) from immutabledict import immutabledict @@ -73,6 +73,9 @@ from pytato.transform import CopyMapper +KnownShapeType = tuple[IntegralT, ...] + + @dataclass(frozen=True) class ParameterStudyAxisTag(UniqueTag): """ @@ -85,11 +88,6 @@ class ParameterStudyAxisTag(UniqueTag): size: int -StudiesT = tuple[ParameterStudyAxisTag, ...] -ArraysT = tuple[Array, ...] -KnownShapeType = tuple[IntegralT, ...] - - class IndexLambdaScalarExpressionVectorizer(IdentityMapper): """ The goal of this mapper is to convert a single instance scalar expression @@ -124,19 +122,11 @@ def map_subscript(self, expr: prim.Subscript) -> prim.Subscript: index = self.rec(expr.index) - new_vars: tuple[prim.Variable, ...] = () - my_studies: tuple[int, ...] = self.varname_to_studies_num[name] - - for num in my_studies: - new_vars = (*new_vars, - prim.Variable(f"_{self.num_orig_elem_inds + num}"),) + additional_inds = (prim.Variable(f"_{self.num_orig_elem_inds + num}") for + num in self.varname_to_studies_num[name]) - if isinstance(index, tuple): - index = index + new_vars - else: - index = tuple(index) + new_vars - - return type(expr)(aggregate=expr.aggregate, index=index) + return type(expr)(aggregate=expr.aggregate, + index=(*index, *additional_inds,)) return super().map_subscript(expr) @@ -155,7 +145,7 @@ def map_variable(self, expr: prim.Variable) -> prim.Expression: assert my_studies assert len(my_studies) > 0 - new_vars = (prim.Variable(f"_{self.num_orig_elem_inds + num}") # noqa + new_vars = (prim.Variable(f"_{self.num_orig_elem_inds + num}") # noqa E501 for num in my_studies) return prim.Subscript(aggregate=expr, index=tuple(new_vars)) @@ -164,150 +154,84 @@ def map_variable(self, expr: prim.Variable) -> prim.Expression: return super().map_variable(expr) +def _param_study_to_index(tag: ParameterStudyAxisTag) -> str: + """ + Get the canonical index string associated with the input tag. + """ + return str(tag.__class__) # Update to use the qualname or name. + + class ParameterStudyVectorizer(CopyMapper): r""" - This mapper will expand a single instance DAG into a DAG for parameter studies. - The DAG for parameter studies will be equivalent to running the single instance - DAG for each input in the parameter study space. You must specify which - input :class:`~pytato.array.Placeholder` are part of what parameter study. + This mapper will expand a DAG into a DAG for parameter studies. An array is part + of a parameter study if one of its axes is tagged with a with a tag from + :class:`ParameterStudyAxisTag` and all of the axes after that axis are also tagged + with a distinct :class:`ParameterStudyAxisTag` tag. An array may be a member of + multiple parameter studies. The new DAG for parameter studies which will be + equivalent to running your original DAG once for each input in the parameter study + space. The parameter study space is defined as the Cartesian product of all + the input parameter studies. When calling this mapper you must specify which input + :class:`~pytato.array.Placeholder` arrays are part of what parameter study. To maintain the equivalence with repeated calling the single instance DAG, the DAG for parameter studies will not create any expressions which depend on the specific instance of a parameter study. - We do allow an input parameter which is specified by a - :class:`~pytato.array.Placeholder` to be a part of multiple distinct parameter - studies. Each distinct parameter study will be a new axis at the end of the array. - - We do NOT require that each input be part of each parameter study. We will broadcast - the input as necessary. + It is not required that each input be part of a parameter study as we will + broadcast the input to the appropriate size. - Ex: + The mapper does not support distributed programming or function definitions. - .. math:: + .. note:: - \mathbf{Z} = \mathbf{X} + \mathbf{Y}, - - where :math:`\mathbf{X}` is a part of parameter study :math:`\mathbf{S1}` and - :math:`\mathbf{Y}` is a part of parameter study :math:`\mathbf{S2}` both with single - instance shapes of :math:`\mathbf{orig\_shape}`. Then, :math:`\mathbf{Z}` will be a - part of both parameter studies :math:`\mathbf{S1}` and :math:`\mathbf{S2}`. - - .. math:: - - \mathbf{Z}_{i,j} = \mathbf{X}_{i} + \mathbf{Y}_{j}, - - and so the shape of :math:`\mathbf{Z}` will be - (:math:`\mathbf{orig\_shape}`, :math:`\mathbf{S1}.size`, :math:`\mathbf{S2}.size`). - - A parameter study is specified in an array by tagging the corresponding axis - with a tag that is a :class:`ParameterStudyAxisTag` or a class which - inherits from it. - - Currently, this only supports DAGs which are made for a single processing unit. - That is we do not support distributed programming right now. We also do not support - function definitions within the single instance DAG. - - NOTE any new axes used for parameter studies will be added to the end of the arrays. - If only a portion of the program is expanded, broadcasting may break as the single - instance axes will be left aligned instead of right aligned. This decision was - made under the assumption that the generated code would execute faster - if the parameter study axes were the ones with the shortest strides. + Any new axes used for parameter studies will be added to the end of the arrays. """ - def __init__(self, placeholder_name_to_parameter_studies: Mapping[str, StudiesT]): + def __init__(self, + place_name_to_parameter_studies: Mapping[str, + tuple[ParameterStudyAxisTag, ...]], + study_to_size: Mapping[ParameterStudyAxisTag, int]): super().__init__() - self.placeholder_name_to_parameter_studies = placeholder_name_to_parameter_studies # noqa - - def _shapes_and_axes_from_predecessors(self, curr_expr: Array, - mapped_preds: ArraysT) -> \ - tuple[KnownShapeType, - AxesT, - dict[int, tuple[int, ...]]]: - - assert not any(axis.tags_of_type(ParameterStudyAxisTag) for - axis in curr_expr.axes) - - # We are post pending the axes we are using for parameter studies. - - study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT] = {} + self.place_name_to_parameter_studies = place_name_to_parameter_studies # E501 + self.study_to_size = study_to_size + def _get_canonical_ordered_studies( + self, mapped_preds: tuple[Array, ...]) -> Sequence[ParameterStudyAxisTag]: active_studies: set[ParameterStudyAxisTag] = set() - for arr in mapped_preds: for axis in arr.axes: tags = axis.tags_of_type(ParameterStudyAxisTag) if tags: assert len(tags) == 1 # only one study per axis. active_studies = active_studies.union(tags) - if tags in study_to_arrays.keys(): - study_to_arrays[tags] = (*study_to_arrays[tags], arr) - else: - study_to_arrays[tags] = (arr,) - - ps, na, arr_num_to_study_nums = self._studies_to_shape_and_axes_and_arrays_in_canonical_order(active_studies, # noqa - study_to_arrays, mapped_preds) - - # Add in the arrays that are not a part of a parameter study. - # This is done to avoid any KeyErrors later. - - for arr_num in range(len(mapped_preds)): - if arr_num not in arr_num_to_study_nums.keys(): - arr_num_to_study_nums[arr_num] = () - else: - assert len(arr_num_to_study_nums[arr_num]) > 0 - - assert len(arr_num_to_study_nums) == len(mapped_preds) - - for axis in na: - assert axis.tags_of_type(ParameterStudyAxisTag) - - return ps, na, arr_num_to_study_nums - - def _studies_to_shape_and_axes_and_arrays_in_canonical_order(self, - studies: Iterable[ParameterStudyAxisTag], - study_to_arrays: dict[frozenset[ParameterStudyAxisTag], ArraysT], - mapped_preds: ArraysT) -> tuple[KnownShapeType, AxesT, - dict[int, tuple[int, ...]]]: - - # This is where we specify the canonical ordering of the studies. - array_num_to_study_nums: dict[int, tuple[int, ...]] = {} - new_shape: KnownShapeType = () - studies_axes: AxesT = () - - num_studies: int = 0 - for ind, study in enumerate(sorted(studies, - key=lambda x: str(x.__class__))): - new_shape = (*new_shape, study.size) - studies_axes = (*studies_axes, Axis(tags=frozenset((study,)))) - for arr_num, arr in enumerate(mapped_preds): - if arr in study_to_arrays[frozenset((study,))]: - if arr_num in array_num_to_study_nums.keys(): - array_num_to_study_nums[arr_num] = (*array_num_to_study_nums[arr_num], ind) # noqa - else: - array_num_to_study_nums[arr_num] = (ind,) - num_studies += 1 - - assert len(new_shape) == num_studies - assert len(new_shape) == len(studies_axes) - - return new_shape, studies_axes, array_num_to_study_nums + + return sorted(active_studies, key=_param_study_to_index) + + def _canonical_ordered_studies_to_shapes_and_axes( + self, studies: Sequence[ParameterStudyAxisTag]) -> tuple[list[int], + list[Axis]]: + """ + Get the shapes and axes in the canonical ordering. + """ + + return [self.study_to_size[study] for study in studies], \ + [Axis(tags=frozenset((study,))) for study in studies] def map_placeholder(self, expr: Placeholder) -> Array: - # This is where we could introduce extra axes. - if expr.name in self.placeholder_name_to_parameter_studies.keys(): - studies = self.placeholder_name_to_parameter_studies[expr.name] + + if expr.name in self.place_name_to_parameter_studies.keys(): + canonical_studies = sorted(self.place_name_to_parameter_studies[expr.name], + key=_param_study_to_index) # noqa E501 # We know that there are no predecessors and we know the studies to add. # We need to get them in the right order. - new_shape, new_axes, _ = self._studies_to_shape_and_axes_and_arrays_in_canonical_order( # noqa - studies, {}, ()) + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 return Placeholder(name=expr.name, shape=self.rec_idx_or_size_tuple((*expr.shape, - *new_shape,)), + *end_shape,)), dtype=expr.dtype, - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *end_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) @@ -315,62 +239,68 @@ def map_placeholder(self, expr: Placeholder) -> Array: def map_roll(self, expr: Roll) -> Array: new_predecessor = self.rec(expr.array) - _, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, - (new_predecessor,)) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + _, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 return Roll(array=new_predecessor, shift=expr.shift, axis=expr.axis, - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *end_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, - (new_predecessor,)) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shapes, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + # Include the axes we are adding to the system. n_single_inst_axes: int = len(expr.axis_permutation) - axis_permute_gen = (i + n_single_inst_axes for i in range(len(postpend_shape))) + axis_permute_gen = (i + n_single_inst_axes for i in range(len(end_shapes))) return AxisPermutation(array=new_predecessor, axis_permutation=(*expr.axis_permutation, *axis_permute_gen,), - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *end_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, - (new_predecessor,)) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + # Update the indices. - new_indices = expr.indices - for shape in postpend_shape: - new_indices = (*new_indices, NormalizedSlice(0, shape, 1)) + end_indices = (NormalizedSlice(0, shape, 1) for shape in end_shape) return type(expr)(new_predecessor, - indices=self.rec_idx_or_size_tuple(new_indices), - axes=(*expr.axes, *new_axes,), + indices=self.rec_idx_or_size_tuple((*expr.indices, + *end_indices)), + axes=(*expr.axes, *end_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: new_predecessor = self.rec(expr.array) - postpend_shape, new_axes, _ = self._shapes_and_axes_from_predecessors(expr, - (new_predecessor,)) + + canonical_studies = self._get_canonical_ordered_studies((new_predecessor,)) + end_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + return Reshape(new_predecessor, - newshape=self.rec_idx_or_size_tuple(expr.newshape + - postpend_shape), + newshape=self.rec_idx_or_size_tuple((*expr.shape, + *end_shape,)), order=expr.order, - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *end_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) # {{{ Operations with multiple predecessors. def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ - -> tuple[ArraysT, AxesT]: + -> tuple[tuple[Array, ...], AxesT]: """ This method will convert predecessors who were originally the same @@ -380,24 +310,25 @@ def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ new_predecessors = tuple(self.rec(arr) for arr in expr.arrays) - studies_shape, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa + canonical_studies = self._get_canonical_ordered_studies(new_predecessors) + studies_shape, new_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 - if not arr_num_to_study_nums: + if not studies_shape: # We do not need to do anything as the expression we have is unmodified. - return new_predecessors, new_axes - - # This is going to be expensive. - correct_shape_preds: ArraysT = () + return new_predecessors, tuple(new_axes) + correct_shape_preds: tuple[Array, ...] = () for iarr, array in enumerate(new_predecessors): # Broadcast out to the right shape. n_single_inst_axes = len(expr.arrays[iarr].shape) - index = tuple(prim.Variable(f"_{ind}") for - ind in range(n_single_inst_axes)) - # Add in the axes from the studies we have in the predecessor. - for study_num in arr_num_to_study_nums[iarr]: - index = (*index, prim.Variable(f"_{n_single_inst_axes + study_num}")) + # We assume there is at most one axis + # tag of type ParameterStudyAxisTag per axis. + + index = tuple(prim.Variable(f"_{ind}") if not \ + array.axes[ind].tags_of_type(ParameterStudyAxisTag) \ + else prim.Variable(f"_{n_single_inst_axes + canonical_studies.index(tuple(array.axes[ind].tags_of_type(ParameterStudyAxisTag))[0])}") # noqa E501 + for ind in range(len(array.shape))) assert len(index) == len(array.axes) @@ -416,7 +347,7 @@ def _broadcast_predecessors_to_same_shape(self, expr: Stack | Concatenate) \ for arr in correct_shape_preds: assert arr.shape == correct_shape_preds[0].shape - return correct_shape_preds, new_axes + return correct_shape_preds, tuple(new_axes) def map_stack(self, expr: Stack) -> Array: new_arrays, new_axes = self._broadcast_predecessors_to_same_shape(expr) @@ -452,46 +383,66 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # However, the index will be unique. array_num_to_bnd_name: dict[int, str] = {ind: name for ind, (name, _) - in enumerate(sorted(new_binds.items()))} # noqa + in enumerate(sorted(new_binds.items()))} # noqa E501 # Determine the new parameter studies that are being conducted. - postpend_shape, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, # noqa - new_arrays) + canonical_studies = self._get_canonical_ordered_studies(new_arrays) + postpend_shapes, post_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 + + varname_to_studies_nums: dict[str, tuple[int, ...]] = {bnd_name: () for + _, bnd_name + in array_num_to_bnd_name.items()} + + for iarr, array in enumerate(new_arrays): + for axis in array.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + assert len(tags) == 1 + study: ParameterStudyAxisTag = next(iter(tags)) + name: str = array_num_to_bnd_name[iarr] + varname_to_studies_nums[name] = (*varname_to_studies_nums[name], + canonical_studies.index(study),) - varname_to_studies_nums = {array_num_to_bnd_name[iarr]: studies for iarr, - studies in arr_num_to_study_nums.items()} + #varname_to_studies_nums = {array_num_to_bnd_name[iarr]: studies for iarr, + # studies in arr_num_to_study_nums.items()} - for vn_key in varname_to_studies_nums.keys(): - assert vn_key in new_binds.keys() + assert all(vn_key in new_binds.keys() for + vn_key in varname_to_studies_nums.keys()) - for vn_key in new_binds.keys(): - assert vn_key in varname_to_studies_nums.keys() + assert all(vn_key in varname_to_studies_nums.keys() for + vn_key in new_binds.keys()) # Now we need to update the expressions. - scalar_expr_mapper = IndexLambdaScalarExpressionVectorizer(varname_to_studies_nums, len(expr.shape)) # noqa + scalar_expr_mapper = IndexLambdaScalarExpressionVectorizer(varname_to_studies_nums, len(expr.shape)) # noqa E501 return IndexLambda(expr=scalar_expr_mapper(expr.expr), bindings=immutabledict(new_binds), - shape=(*expr.shape, *postpend_shape,), + shape=(*expr.shape, *postpend_shapes,), var_to_reduction_descr=expr.var_to_reduction_descr, dtype=expr.dtype, - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *post_axes,), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_einsum(self, expr: Einsum) -> Array: new_predecessors = tuple(self.rec(arg) for arg in expr.args) - _, new_axes, arr_num_to_study_nums = self._shapes_and_axes_from_predecessors(expr, new_predecessors) # noqa + canonical_studies = self._get_canonical_ordered_studies(new_predecessors) + studies_shape, end_axes = self._canonical_ordered_studies_to_shapes_and_axes(canonical_studies) # noqa E501 access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...] = () for ival, array in enumerate(new_predecessors): one_descr = expr.access_descriptors[ival] - if arr_num_to_study_nums: - for ind in arr_num_to_study_nums[ival]: + for axis in array.axes: + tags = axis.tags_of_type(ParameterStudyAxisTag) + if tags: + # Need to append a descriptor + assert len(tags) == 1 + study: ParameterStudyAxisTag = next(iter(tags)) one_descr = (*one_descr, - # Adding in new element axes to the end of the arrays. - EinsumElementwiseAxis(dim=len(expr.shape) + ind)) + EinsumElementwiseAxis(dim=len(expr.shape) + + canonical_studies.index(study))) + access_descriptors = (*access_descriptors, one_descr) # One descriptor per axis. @@ -499,7 +450,7 @@ def map_einsum(self, expr: Einsum) -> Array: return Einsum(access_descriptors, new_predecessors, - axes=(*expr.axes, *new_axes,), + axes=(*expr.axes, *end_axes,), redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) diff --git a/test/test_vectorizer.py b/test/test_vectorizer.py index 6f6060914..8ab15a589 100644 --- a/test/test_vectorizer.py +++ b/test/test_vectorizer.py @@ -46,23 +46,23 @@ class Study2(ParameterStudyAxisTag): pass -global_array_name1 = "foo" -global_array_name2 = "bar" -global_shape1: int = 1234 -global_shape2: int = 5678 -global_study1 = Study1(global_shape1) -global_study2 = Study2(global_shape2) -global_name_to_studies = {global_array_name1: frozenset((global_study1,)), - global_array_name2: frozenset((global_study2,)), } - +GLOBAL_ARRAY_NAME1 = "foo" +GLOBAL_ARRAY_NAME2 = "bar" +GLOBAL_SHAPE1: int = 1234 +GLOBAL_SHAPE2: int = 5678 +GLOBAL_STUDY1 = Study1(GLOBAL_SHAPE1) +GLOBAL_STUDY2 = Study2(GLOBAL_SHAPE2) +GLOBAL_NAME_TO_STUDIES = {GLOBAL_ARRAY_NAME1: frozenset((GLOBAL_STUDY1,)), + GLOBAL_ARRAY_NAME2: frozenset((GLOBAL_STUDY2,)), } +GLOBAL_STUDY_TO_SHAPES = {GLOBAL_STUDY1: GLOBAL_SHAPE1, GLOBAL_STUDY2: GLOBAL_SHAPE2} # {{{ Expansion Mapper tests. def test_vectorize_mapper_placeholder(): - expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int) + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int) assert expr.shape == (15, 5) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (15, 5, global_shape1) + assert new_expr.shape == (15, 5, GLOBAL_SHAPE1) for i, axis in enumerate(new_expr.axes): tags = axis.tags_of_type(ParameterStudyAxisTag) @@ -73,25 +73,25 @@ def test_vectorize_mapper_placeholder(): def test_vectorize_mapper_basic_index(): - expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] assert expr.shape == () - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (global_shape1,) + assert new_expr.shape == (GLOBAL_SHAPE1,) assert new_expr.axes[0].tags_of_type(ParameterStudyAxisTag) def test_vectorize_mapper_advanced_index_contiguous_axes(): - expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int) + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int) expr = expr[pt.arange(10, dtype=int)] assert expr.shape == (10, 5) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (10, 5, global_shape1) + assert new_expr.shape == (10, 5, GLOBAL_SHAPE1) assert new_expr.axes[2].tags_of_type(ParameterStudyAxisTag) assert isinstance(new_expr, pt.AdvancedIndexInContiguousAxes) @@ -102,15 +102,15 @@ def test_vectorize_mapper_advanced_index_non_contiguous_axes(): ind0 = pt.arange(10, dtype=int).reshape(10, 1) ind1 = pt.arange(2, dtype=int).reshape(1, 2) - expr = pt.make_placeholder(global_array_name1, (15, 1000, 5), dtype=int) + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 1000, 5), dtype=int) expr = expr[ind0, :, ind1] assert isinstance(expr, pt.AdvancedIndexInNoncontiguousAxes) assert expr.shape == (10, 2, 1000) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (10, 2, 1000, global_shape1) + assert new_expr.shape == (10, 2, 1000, GLOBAL_SHAPE1) assert new_expr.axes[3].tags_of_type(ParameterStudyAxisTag) assert isinstance(new_expr, pt.AdvancedIndexInNoncontiguousAxes) @@ -118,14 +118,14 @@ def test_vectorize_mapper_advanced_index_non_contiguous_axes(): def test_vectorize_mapper_index_lambda(): - expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] \ + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] \ + pt.ones(100) assert expr.shape == (100,) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (100, global_shape1) + assert new_expr.shape == (100, GLOBAL_SHAPE1) assert isinstance(new_expr, pt.IndexLambda) scalar_expr = new_expr.expr @@ -137,7 +137,7 @@ def test_vectorize_mapper_index_lambda(): def test_vectorize_mapper_roll(): - expr = pt.make_placeholder(global_array_name1, (15, 5), dtype=int)[14, 0] \ + expr = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)[14, 0] \ + pt.ones(100) expr = pt.roll(expr, axis=0, shift=22) @@ -145,20 +145,20 @@ def test_vectorize_mapper_roll(): assert expr.shape == (100,) assert not any(axis.tags_of_type(ParameterStudyAxisTag) for axis in expr.axes) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (100, global_shape1,) + assert new_expr.shape == (100, GLOBAL_SHAPE1,) assert isinstance(new_expr, pt.Roll) assert new_expr.axes[1].tags_of_type(ParameterStudyAxisTag) def test_vectorize_mapper_axis_permutation(): - expr = pt.transpose(pt.make_placeholder(global_array_name1, (15, 5), dtype=int)) + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)) assert expr.shape == (5, 15) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(expr) - assert new_expr.shape == (5, 15, global_shape1) + assert new_expr.shape == (5, 15, GLOBAL_SHAPE1) assert isinstance(new_expr, pt.AxisPermutation) for i, axis in enumerate(new_expr.axes): @@ -170,17 +170,17 @@ def test_vectorize_mapper_axis_permutation(): def test_vectorize_mapper_reshape(): - expr = pt.transpose(pt.make_placeholder(global_array_name1, + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(global_array_name2, + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0).reshape(10, 15) assert out_expr.shape == (10, 15) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, global_shape1, global_shape2) + assert new_expr.shape == (10, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) assert isinstance(new_expr, pt.Reshape) for i, axis in enumerate(new_expr.axes): @@ -195,17 +195,17 @@ def test_vectorize_mapper_reshape(): def test_vectorize_mapper_stack(): - expr = pt.transpose(pt.make_placeholder(global_array_name1, + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(global_array_name2, + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, (15, 5), dtype=int)) out_expr = pt.stack([expr, expr2], axis=0) assert out_expr.shape == (2, 5, 15) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(out_expr) - assert new_expr.shape == (2, 5, 15, global_shape1, global_shape2) + assert new_expr.shape == (2, 5, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) assert isinstance(new_expr, pt.Stack) for i, axis in enumerate(new_expr.axes): @@ -220,17 +220,17 @@ def test_vectorize_mapper_stack(): def test_vectorize_mapper_concatenate(): - expr = pt.transpose(pt.make_placeholder(global_array_name1, + expr = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME1, (15, 5), dtype=int)) - expr2 = pt.transpose(pt.make_placeholder(global_array_name2, + expr2 = pt.transpose(pt.make_placeholder(GLOBAL_ARRAY_NAME2, (15, 5), dtype=int)) out_expr = pt.concatenate([expr, expr2], axis=0) assert out_expr.shape == (10, 15) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_expr = my_mapper(out_expr) - assert new_expr.shape == (10, 15, global_shape1, global_shape2) + assert new_expr.shape == (10, 15, GLOBAL_SHAPE1, GLOBAL_SHAPE2) assert isinstance(new_expr, pt.Concatenate) for i, axis in enumerate(new_expr.axes): @@ -246,19 +246,19 @@ def test_vectorize_mapper_concatenate(): def test_vectorize_mapper_einsum_matmul(): # Matmul gets expanded correctly. - a = pt.make_placeholder(global_array_name1, + a = pt.make_placeholder(GLOBAL_ARRAY_NAME1, (47, 42), dtype=int) - b = pt.make_placeholder(global_array_name2, + b = pt.make_placeholder(GLOBAL_ARRAY_NAME2, (42, 5), dtype=int) c = pt.matmul(a, b) assert isinstance(c, pt.Einsum) assert c.shape == (47, 5) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) updated_c = my_mapper(c) - assert updated_c.shape == (47, 5, global_shape1, global_shape2) + assert updated_c.shape == (47, 5, GLOBAL_SHAPE1, GLOBAL_SHAPE2) def test_vectorize_mapper_does_nothing_if_tags_not_there(): @@ -288,7 +288,7 @@ def make_dws_placeholder(expr): dag = pt.transform.map_and_copy(dag, make_dws_placeholder) - my_mapper = ParameterStudyVectorizer(global_name_to_studies) + my_mapper = ParameterStudyVectorizer(GLOBAL_NAME_TO_STUDIES, GLOBAL_STUDY_TO_SHAPES) new_dag = my_mapper(dag) assert new_dag == dag From be786dde37b494421d0775eed1b5b2c15661ffee Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 7 Nov 2024 13:50:12 -0600 Subject: [PATCH 27/27] Add in the removal of the raising operator. --- pytato/scalar_expr.py | 3 +- pytato/transform/metadata.py | 107 +++++++++++++++++------------------ 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 24f2c684b..982737fa5 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -142,7 +142,8 @@ def map_reduce(self, expr: Reduce) -> ScalarExpression: IDX_LAMBDA_RE = re.compile("_r?(0|([1-9][0-9]*))") - +IDX_LAMBDA_INAME = re.compile("^(_(0|([1-9][0-9]*)))$") +IDX_LAMBDA_JUST_REDUCTIONS = re.compile("^(_r(0|([1-9][0-9]*)))$") class DependencyMapper(DependencyMapperBase): def __init__(self, *, diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 319611107..2afcafb8e 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -99,6 +99,34 @@ GraphNodeT = TypeVar("GraphNodeT") +from pytato.scalar_expr import ( + IDX_LAMBDA_INAME, + IDX_LAMBDA_JUST_REDUCTIONS, + IdentityMapper as ScalarMapper, +) + +class AxesUsedMapper(ScalarMapper): + """ + Determine which axes are used in the scalar expression and which ones just flow + through the expression. + """ + + def __init__(self, var_names_in_use: list[str]): + self.var_names_in_use: list[str] = var_names_in_use + + self.usage_dict: Mapping[str, list[tuple[prim.Expression, ...]]] = {vname: [] + for vname in + self.var_names_in_use} + + def map_subscript(self, expr: prim.Subscript) -> None: + + name = expr.aggregate.name + if name in self.var_names_in_use: + self.usage_dict[name].append(expr.index_tuple) + + self.rec(expr.index) + + # {{{ AxesTagsEquationCollector @@ -237,65 +265,36 @@ def map_index_lambda(self, expr: IndexLambda) -> None: for bnd in expr.bindings.values(): self.rec(bnd) - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - from warnings import warn - warn(f"'{expr}' is an unknown index lambda type" - " no tags were propagated across it.", stacklevel=1) - # no propagation semantics implemented for such cases - return - - if isinstance(hlo, BinaryOp): - subexprs: tuple[ArrayOrScalar, ...] = (hlo.x1, hlo.x2) - elif isinstance(hlo, WhereOp): - subexprs = (hlo.condition, hlo.then, hlo.else_) - elif isinstance(hlo, FullOp): - # A full-op does not impose any equations - subexprs = () - elif isinstance(hlo, BroadcastOp): - subexprs = (hlo.x,) - elif isinstance(hlo, C99CallOp): - subexprs = hlo.args - elif isinstance(hlo, ReduceOp): - - # {{{ ReduceOp doesn't quite involve broadcasting - - i_out_axis = 0 - for i_in_axis in range(hlo.x.ndim): - if i_in_axis not in hlo.axes: - self.record_equation( - self.get_var_for_axis(hlo.x, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - i_out_axis += 1 - assert i_out_axis == expr.ndim + keys = list(expr.bindings.keys()) - # }}} + mymapper = AxesUsedMapper(keys) - return + mymapper(expr.expr) - else: - raise NotImplementedError(type(hlo)) + out_shape = expr.shape + assert len(out_shape) == expr.ndim - for subexpr in subexprs: - if isinstance(subexpr, Array): - for i_in_axis, i_out_axis in zip( - range(subexpr.ndim), - range(expr.ndim-subexpr.ndim, expr.ndim)): - in_dim = subexpr.shape[i_in_axis] - out_dim = expr.shape[i_out_axis] - if are_shape_components_equal(in_dim, out_dim): - self.record_equation( - self.get_var_for_axis(subexpr, i_in_axis), - self.get_var_for_axis(expr, i_out_axis) - ) - else: - # i_in_axis is broadcasted => do not propagate - assert are_shape_components_equal(in_dim, 1) - else: - assert isinstance(subexpr, SCALAR_CLASSES) + for key in keys: + if len(mymapper.usage_dict[key]) > 0: + for tup_ind in range(len(mymapper.usage_dict[key][0])): + vname = mymapper.usage_dict[key][0][tup_ind] + if isinstance(vname, prim.Variable): + if IDX_LAMBDA_JUST_REDUCTIONS.fullmatch(vname.name): + # Reduction axis. Pass for now. + pass + elif vname.name[:3] == "_in": + # Array used as part of the index. Pass for now. + pass + elif IDX_LAMBDA_INAME.fullmatch(vname.name): + # Matched with an output axis. + inum = int(vname.name[1:]) + val = (self.get_var_for_axis(expr.bindings[key], tup_ind), + self.get_var_for_axis(expr, inum)) + self.equations.append(val) + else: + raise ValueError(f"Unknown index name used in {vname}") + return def map_stack(self, expr: Stack) -> None: """