Skip to content

fix: tensor dispatch with TP enabled#23

Merged
hann-wang merged 2 commits into
mainfrom
han/fix-tp-dispatch
Jun 2, 2026
Merged

fix: tensor dispatch with TP enabled#23
hann-wang merged 2 commits into
mainfrom
han/fix-tp-dispatch

Conversation

@hann-wang

@hann-wang hann-wang commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

When TP enabled, mm/addmm is dispatch via __torch_dispatch__ not __torch_function__. We have to manually call the func to make sure LPT kernels are correctly invoked.

Note: this is a temp fix!!! Autograd backward is not working in __torch_dispatch__.

Copilot AI review requested due to automatic review settings June 2, 2026 04:51

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make the training weight wrapper tensor dispatch behave correctly when Tensor Parallelism (TP) is enabled, so that distributed weight movement and GEMM/grouped GEMM paths don’t accidentally drop the wrapper semantics needed for low-precision routing.

Changes:

  • Preserve the wrapper subclass across c10d.scatter_ (used by TP to distribute weights).
  • Special-case GEMM-like ops and _grouped_mm in __torch_dispatch__ to avoid the generic unwrap-and-call path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread alto/kernels/dispatch/tensor.py Outdated
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings June 2, 2026 05:00
@hann-wang hann-wang merged commit 47b4886 into main Jun 2, 2026
@hann-wang hann-wang deleted the han/fix-tp-dispatch branch June 2, 2026 05:06

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 1 out of 1 changed files in this pull request and generated 3 comments.

Comment on lines +116 to +119
elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm":
# Delegate to the subclass' GEMM / grouped_mm overrides without
# unwrapping the wrapper tensor, avoiding __torch_dispatch__ recursion.
return cls.__torch_function__(func, types, args, kwargs or {})
Comment on lines +116 to +119
elif func.__name__ in gemm_ops or func.__name__ == "_grouped_mm":
# Delegate to the subclass' GEMM / grouped_mm overrides without
# unwrapping the wrapper tensor, avoiding __torch_dispatch__ recursion.
return cls.__torch_function__(func, types, args, kwargs or {})
Comment on lines +42 to +43
# required for TP - scatter_ is used to distribute weights
torch.ops.c10d.scatter_.default,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants