Add tiled GRPO lm-head loss#1735
Conversation
Add an opt-in, memory-efficient GRPO loss path (`--use_liger_grpo_loss`) that follows DeepSpeed's TiledFusedLogitsLoss pattern: the lm-head projection and scalar loss are recomputed tile-by-tile so the full-vocabulary logits are never materialized, which is a large peak-memory win for big vocabularies / long context. The tiled kernel reproduces the existing DAPO/CISPO objective exactly (`masked_mean(pg + beta * kl, mask, None, loss_denominator) * (world_size // sp)`), so a run can toggle the flag without changing the objective. Default-off; the normal loss path is unchanged when disabled. The kernel math and gradients (w.r.t. hidden states and lm-head weight) are validated on CPU against a dense reference across DAPO/CISPO, with/without ref KL, multiple shard counts, and temperatures. The DeepSpeed ZeRO-3 integration (per-tile grad reduction via ds_grad_is_ready) still needs a GPU end-to-end run. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. Feedback on the changes highlights a correctness bug in the cispo objective where the clipping fraction (clipfrac) is incorrectly logged as 0.0 because pg_losses2 is set equal to pg_losses. Additionally, a redundant .clone() call on selected_token_ids was identified in grpo_fast.py since the subsequent torch.where operation already returns a new tensor.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
torch.where already returns a new tensor, so cloning selected_token_ids first is an unnecessary allocation/copy. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The tiled path computed clipfrac as (pg_losses2 > pg_losses), which is always False for cispo (pg_losses2 == pg_losses) so clipfrac was wrongly logged as 0. Compute the clip fraction per objective: (pg2 > pg1) for dapo and (ratio > 1 + clip_higher) for cispo, matching the non-tiled path. Also avoid the redundant pg_losses2/kl_all allocations. Extend the unit test to assert clipfrac matches the dense reference (with old logprobs chosen so cispo clipfrac is non-zero). Also resolve ty type-check errors (cast backbone/lm_head, ignore the DeepSpeed runtime ds_grad_is_ready attribute) so `make quality` passes. Co-authored-by: Cursor <cursoragent@cursor.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, thereby significantly reducing peak memory usage. Feedback on the implementation identifies a critical issue with DeepSpeed ZeRO-3 gradient synchronization in TiledGRPOLMHeadLoss.backward. Because the lm_head parameters are not returned in the custom backward function, their backward hooks are never triggered during the outer backward pass, which will prevent gradient reduction across GPUs and cause silent weight divergence. A detailed solution is provided to pass and return the parameter gradients explicitly.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. The review feedback suggests several key improvements: adding a validation check to reject the unsupported combination of use_liger_grpo_loss and use_rho_correction, removing a performance-degrading torch.cuda.empty_cache() call in the training loop, retrieving lm_head parameters recursively to support PEFT adapters, and correcting # ty: ignore typos to standard # type: ignore comments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Thread the per-token truncated-importance-sampling weights (rho = pi^train_old / pi^infer_old) through tiled_grpo_lm_head_loss: scale the policy loss per token and zero clipfrac for dropped tokens, matching the non-tiled compute_grpo_loss path. _train_sample_liger_grpo now computes rho via compute_rho_correction, logs its histograms/metrics, and passes the weights to the kernel. Extends the unit test (and standalone validation) to cover the rho path against a dense reference. Co-authored-by: Cursor <cursoragent@cursor.com>
Use lm_head.parameters() (recurse=True) so wrapped/PEFT heads (e.g. LoRA) have their trainable parameters included in the ZeRO-3 grad-ready bookkeeping. Co-authored-by: Cursor <cursoragent@cursor.com>
ty honors the standard # type: ignore[unresolved-attribute] form, which is more portable than the ty-specific spelling. Co-authored-by: Cursor <cursoragent@cursor.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) to recompute the lm-head projection and loss tile-by-tile, reproducing the DAPO/CISPO objectives with significantly lower peak memory. The feedback highlights several important improvements: resolving an issue where PeftModel is not unwrapped (which would run the entire model and break training), optimizing DeepSpeed ZeRO-3 performance by gathering lm_head parameters outside the shard loop to avoid redundant communication overhead, removing a redundant requires_grad_ call on view tensors to prevent potential PyTorch errors, and parameterizing the hardcoded KL estimator index for better flexibility.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Thread kl_estimator through tiled_grpo_lm_head_loss (kl = kl_all[kl_estimator]) instead of hardcoding estimator 2, and drop the kl_estimator==2 restriction on use_liger_grpo_loss. Logs loss/kl_avg with the configured estimator. Extends the unit test (and standalone validation) to sweep estimators 0-3 against a dense reference. Co-authored-by: Cursor <cursoragent@cursor.com>
_unwrap_causal_lm now also descends through PEFT PeftModel wrappers (via get_base_model), so get_causal_lm_backbone_and_lm_head returns the real backbone and lm_head instead of treating the full causal LM as the backbone (which would materialize logits and break the tiled loss). Adds unit tests for plain, DeepSpeed-, PEFT-, and DeepSpeed+PEFT-wrapped models. Co-authored-by: Cursor <cursoragent@cursor.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, reproducing the DAPO/CISPO objective with much lower peak memory. The changes include the core implementation of the tiled loss, helper functions for causal LM backbone extraction, and comprehensive unit tests. The review feedback highlights two performance and memory optimization opportunities: removing the torch.cuda.empty_cache() call inside the training loop to prevent CPU-GPU synchronization overhead, and performing temperature scaling in-place on the logits tensor to reduce memory allocation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Mirrors single_gpu_on_beaker.sh but enables --use_liger_grpo_loss under ZeRO-3 (deepspeed_stage 3) so the tiled loss path can be smoke-tested on GPU. Co-authored-by: Cursor <cursoragent@cursor.com>
Switch the single-GPU GRPO debug test to ZeRO-3 and --use_liger_grpo_loss so the tiled lm-head loss path is exercised on GPU. Co-authored-by: Cursor <cursoragent@cursor.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. The feedback highlights several excellent optimization and robustness opportunities: using PyTorch's native log_softmax for numerical stability, safeguarding shard division against zero-token edge cases, avoiding unnecessary gradient computations in backward when gradients are not required, and removing a performance-blocking torch.cuda.empty_cache() call from the inner training loop.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. It includes the implementation of the tiled loss path, helper functions for model unwrapping, comprehensive unit tests, and an updated debug script. Feedback highlights that under specific configurations (e.g., num_mini_batches == 1 and use_vllm_logprobs == False on the first epoch), the fallback branch still materializes full-vocabulary logits, which could cause unexpected memory spikes and defeat the purpose of the tiled loss.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
In the single-minibatch, non-vLLM, first-epoch fallback, _train_sample_liger_grpo computed the current-policy old logprobs via forward_for_logprobs, which materializes [B, T, vocab] logits and defeats the tiled loss's memory savings. Compute them tile-by-tile from backbone hidden states via a new tiled_token_logprobs helper instead. Adds a unit test validating the helper against a dense log_softmax+gather reference. Co-authored-by: Cursor <cursoragent@cursor.com>
Summary
--use_liger_grpo_loss,--liger_grpo_loss_chunk_size) that follows DeepSpeed'sTiledFusedLogitsLosspattern: the lm-head projection and scalar loss are recomputed tile-by-tile in a custom autogradFunction, so the full-vocabulary logits are never materialized. This is a large peak-memory win for large vocabularies / long context.masked_mean(pg + beta * kl, mask, None, loss_denominator) * (world_size // sequence_parallel_size)— so a run can toggle the flag on/off without changing the objective.use_liger_grpo_loss=Falsethe normal loss path is completely unchanged. When on, a dedicated_train_sample_liger_grpostep runs the backbone (no lm-head) to get hidden states, then the tiled loss.grpo_utils:tiled_grpo_lm_head_loss/TiledGRPOLMHeadLoss,get_causal_lm_backbone_and_lm_head,forward_for_liger_hidden_states. Config validation restricts the path toloss_fn in {dapo, cispo},kl_estimator=2, andrecord_entropy=False.Scope
This is a focused port adapted to
main's current loss set: the source branch's kernel also embedded DPPO/TVPO, sequence-denominator, and freeze/policy-mask branches that don't exist onmain; those are intentionally not included here.Validation
open_instruct/test_grpo_utils_tiled_loss.pyasserts the tiled loss and gradients (w.r.t. hidden states and the lm-head weight) match a dense reference across DAPO/CISPO × {with, without} ref-KL × shards {1, 3, 100} × temperature {1.0, 0.7}. I separately reproduced this with a standalone script — all cases match to 1e-5.param.ds_grad_is_ready) can only be exercised on GPU. Please do a GPU end-to-end smoke run before merging. I was unable to run GPU/DeepSpeed locally.GPU_TESTS=bypass
Made with Cursor