Fix token-weighted loss under gradient accumulation in finetune.py#1736
Fix token-weighted loss under gradient accumulation in finetune.py#1736Chessing234 wants to merge 2 commits into
Conversation
When gradient_accumulation_steps > 1, scale each microbatch loss by its supervised token count and renormalize gradients so all tokens are weighted equally instead of averaging per-microbatch means. Fixes allenai#1728 Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Code Review
This pull request introduces token-weighted loss under gradient accumulation in finetune.py to replace the averaging of per-microbatch means. The review feedback identifies two critical issues with the proposed implementation: first, the gradient scaling loop is incompatible with DeepSpeed because parameter gradients are managed internally, which can lead to gradient explosion; second, using local token counts for scaling in a distributed (DDP) environment will cause gradient divergence across ranks. The reviewer suggests disabling this scaling when DeepSpeed is active and using accelerator.reduce to compute a global token count for correct gradient scaling across ranks.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if args.sequence_parallel_size == 1 and args.gradient_accumulation_steps > 1: | ||
| pred_tokens_f = pred_tokens_in_batch.float() | ||
| loss = loss * pred_tokens_f | ||
| accum_pred_tokens += pred_tokens_f |
There was a problem hiding this comment.
Critical Issue: DeepSpeed Compatibility & Gradient Explosion
Under DeepSpeed ZeRO-2/ZeRO-3, parameter gradients are partitioned and reduced on-the-fly during the backward pass, meaning param.grad is typically None (or managed internally) after accelerator.backward().
As a result, the gradient scaling loop at the end of the accumulation step (lines 857-859) is silently skipped. Since the loss is multiplied by pred_tokens_in_batch here, but never divided back by the accumulated tokens, the gradients will be ~1000x to 8000x larger than they should be, leading to immediate gradient explosion and NaN loss.
To prevent this, we should disable this token-weighted loss scaling when DeepSpeed is active and fall back to the standard behavior.
| if args.sequence_parallel_size == 1 and args.gradient_accumulation_steps > 1: | |
| pred_tokens_f = pred_tokens_in_batch.float() | |
| loss = loss * pred_tokens_f | |
| accum_pred_tokens += pred_tokens_f | |
| if ( | |
| args.sequence_parallel_size == 1 | |
| and args.gradient_accumulation_steps > 1 | |
| and accelerator.state.deepspeed_plugin is None | |
| ): | |
| pred_tokens_f = pred_tokens_in_batch.float() | |
| loss = loss * pred_tokens_f | |
| accum_pred_tokens += pred_tokens_f |
| if ( | ||
| accelerator.sync_gradients | ||
| and args.sequence_parallel_size == 1 | ||
| and args.gradient_accumulation_steps > 1 | ||
| ): | ||
| grad_scale = args.gradient_accumulation_steps / torch.clamp(accum_pred_tokens, min=1.0) | ||
| for param in model.parameters(): | ||
| if param.grad is not None: | ||
| param.grad.mul_(grad_scale) | ||
| accum_pred_tokens.zero_() |
There was a problem hiding this comment.
Critical Issue: DDP Gradient Divergence & Correct Scaling
There are two major issues with the current gradient scaling logic:
- DDP Gradient Divergence:
accum_pred_tokensis accumulated locally on each rank. Since different ranks process different numbers of tokens,accum_pred_tokenswill differ across ranks. Because DDP averages gradients across ranks duringaccelerator.backward(), multiplying the synchronized gradients by a localgrad_scalewill cause the gradients (and thus model weights) to immediately diverge across ranks, ruining training. - DeepSpeed Compatibility: As noted above, we must skip this scaling when DeepSpeed is active to avoid operating on
Nonegradients.
To fix the DDP divergence, we must compute the global sum of accumulated tokens across all ranks using accelerator.reduce(..., reduction="sum"), and then scale the averaged gradients by (gradient_accumulation_steps * num_processes) / global_accum_pred_tokens.
| if ( | |
| accelerator.sync_gradients | |
| and args.sequence_parallel_size == 1 | |
| and args.gradient_accumulation_steps > 1 | |
| ): | |
| grad_scale = args.gradient_accumulation_steps / torch.clamp(accum_pred_tokens, min=1.0) | |
| for param in model.parameters(): | |
| if param.grad is not None: | |
| param.grad.mul_(grad_scale) | |
| accum_pred_tokens.zero_() | |
| if ( | |
| accelerator.sync_gradients | |
| and args.sequence_parallel_size == 1 | |
| and args.gradient_accumulation_steps > 1 | |
| and accelerator.state.deepspeed_plugin is None | |
| ): | |
| global_accum_pred_tokens = accelerator.reduce(accum_pred_tokens, reduction="sum") | |
| grad_scale = (args.gradient_accumulation_steps * accelerator.num_processes) / torch.clamp(global_accum_pred_tokens, min=1.0) | |
| for param in model.parameters(): | |
| if param.grad is not None: | |
| param.grad.mul_(grad_scale) | |
| accum_pred_tokens.zero_() |
Summary
gradient_accumulation_steps > 1, scale each microbatch loss by its supervised token count and renormalize accumulated gradients so all tokens are weighted equally (matching the sequence-parallel path and the oldreduce_loss=sumbehavior).reduce_loss=sumremoved in #1024 without deprecation warning: LR miscalibration and mean-of-microbatch-means under gradient accumulation #1728.Fixes #1728
Test plan
gradient_accumulation_steps > 1and verify loss/gradients match token-weighted expectationgradient_accumulation_steps == 1GPU_TESTS=bypass
Made with Cursor