Skip to content

Commit 97c3208

Browse files
committed
gguf-py : handle lazy tensors in kwargs
1 parent 5d3a4a7 commit 97c3208

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

gguf-py/gguf/lazy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ def __init_subclass__(cls) -> None:
9898

9999
@staticmethod
100100
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
101-
# TODO: dict and set
102101
if isinstance(o, (list, tuple)):
103102
L = []
104103
for item in o:
105104
L.append(LazyBase._recurse_apply(item, fn))
106105
if isinstance(o, tuple):
107106
L = tuple(L)
108107
return L
108+
elif isinstance(o, dict):
109+
return {k: LazyBase._recurse_apply(v, fn) for k, v in o.items()}
109110
elif isinstance(o, LazyBase):
110111
return fn(o)
111112
else:
@@ -119,11 +120,11 @@ def wrapped_fn(*args, **kwargs):
119120
args = ((use_self,) if use_self is not None else ()) + args
120121

121122
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
122-
# TODO: maybe handle tensors in kwargs too
123+
meta_kwargs = LazyBase._recurse_apply(kwargs, lambda t: t._meta)
123124

124125
if isinstance(meta_noop, bool) and not meta_noop:
125126
try:
126-
res = fn(*meta_args, **kwargs)
127+
res = fn(*meta_args, **meta_kwargs)
127128
except NotImplementedError:
128129
# running some operations on PyTorch's Meta tensors can cause this exception
129130
res = None
@@ -159,7 +160,8 @@ def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
159160
# non-tensor return likely relies on the contents of the args
160161
# (e.g. the result of torch.equal)
161162
eager_args = cls.to_eager(args)
162-
return fn(*eager_args, **kwargs)
163+
eager_kwargs = cls.to_eager(kwargs)
164+
return fn(*eager_args, **eager_kwargs)
163165
return wrapped_fn
164166

165167
@classmethod
@@ -172,6 +174,7 @@ def simple_to_eager(_t: LazyBase) -> Any:
172174

173175
assert _t._func is not None
174176
_t._args = cls._recurse_apply(_t._args, simple_to_eager)
177+
_t._kwargs = cls._recurse_apply(_t._kwargs, simple_to_eager)
175178
_t._data = _t._func(*_t._args, **_t._kwargs)
176179
# sanity check
177180
assert _t._data is not None

0 commit comments

Comments
 (0)