Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions alto/kernels/dispatch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
torch.ops.aten.clone.default,
torch.ops.aten.transpose.int,
torch.ops.aten.t.default,
# required for TP - scatter_ is used to distribute weights
torch.ops.c10d.scatter_.default,
Comment on lines +42 to +43
}

gemm_ops = ("linear", "mm.default", "matmul.default", "addmm.default", "matmul")
Expand Down Expand Up @@ -111,6 +113,10 @@ def unwrap(t):
# detach is special case
if func == torch.ops.aten.detach.default:
return cls(args_unwrapped[0], config)
elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm":
# Delegate to the subclass' GEMM / grouped_mm overrides without
# unwrapping the wrapper tensor, avoiding __torch_dispatch__ recursion.
return cls.__torch_function__(func, types, args, kwargs or {})
Comment on lines +116 to +119
Comment on lines +116 to +119

# perform op
out = func(*args_unwrapped, **kwargs_unwrapped)
Expand Down