Skip to content

Commit 0655a0c

Browse files
committed
add ArrayOrNamesOrFunctionDef
1 parent 666ca2a commit 0655a0c

6 files changed

Lines changed: 44 additions & 32 deletions

File tree

pytato/analysis/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pytato.tags import ImplStored
6262
from pytato.transform import (
6363
ArrayOrNames,
64+
ArrayOrNamesOrFunctionDef,
6465
ArrayOrNamesTc,
6566
CachedWalkMapper,
6667
CombineMapper,
@@ -362,7 +363,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
362363

363364
class ListOfDirectPredecessorsGetter(
364365
Mapper[
365-
list[ArrayOrNames | FunctionDefinition],
366+
list[ArrayOrNamesOrFunctionDef],
366367
list[ArrayOrNames],
367368
[]]):
368369
"""
@@ -445,8 +446,8 @@ def map_distributed_send_ref_holder(self,
445446
return [expr.send.data, expr.passthrough_data]
446447

447448
def map_call(
448-
self, expr: Call) -> list[ArrayOrNames | FunctionDefinition]:
449-
result: list[ArrayOrNames | FunctionDefinition] = []
449+
self, expr: Call) -> list[ArrayOrNamesOrFunctionDef]:
450+
result: list[ArrayOrNamesOrFunctionDef] = []
450451
if self.include_functions:
451452
result.append(expr.function)
452453
result += list(expr.bindings.values())
@@ -483,7 +484,7 @@ def __init__(self, *, include_functions: bool = False) -> None:
483484
@overload
484485
def __call__(
485486
self, expr: ArrayOrNames
486-
) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]:
487+
) -> FrozenOrderedSet[ArrayOrNamesOrFunctionDef]:
487488
...
488489

489490
@overload
@@ -492,9 +493,9 @@ def __call__(self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]:
492493

493494
def __call__(
494495
self,
495-
expr: ArrayOrNames | FunctionDefinition,
496+
expr: ArrayOrNamesOrFunctionDef,
496497
) -> (
497-
FrozenOrderedSet[ArrayOrNames | FunctionDefinition]
498+
FrozenOrderedSet[ArrayOrNamesOrFunctionDef]
498499
| FrozenOrderedSet[ArrayOrNames]):
499500
"""Get the direct predecessors of *expr*."""
500501
return FrozenOrderedSet(self._pred_getter(expr))
@@ -543,7 +544,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self:
543544
_visited_functions=self._visited_functions)
544545

545546
@override
546-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
547+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
547548
if not isinstance(expr, DictOfNamedArrays):
548549
self.expr_type_counts[type(expr)] += 1
549550

@@ -606,7 +607,7 @@ def __init__(self, _visited_functions: set[Any] | None = None) -> None:
606607
super().__init__(_visited_functions=_visited_functions)
607608

608609
self.expr_multiplicity_counts: \
609-
dict[ArrayOrNames | FunctionDefinition, int] = defaultdict(int)
610+
dict[ArrayOrNamesOrFunctionDef, int] = defaultdict(int)
610611

611612
@override
612613
def get_cache_key(self, expr: ArrayOrNames) -> int:
@@ -619,13 +620,13 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
619620
return id(expr)
620621

621622
@override
622-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
623+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
623624
if not isinstance(expr, DictOfNamedArrays):
624625
self.expr_multiplicity_counts[expr] += 1
625626

626627

627628
def get_node_multiplicities(
628-
outputs: ArrayOrNames) -> dict[ArrayOrNames | FunctionDefinition, int]:
629+
outputs: ArrayOrNames) -> dict[ArrayOrNamesOrFunctionDef, int]:
629630
"""
630631
Returns the multiplicity per `expr`.
631632
"""
@@ -662,7 +663,7 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int:
662663
return id(expr)
663664

664665
@override
665-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
666+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
666667
if isinstance(expr, Call):
667668
self.count += 1
668669

@@ -884,7 +885,7 @@ def map_call(self, expr: Call) -> None:
884885
f"{type(self).__name__} does not support functions.")
885886

886887
@override
887-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
888+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
888889
if not is_materialized(expr):
889890
return
890891
assert isinstance(expr, Array)
@@ -978,7 +979,7 @@ def map_call(self, expr: Call) -> None:
978979
f"{type(self).__name__} does not support functions.")
979980

980981
@override
981-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
982+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
982983
if not is_materialized(expr) or not has_taggable_materialization(expr):
983984
return
984985
assert isinstance(expr, Array)

pytato/distributed/verify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@
5555
DistributedGraphPartition,
5656
PartId,
5757
)
58-
from pytato.transform import ArrayOrNames, CachedWalkMapper
58+
from pytato.transform import (
59+
ArrayOrNames,
60+
ArrayOrNamesOrFunctionDef,
61+
CachedWalkMapper,
62+
)
5963

6064

6165
logger = logging.getLogger(__name__)
@@ -68,7 +72,6 @@
6872
import numpy as np
6973

7074
from pytato.distributed.nodes import CommTagType, DistributedRecv
71-
from pytato.function import FunctionDefinition
7275

7376

7477
# {{{ data structures
@@ -156,7 +159,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> int:
156159
return id(expr)
157160

158161
@override
159-
def visit(self, expr: ArrayOrNames | FunctionDefinition) -> bool:
162+
def visit(self, expr: ArrayOrNamesOrFunctionDef) -> bool:
160163
super().visit(expr)
161164
if isinstance(expr, ArrayOrNames):
162165
self.seen_nodes.add(expr)

pytato/equality.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464

6565
ArrayOrNames = Array | AbstractResultWithNamedArrays
66+
ArrayOrNamesOrFunctionDef = \
67+
Array | AbstractResultWithNamedArrays | FunctionDefinition
6668

6769

6870
# {{{ EqualityComparer
@@ -87,7 +89,7 @@ def __init__(self) -> None:
8789
# Uses the same cache for both arrays and functions
8890
self._cache: dict[tuple[int, int], bool] = {}
8991

90-
def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool:
92+
def rec(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool:
9193
# These cases are simple enough that they don't need to be cached
9294
if expr1 is expr2:
9395
return True
@@ -119,7 +121,7 @@ def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool:
119121
self._cache[cache_key] = result
120122
return result
121123

122-
def __call__(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool:
124+
def __call__(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool:
123125
return self.rec(expr1, expr2)
124126

125127
def handle_unsupported_array(self, expr1: Array,

pytato/transform/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989

9090

9191
ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays
92+
ArrayOrNamesOrFunctionDef: TypeAlias = \
93+
Array | AbstractResultWithNamedArrays | FunctionDefinition
9294
ArrayOrNamesTc = TypeVar("ArrayOrNamesTc",
9395
Array, AbstractResultWithNamedArrays, DictOfNamedArrays)
9496
ArrayOrNamesOrFunctionDefTc = TypeVar("ArrayOrNamesOrFunctionDefTc",
@@ -150,6 +152,7 @@
150152
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
151153
152154
.. class:: ArrayOrNames
155+
.. class:: ArrayOrNamesOrFunctionDef
153156
154157
.. class:: ArrayOrNamesTc
155158
@@ -307,7 +310,7 @@ def __call__(
307310

308311
def __call__(
309312
self,
310-
expr: ArrayOrNames | FunctionDefinition,
313+
expr: ArrayOrNamesOrFunctionDef,
311314
*args: P.args,
312315
**kwargs: P.kwargs) -> ResultT | FunctionResultT:
313316
"""Handle the mapping of *expr*."""
@@ -1569,7 +1572,7 @@ def clone_for_callee(
15691572
return type(self)()
15701573

15711574
def visit(
1572-
self, expr: ArrayOrNames | FunctionDefinition,
1575+
self, expr: ArrayOrNamesOrFunctionDef,
15731576
*args: P.args, **kwargs: P.kwargs) -> bool:
15741577
"""
15751578
If this method returns *True*, *expr* is traversed during the walk.
@@ -1579,7 +1582,7 @@ def visit(
15791582
return True
15801583

15811584
def post_visit(
1582-
self, expr: ArrayOrNames | FunctionDefinition,
1585+
self, expr: ArrayOrNamesOrFunctionDef,
15831586
*args: P.args, **kwargs: P.kwargs) -> None:
15841587
"""
15851588
Callback after *expr* has been traversed.
@@ -1841,7 +1844,7 @@ def __init__(
18411844
def get_cache_key(self, expr: ArrayOrNames) -> int:
18421845
return id(expr)
18431846

1844-
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
1847+
def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None:
18451848
if isinstance(expr, Array):
18461849
self.topological_order.append(expr)
18471850

pytato/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,14 @@
6161
ScalarExpression,
6262
TypeCast,
6363
)
64-
from pytato.transform import ArrayOrNames, CachedMapper
64+
from pytato.transform import ArrayOrNamesOrFunctionDef, CachedMapper
6565

6666

6767
if TYPE_CHECKING:
6868
from collections.abc import Callable, Iterable, Sequence
6969

7070
from pytools.tag import Tag
7171

72-
from pytato.function import FunctionDefinition
73-
7472

7573
__doc__ = """
7674
Helper routines
@@ -741,7 +739,7 @@ def get_einsum_specification(expr: Einsum) -> str:
741739
return f"{','.join(input_specs)}->{output_spec}"
742740

743741

744-
def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool:
742+
def is_materialized(expr: ArrayOrNamesOrFunctionDef) -> bool:
745743
"""Returns *True* if *expr* is materialized."""
746744
from pytato.array import InputArgumentBase
747745
from pytato.distributed.nodes import DistributedRecv
@@ -757,7 +755,7 @@ def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool:
757755
DistributedRecv)))
758756

759757

760-
def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool:
758+
def has_taggable_materialization(expr: ArrayOrNamesOrFunctionDef) -> bool:
761759
"""
762760
Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to
763761
determine whether or not it is materialized.

pytato/visualization/dot.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@
5959
)
6060
from pytato.function import Call, FunctionDefinition, NamedCallResult
6161
from pytato.tags import FunctionIdentifier
62-
from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer
62+
from pytato.transform import (
63+
ArrayOrNames,
64+
ArrayOrNamesOrFunctionDef,
65+
CachedMapper,
66+
InputGatherer,
67+
)
6368

6469

6570
if TYPE_CHECKING:
@@ -160,7 +165,7 @@ def emit_subgraph(sg: _SubgraphTree) -> None:
160165
class _DotNodeInfo:
161166
title: str
162167
fields: dict[str, Any]
163-
edges: dict[str, ArrayOrNames | FunctionDefinition]
168+
edges: dict[str, ArrayOrNamesOrFunctionDef]
164169

165170

166171
def stringify_tags(tags: frozenset[Tag | None]) -> str:
@@ -193,7 +198,7 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
193198
"non_equality_tags": expr.non_equality_tags,
194199
}
195200

196-
edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
201+
edges: dict[str, ArrayOrNamesOrFunctionDef] = {}
197202
return _DotNodeInfo(title, fields, edges)
198203

199204
# type-ignore-reason: incompatible with supertype
@@ -297,7 +302,7 @@ def map_einsum(self, expr: Einsum) -> None:
297302
self.node_to_dot[expr] = info
298303

299304
def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
300-
edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
305+
edges: dict[str, ArrayOrNamesOrFunctionDef] = {}
301306
for name, val in expr._data.items():
302307
edges[name] = val
303308
self.rec(val)
@@ -308,7 +313,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
308313
edges=edges)
309314

310315
def map_loopy_call(self, expr: LoopyCall) -> None:
311-
edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
316+
edges: dict[str, ArrayOrNamesOrFunctionDef] = {}
312317
for name, arg in expr.bindings.items():
313318
if isinstance(arg, Array):
314319
edges[name] = arg

0 commit comments

Comments
 (0)