Skip to content

Commit 8382ac8

Browse files
committed
add transform_dag implementation to inline functions for pytato JAX array context
1 parent 83c44de commit 8382ac8

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
817817
An arraycontext that uses :mod:`pytato` to represent the thawed state of
818818
the arrays and compiles the expressions using
819819
:class:`pytato.target.python.JAXPythonTarget`.
820+
821+
822+
.. automethod:: transform_dag
820823
"""
821824

822825
def __init__(self,
@@ -967,6 +970,17 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
967970
from .compile import LazilyJAXCompilingFunctionCaller
968971
return LazilyJAXCompilingFunctionCaller(self, f)
969972

973+
def transform_dag(self, dag: pytato.DictOfNamedArrays
974+
) -> pytato.DictOfNamedArrays:
975+
import pytato as pt
976+
977+
# FIXME: Having to use _verify_is_dag seems clunky, but I'm not sure how to
978+
# avoid it
979+
from .utils import _verify_is_dag
980+
dag = _verify_is_dag(pt.tag_all_calls_to_be_inlined(dag))
981+
dag = _verify_is_dag(pt.inline_calls(dag))
982+
return dag
983+
970984
def tag(self, tags: ToTagSetConvertible, array):
971985
def _tag(ary):
972986
import jax.numpy as jnp

0 commit comments

Comments
 (0)