1010^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
1212.. autofunction:: tabulate_profiling_data
13+
14+ References
15+ ^^^^^^^^^^
16+
17+ .. autoclass:: ArrayOrNamesTc
18+
19+ A constrained type variable binding to either
20+ :class:`pytato.Array` or :class:`pytato.AbstractResultWithNames`.
1321"""
1422
1523
4048
4149from typing import TYPE_CHECKING , Any , cast
4250
51+ from typing_extensions import override
52+
4353import pytools
4454from pytato .analysis import get_num_call_sites
4555from pytato .array import (
46- AbstractResultWithNamedArrays ,
4756 Array ,
4857 Axis as PtAxis ,
4958 DataWrapper ,
50- DictOfNamedArrays ,
5159 Placeholder ,
5260 SizeParam ,
5361 make_placeholder ,
5462)
55- from pytato .function import FunctionDefinition
5663from pytato .target .loopy import LoopyPyOpenCLTarget
5764from pytato .transform import (
5865 ArrayOrNames ,
66+ ArrayOrNamesTc ,
5967 CopyMapper ,
6068 TransformMapperCache ,
6169 deduplicate ,
6977 from collections .abc import Mapping
7078
7179 import loopy as lp
80+ from pytato import AbstractResultWithNamedArrays
81+ from pytato .function import FunctionDefinition
7282
7383 from arraycontext import ArrayContext
7484 from arraycontext .container import SerializationKey
@@ -98,6 +108,7 @@ def __init__(
98108 self .vng = UniqueNameGenerator ()
99109 self .seen_inputs : set [str ] = set ()
100110
111+ @override
101112 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
102113 if expr .name is not None :
103114 if expr .name in self .seen_inputs :
@@ -119,13 +130,16 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
119130 axes = expr .axes ,
120131 tags = expr .tags )
121132
133+ @override
122134 def map_size_param (self , expr : SizeParam ) -> Array :
123135 raise NotImplementedError
124136
137+ @override
125138 def map_placeholder (self , expr : Placeholder ) -> Array :
126139 raise ValueError ("Placeholders cannot appear in"
127140 " DatawrapperToBoundPlaceholderMapper." )
128141
142+ @override
129143 def map_function_definition (
130144 self , expr : FunctionDefinition ) -> FunctionDefinition :
131145 raise ValueError ("Function definitions cannot appear in"
@@ -135,8 +149,8 @@ def map_function_definition(
135149# FIXME: This strategy doesn't work if the DAG has functions, since function
136150# definitions can't contain non-argument placeholders
137151def _normalize_pt_expr (
138- expr : DictOfNamedArrays
139- ) -> tuple [Array | AbstractResultWithNamedArrays , Mapping [str , Any ]]:
152+ expr : AbstractResultWithNamedArrays
153+ ) -> tuple [AbstractResultWithNamedArrays , Mapping [str , Any ]]:
140154 """
141155 Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
142156 normalized form of *expr*, with all instances of
@@ -155,7 +169,6 @@ def _normalize_pt_expr(
155169
156170 normalize_mapper = _DatawrapperToBoundPlaceholderMapper ()
157171 normalized_expr = normalize_mapper (expr )
158- assert isinstance (normalized_expr , AbstractResultWithNamedArrays )
159172 return normalized_expr , normalize_mapper .bound_arguments
160173
161174
@@ -193,6 +206,7 @@ def __init__(self, actx: ArrayContext) -> None:
193206 super ().__init__ ()
194207 self .actx = actx
195208
209+ @override
196210 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
197211 import numpy as np
198212
@@ -225,6 +239,7 @@ def __init__(self, actx: ArrayContext) -> None:
225239 super ().__init__ ()
226240 self .actx = actx
227241
242+ @override
228243 def map_data_wrapper (self , expr : DataWrapper ) -> Array :
229244 import numpy as np
230245
@@ -244,15 +259,15 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
244259 non_equality_tags = expr .non_equality_tags )
245260
246261
247- def transfer_from_numpy (expr : ArrayOrNames , actx : ArrayContext ) -> ArrayOrNames :
262+ def transfer_from_numpy (expr : ArrayOrNamesTc , actx : ArrayContext ) -> ArrayOrNamesTc :
248263 """Transfer arrays contained in :class:`~pytato.array.DataWrapper`
249264 instances to be device arrays, using
250265 :meth:`~arraycontext.ArrayContext.from_numpy`.
251266 """
252267 return TransferFromNumpyMapper (actx )(expr )
253268
254269
255- def transfer_to_numpy (expr : ArrayOrNames , actx : ArrayContext ) -> ArrayOrNames :
270+ def transfer_to_numpy (expr : ArrayOrNamesTc , actx : ArrayContext ) -> ArrayOrNamesTc :
256271 """Transfer arrays contained in :class:`~pytato.array.DataWrapper`
257272 instances to be :class:`numpy.ndarray` instances, using
258273 :meth:`~arraycontext.ArrayContext.to_numpy`.
0 commit comments