Skip to content

Fix token-weighted loss under gradient accumulation in finetune.py#1736

Open
Chessing234 wants to merge 2 commits into
allenai:mainfrom
Chessing234:Chessing234/fix/token-weighted-loss-gradient-accumulation
Open

Fix token-weighted loss under gradient accumulation in finetune.py#1736
Chessing234 wants to merge 2 commits into
allenai:mainfrom
Chessing234:Chessing234/fix/token-weighted-loss-gradient-accumulation

Conversation

@Chessing234

Copy link
Copy Markdown
Contributor

Summary

Fixes #1728

Test plan

  • Run SFT with gradient_accumulation_steps > 1 and verify loss/gradients match token-weighted expectation
  • Confirm no change when gradient_accumulation_steps == 1

GPU_TESTS=bypass

Made with Cursor

Chessing234 and others added 2 commits June 24, 2026 17:11
When gradient_accumulation_steps > 1, scale each microbatch loss by its
supervised token count and renormalize gradients so all tokens are weighted
equally instead of averaging per-microbatch means.

Fixes allenai#1728

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces token-weighted loss under gradient accumulation in finetune.py to replace the averaging of per-microbatch means. The review feedback identifies two critical issues with the proposed implementation: first, the gradient scaling loop is incompatible with DeepSpeed because parameter gradients are managed internally, which can lead to gradient explosion; second, using local token counts for scaling in a distributed (DDP) environment will cause gradient divergence across ranks. The reviewer suggests disabling this scaling when DeepSpeed is active and using accelerator.reduce to compute a global token count for correct gradient scaling across ranks.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/finetune.py
Comment on lines +826 to +829
if args.sequence_parallel_size == 1 and args.gradient_accumulation_steps > 1:
pred_tokens_f = pred_tokens_in_batch.float()
loss = loss * pred_tokens_f
accum_pred_tokens += pred_tokens_f

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Issue: DeepSpeed Compatibility & Gradient Explosion

Under DeepSpeed ZeRO-2/ZeRO-3, parameter gradients are partitioned and reduced on-the-fly during the backward pass, meaning param.grad is typically None (or managed internally) after accelerator.backward().

As a result, the gradient scaling loop at the end of the accumulation step (lines 857-859) is silently skipped. Since the loss is multiplied by pred_tokens_in_batch here, but never divided back by the accumulated tokens, the gradients will be ~1000x to 8000x larger than they should be, leading to immediate gradient explosion and NaN loss.

To prevent this, we should disable this token-weighted loss scaling when DeepSpeed is active and fall back to the standard behavior.

Suggested change
if args.sequence_parallel_size == 1 and args.gradient_accumulation_steps > 1:
pred_tokens_f = pred_tokens_in_batch.float()
loss = loss * pred_tokens_f
accum_pred_tokens += pred_tokens_f
if (
args.sequence_parallel_size == 1
and args.gradient_accumulation_steps > 1
and accelerator.state.deepspeed_plugin is None
):
pred_tokens_f = pred_tokens_in_batch.float()
loss = loss * pred_tokens_f
accum_pred_tokens += pred_tokens_f

Comment thread open_instruct/finetune.py
Comment on lines +851 to +860
if (
accelerator.sync_gradients
and args.sequence_parallel_size == 1
and args.gradient_accumulation_steps > 1
):
grad_scale = args.gradient_accumulation_steps / torch.clamp(accum_pred_tokens, min=1.0)
for param in model.parameters():
if param.grad is not None:
param.grad.mul_(grad_scale)
accum_pred_tokens.zero_()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Issue: DDP Gradient Divergence & Correct Scaling

There are two major issues with the current gradient scaling logic:

  1. DDP Gradient Divergence: accum_pred_tokens is accumulated locally on each rank. Since different ranks process different numbers of tokens, accum_pred_tokens will differ across ranks. Because DDP averages gradients across ranks during accelerator.backward(), multiplying the synchronized gradients by a local grad_scale will cause the gradients (and thus model weights) to immediately diverge across ranks, ruining training.
  2. DeepSpeed Compatibility: As noted above, we must skip this scaling when DeepSpeed is active to avoid operating on None gradients.

To fix the DDP divergence, we must compute the global sum of accumulated tokens across all ranks using accelerator.reduce(..., reduction="sum"), and then scale the averaged gradients by (gradient_accumulation_steps * num_processes) / global_accum_pred_tokens.

Suggested change
if (
accelerator.sync_gradients
and args.sequence_parallel_size == 1
and args.gradient_accumulation_steps > 1
):
grad_scale = args.gradient_accumulation_steps / torch.clamp(accum_pred_tokens, min=1.0)
for param in model.parameters():
if param.grad is not None:
param.grad.mul_(grad_scale)
accum_pred_tokens.zero_()
if (
accelerator.sync_gradients
and args.sequence_parallel_size == 1
and args.gradient_accumulation_steps > 1
and accelerator.state.deepspeed_plugin is None
):
global_accum_pred_tokens = accelerator.reduce(accum_pred_tokens, reduction="sum")
grad_scale = (args.gradient_accumulation_steps * accelerator.num_processes) / torch.clamp(global_accum_pred_tokens, min=1.0)
for param in model.parameters():
if param.grad is not None:
param.grad.mul_(grad_scale)
accum_pred_tokens.zero_()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

reduce_loss=sum removed in #1024 without deprecation warning: LR miscalibration and mean-of-microbatch-means under gradient accumulation

1 participant