Skip to content

Trim non-pipeline input grads before caching in bwd_cache#431

Open
aditvenk wants to merge 1 commit into
mainfrom
user/avenkataraman/pp-fix
Open

Trim non-pipeline input grads before caching in bwd_cache#431
aditvenk wants to merge 1 commit into
mainfrom
user/avenkataraman/pp-fix

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

The backward graph may produce gradients for inputs beyond the pipeline activations (e.g. labels when loss is fused into the last stage). get_bwd_send_ops zips bwd_cache with grad_send_info using strict=True, and grad_send_info only has entries for pipeline activation inputs, so extra grads cause a ValueError.

Mirror the trimming that upstream PipelineStage does at torch/distributed/pipelining/stage.py:997.

Authored with Claude.

The backward graph may produce gradients for inputs beyond the pipeline
activations (e.g. labels when loss is fused into the last stage).
get_bwd_send_ops zips bwd_cache with grad_send_info using strict=True,
and grad_send_info only has entries for pipeline activation inputs, so
extra grads cause a ValueError.

Mirror the trimming that upstream PipelineStage does at
torch/distributed/pipelining/stage.py:997.

Authored with Claude.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 22, 2026
@aditvenk aditvenk requested a review from xmfan April 22, 2026 03:25
@aditvenk aditvenk closed this Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants