diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index f94adf8a5..49915bc9f 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -6201,14 +6201,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 44, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -6929,22 +6921,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 21, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 41, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -6961,22 +6937,6 @@ "lineCount": 1 } }, - { - "code": "reportPrivateUsage", - "range": { - "startColumn": 29, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 38, - "endColumn": 42, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -6985,14 +6945,6 @@ "lineCount": 1 } }, - { - "code": "reportPrivateUsage", - "range": { - "startColumn": 29, - "endColumn": 39, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -10523,22 +10475,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 62, - "endColumn": 79, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 62, - "endColumn": 79, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -10554,14 +10490,6 @@ "endColumn": 51, "lineCount": 1 } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 62, - "endColumn": 79, - "lineCount": 1 - } } ], "./test/test_pytato.py": [ diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 2626856ba..6de9e5de4 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -694,7 +694,7 @@ def get_num_call_sites(outputs: ArrayOrNames) -> int: # {{{ TagCountMapper -class TagCountMapper(CombineMapper[int, Never]): +class TagCountMapper(CombineMapper[int, Never, []]): """ Returns the number of nodes in a DAG that are tagged with all the tag types in *tag_types*. diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index bf990a557..7e6203d2d 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -414,7 +414,7 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[FrozenOrderedSet[CommunicationOpIdentifier], Never]): + CombineMapper[FrozenOrderedSet[CommunicationOpIdentifier], Never, []]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 22edf06e3..f34f4c647 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -815,7 +815,8 @@ def map_call(self, expr: Call, state: CodeGenState) -> None: } -class ReductionCollector(scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]): +class ReductionCollector( + scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]): """ Constructs a :class:`frozenset` containing all instances of :class:`pytato.scalar_expr.Reduce` found in a scalar expression. diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 7dc67e0cc..604ce4826 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1237,129 +1237,162 @@ def deduplicate( # {{{ CombineMapper -class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]): +class CombineMapper(CachedMapper[ResultT, FunctionResultT, P]): """ Abstract mapper that recursively combines the results of user nodes of a given expression. .. automethod:: combine """ - def get_cache_key(self, expr: ArrayOrNames) -> CacheKeyT: - return expr - - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> CacheKeyT: - return expr - - def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] - ) -> tuple[ResultT, ...]: - return tuple(self.rec(s) for s in situp if isinstance(s, Array)) + def rec_idx_or_size_tuple( + self, + situp: tuple[IndexOrShapeExpr, ...], + *args: P.args, **kwargs: P.kwargs) -> tuple[ResultT, ...]: + return tuple( + self.rec(s, *args, **kwargs) + for s in situp if isinstance(s, Array)) def combine(self, *args: ResultT) -> ResultT: """Combine the arguments.""" raise NotImplementedError - def map_index_lambda(self, expr: IndexLambda) -> ResultT: - return self.combine(*(self.rec(bnd) - for _, bnd in sorted(expr.bindings.items())), - *self.rec_idx_or_size_tuple(expr.shape)) + def map_index_lambda( + self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *( + self.rec(bnd, *args, **kwargs) + for _, bnd in sorted(expr.bindings.items())), + *self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs)) - def map_placeholder(self, expr: Placeholder) -> ResultT: - return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_placeholder( + self, expr: Placeholder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs)) - def map_data_wrapper(self, expr: DataWrapper) -> ResultT: - return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_data_wrapper( + self, expr: DataWrapper, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs)) - def map_stack(self, expr: Stack) -> ResultT: - return self.combine(*(self.rec(ary) - for ary in expr.arrays)) + def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *( + self.rec(ary, *args, **kwargs) + for ary in expr.arrays)) - def map_roll(self, expr: Roll) -> ResultT: - return self.combine(self.rec(expr.array)) + def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(expr.array, *args, **kwargs)) - def map_axis_permutation(self, expr: AxisPermutation) -> ResultT: - return self.combine(self.rec(expr.array)) + def map_axis_permutation( + self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(self.rec(expr.array, *args, **kwargs)) - def _map_index_base(self, expr: IndexBase) -> ResultT: - return self.combine(self.rec(expr.array), - *self.rec_idx_or_size_tuple(expr.indices)) + def _map_index_base( + self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + self.rec(expr.array, *args, **kwargs), + *self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs)) - def map_basic_index(self, expr: BasicIndex) -> ResultT: - return self._map_index_base(expr) + def map_basic_index( + self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self._map_index_base(expr, *args, **kwargs) - def map_contiguous_advanced_index(self, - expr: AdvancedIndexInContiguousAxes - ) -> ResultT: - return self._map_index_base(expr) + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes, *args: P.args, + **kwargs: P.kwargs) -> ResultT: + return self._map_index_base(expr, *args, **kwargs) - def map_non_contiguous_advanced_index(self, - expr: AdvancedIndexInNoncontiguousAxes - ) -> ResultT: - return self._map_index_base(expr) + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes, *args: P.args, + **kwargs: P.kwargs) -> ResultT: + return self._map_index_base(expr, *args, **kwargs) - def map_reshape(self, expr: Reshape) -> ResultT: + def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( - self.rec(expr.array), - *self.rec_idx_or_size_tuple(expr.newshape)) - - def map_concatenate(self, expr: Concatenate) -> ResultT: - return self.combine(*(self.rec(ary) - for ary in expr.arrays)) - - def map_einsum(self, expr: Einsum) -> ResultT: - return self.combine(*(self.rec(ary) - for ary in expr.args)) + self.rec(expr.array, *args, **kwargs), + *self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs)) - def map_csr_matmul(self, expr: CSRMatmul) -> ResultT: + def map_concatenate( + self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( - self.rec(expr.matrix.elem_values), - self.rec(expr.matrix.elem_col_indices), - self.rec(expr.matrix.row_starts), - self.rec(expr.array)) + *( + self.rec(ary, *args, **kwargs) + for ary in expr.arrays)) - def map_named_array(self, expr: NamedArray) -> ResultT: - return self.combine(self.rec(expr._container)) + def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *( + self.rec(ary, *args, **kwargs) + for ary in expr.args)) - def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> ResultT: - return self.combine(*(self.rec(ary.expr) - for ary in expr.values())) + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + self.rec(expr.matrix.elem_values, *args, **kwargs), + self.rec(expr.matrix.elem_col_indices, *args, **kwargs), + self.rec(expr.matrix.row_starts, *args, **kwargs), + self.rec(expr.array, *args, **kwargs)) + + def map_named_array( + self, expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(expr._container, *args, **kwargs)) + + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine( + *( + self.rec(ary.expr, *args, **kwargs) + for ary in expr.values())) - def map_loopy_call(self, expr: LoopyCall) -> ResultT: - return self.combine(*(self.rec(ary) - for _, ary in sorted(expr.bindings.items()) - if isinstance(ary, Array))) + def map_loopy_call( + self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine( + *( + self.rec(ary, *args, **kwargs) + for _, ary in sorted(expr.bindings.items()) + if isinstance(ary, Array))) - def map_loopy_call_result(self, expr: LoopyCallResult) -> ResultT: - return self.rec(expr._container) + def map_loopy_call_result( + self, expr: LoopyCallResult, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(self.rec(expr._container, *args, **kwargs)) def map_distributed_send_ref_holder( - self, expr: DistributedSendRefHolder) -> ResultT: + self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine( - self.rec(expr.send.data), - self.rec(expr.passthrough_data), - ) + self.rec(expr.send.data, *args, **kwargs), + self.rec(expr.passthrough_data, *args, **kwargs)) - def map_distributed_recv(self, expr: DistributedRecv) -> ResultT: - return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_distributed_recv( + self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(*self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs)) - def map_function_definition(self, expr: FunctionDefinition) -> FunctionResultT: + def map_function_definition( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> FunctionResultT: raise NotImplementedError("Combining results from a callee expression" " is context-dependent. Derived classes" " must override map_function_definition.") - def map_call(self, expr: Call) -> ResultT: + def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError( "Mapping calls is context-dependent. Derived classes must override " "map_call.") - def map_named_call_result(self, expr: NamedCallResult) -> ResultT: - return self.rec(expr._container) + def map_named_call_result( + self, expr: NamedCallResult, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(expr._container, *args, **kwargs)) # }}} # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R, Never]): +class DependencyMapper(CombineMapper[R, Never, []]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -1479,7 +1512,8 @@ def combine(self, *args: frozenset[Array]) -> frozenset[Array]: # {{{ InputGatherer / ListOfInputsGatherer class InputGatherer( - CombineMapper[frozenset[InputArgumentBase], frozenset[InputArgumentBase]]): + CombineMapper[ + frozenset[InputArgumentBase], frozenset[InputArgumentBase], []]): """ Mapper to combine all instances of :class:`pytato.array.InputArgumentBase` that an array expression depends on. @@ -1534,7 +1568,7 @@ def map_call(self, expr: Call) -> frozenset[InputArgumentBase]: class ListOfInputsGatherer( - CombineMapper[list[InputArgumentBase], list[InputArgumentBase]]): + CombineMapper[list[InputArgumentBase], list[InputArgumentBase], []]): """ Mapper to combine all instances of :class:`pytato.array.InputArgumentBase` that an array expression depends on, preserving duplicates. @@ -1601,7 +1635,7 @@ def map_call(self, expr: Call) -> list[InputArgumentBase]: # {{{ SizeParamGatherer class SizeParamGatherer( - CombineMapper[frozenset[SizeParam], frozenset[SizeParam]]): + CombineMapper[frozenset[SizeParam], frozenset[SizeParam], []]): """ Mapper to combine all instances of :class:`pytato.array.SizeParam` that an array expression depends on. diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 4310704ee..36aab2391 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -115,9 +115,8 @@ # {{{ BindingSubscriptsCollector -class BindingSubscriptsCollector(CombineMapper[dict[BindingName, - set[tuple[Expression, ...]]], - []]): +class BindingSubscriptsCollector( + CombineMapper[dict[BindingName, set[tuple[Expression, ...]]], []]): """ Return all the subscript expressions used by a variable specified by BindingName. Ex: