From 524deed5a6ca4f2be6ee01b9b35b50b4a09e2dc6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 2 Jun 2026 04:45:56 +0000 Subject: [PATCH 1/2] fix: tensor dispatch with TP enabled --- alto/kernels/dispatch/tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/alto/kernels/dispatch/tensor.py b/alto/kernels/dispatch/tensor.py index de96b9f..3f7c91e 100644 --- a/alto/kernels/dispatch/tensor.py +++ b/alto/kernels/dispatch/tensor.py @@ -39,6 +39,8 @@ torch.ops.aten.clone.default, torch.ops.aten.transpose.int, torch.ops.aten.t.default, + # required for TP - scatter_ is used to distribute weights + torch.ops.c10d.scatter_.default, } gemm_ops = ("linear", "mm.default", "matmul.default", "addmm.default", "matmul") @@ -111,6 +113,8 @@ def unwrap(t): # detach is special case if func == torch.ops.aten.detach.default: return cls(args_unwrapped[0], config) + elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm": + return func(*args, **kwargs) # perform op out = func(*args_unwrapped, **kwargs_unwrapped) From 1891b5b104929c321902e8b6f52a05fbf5bf001b Mon Sep 17 00:00:00 2001 From: Hann Wang Date: Tue, 2 Jun 2026 13:00:41 +0800 Subject: [PATCH 2/2] Delegate to the subclass' function Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- alto/kernels/dispatch/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/alto/kernels/dispatch/tensor.py b/alto/kernels/dispatch/tensor.py index 3f7c91e..a13d8a5 100644 --- a/alto/kernels/dispatch/tensor.py +++ b/alto/kernels/dispatch/tensor.py @@ -114,7 +114,9 @@ def unwrap(t): if func == torch.ops.aten.detach.default: return cls(args_unwrapped[0], config) elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm": - return func(*args, **kwargs) + # Delegate to the subclass' GEMM / grouped_mm overrides without + # unwrapping the wrapper tensor, avoiding __torch_dispatch__ recursion. + return cls.__torch_function__(func, types, args, kwargs or {}) # perform op out = func(*args_unwrapped, **kwargs_unwrapped)