diff --git a/alto/kernels/dispatch/tensor.py b/alto/kernels/dispatch/tensor.py index de96b9f..a13d8a5 100644 --- a/alto/kernels/dispatch/tensor.py +++ b/alto/kernels/dispatch/tensor.py @@ -39,6 +39,8 @@ torch.ops.aten.clone.default, torch.ops.aten.transpose.int, torch.ops.aten.t.default, + # required for TP - scatter_ is used to distribute weights + torch.ops.c10d.scatter_.default, } gemm_ops = ("linear", "mm.default", "matmul.default", "addmm.default", "matmul") @@ -111,6 +113,10 @@ def unwrap(t): # detach is special case if func == torch.ops.aten.detach.default: return cls(args_unwrapped[0], config) + elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm": + # Delegate to the subclass' GEMM / grouped_mm overrides without + # unwrapping the wrapper tensor, avoiding __torch_dispatch__ recursion. + return cls.__torch_function__(func, types, args, kwargs or {}) # perform op out = func(*args_unwrapped, **kwargs_unwrapped)