@@ -98,14 +98,15 @@ def __init_subclass__(cls) -> None:
9898
9999 @staticmethod
100100 def _recurse_apply (o : Any , fn : Callable [[Any ], Any ]) -> Any :
101- # TODO: dict and set
102101 if isinstance (o , (list , tuple )):
103102 L = []
104103 for item in o :
105104 L .append (LazyBase ._recurse_apply (item , fn ))
106105 if isinstance (o , tuple ):
107106 L = tuple (L )
108107 return L
108+ elif isinstance (o , dict ):
109+ return {k : LazyBase ._recurse_apply (v , fn ) for k , v in o .items ()}
109110 elif isinstance (o , LazyBase ):
110111 return fn (o )
111112 else :
@@ -119,11 +120,11 @@ def wrapped_fn(*args, **kwargs):
119120 args = ((use_self ,) if use_self is not None else ()) + args
120121
121122 meta_args = LazyBase ._recurse_apply (args , lambda t : t ._meta )
122- # TODO: maybe handle tensors in kwargs too
123+ meta_kwargs = LazyBase . _recurse_apply ( kwargs , lambda t : t . _meta )
123124
124125 if isinstance (meta_noop , bool ) and not meta_noop :
125126 try :
126- res = fn (* meta_args , ** kwargs )
127+ res = fn (* meta_args , ** meta_kwargs )
127128 except NotImplementedError :
128129 # running some operations on PyTorch's Meta tensors can cause this exception
129130 res = None
@@ -159,7 +160,8 @@ def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
159160 # non-tensor return likely relies on the contents of the args
160161 # (e.g. the result of torch.equal)
161162 eager_args = cls .to_eager (args )
162- return fn (* eager_args , ** kwargs )
163+ eager_kwargs = cls .to_eager (kwargs )
164+ return fn (* eager_args , ** eager_kwargs )
163165 return wrapped_fn
164166
165167 @classmethod
@@ -172,6 +174,7 @@ def simple_to_eager(_t: LazyBase) -> Any:
172174
173175 assert _t ._func is not None
174176 _t ._args = cls ._recurse_apply (_t ._args , simple_to_eager )
177+ _t ._kwargs = cls ._recurse_apply (_t ._kwargs , simple_to_eager )
175178 _t ._data = _t ._func (* _t ._args , ** _t ._kwargs )
176179 # sanity check
177180 assert _t ._data is not None
0 commit comments