Skip to content

Add tiled GRPO lm-head loss#1735

Open
hamishivi wants to merge 13 commits into
mainfrom
hamishivi/tiled-grpo-lm-head-loss
Open

Add tiled GRPO lm-head loss#1735
hamishivi wants to merge 13 commits into
mainfrom
hamishivi/tiled-grpo-lm-head-loss

Conversation

@hamishivi

Copy link
Copy Markdown
Collaborator

Summary

  • Add an opt-in, memory-efficient GRPO loss path (--use_liger_grpo_loss, --liger_grpo_loss_chunk_size) that follows DeepSpeed's TiledFusedLogitsLoss pattern: the lm-head projection and scalar loss are recomputed tile-by-tile in a custom autograd Function, so the full-vocabulary logits are never materialized. This is a large peak-memory win for large vocabularies / long context.
  • The tiled kernel reproduces the existing DAPO/CISPO objective exactly — masked_mean(pg + beta * kl, mask, None, loss_denominator) * (world_size // sequence_parallel_size) — so a run can toggle the flag on/off without changing the objective.
  • Default-off: when use_liger_grpo_loss=False the normal loss path is completely unchanged. When on, a dedicated _train_sample_liger_grpo step runs the backbone (no lm-head) to get hidden states, then the tiled loss.
  • New helpers in grpo_utils: tiled_grpo_lm_head_loss/TiledGRPOLMHeadLoss, get_causal_lm_backbone_and_lm_head, forward_for_liger_hidden_states. Config validation restricts the path to loss_fn in {dapo, cispo}, kl_estimator=2, and record_entropy=False.

Scope

This is a focused port adapted to main's current loss set: the source branch's kernel also embedded DPPO/TVPO, sequence-denominator, and freeze/policy-mask branches that don't exist on main; those are intentionally not included here.

Validation

  • Correctness (CPU): new open_instruct/test_grpo_utils_tiled_loss.py asserts the tiled loss and gradients (w.r.t. hidden states and the lm-head weight) match a dense reference across DAPO/CISPO × {with, without} ref-KL × shards {1, 3, 100} × temperature {1.0, 0.7}. I separately reproduced this with a standalone script — all cases match to 1e-5.
  • Still needs a GPU run: the DeepSpeed ZeRO-3 integration (per-tile gradient reduction via param.ds_grad_is_ready) can only be exercised on GPU. Please do a GPU end-to-end smoke run before merging. I was unable to run GPU/DeepSpeed locally.

GPU_TESTS=bypass

Made with Cursor

hamishivi and others added 2 commits June 23, 2026 14:09
Add an opt-in, memory-efficient GRPO loss path (`--use_liger_grpo_loss`) that
follows DeepSpeed's TiledFusedLogitsLoss pattern: the lm-head projection and
scalar loss are recomputed tile-by-tile so the full-vocabulary logits are never
materialized, which is a large peak-memory win for big vocabularies / long
context.

The tiled kernel reproduces the existing DAPO/CISPO objective exactly
(`masked_mean(pg + beta * kl, mask, None, loss_denominator) * (world_size // sp)`),
so a run can toggle the flag without changing the objective. Default-off; the
normal loss path is unchanged when disabled.

The kernel math and gradients (w.r.t. hidden states and lm-head weight) are
validated on CPU against a dense reference across DAPO/CISPO, with/without ref
KL, multiple shard counts, and temperatures. The DeepSpeed ZeRO-3 integration
(per-tile grad reduction via ds_grad_is_ready) still needs a GPU end-to-end run.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. Feedback on the changes highlights a correctness bug in the cispo objective where the clipping fraction (clipfrac) is incorrectly logged as 0.0 because pg_losses2 is set equal to pg_losses. Additionally, a redundant .clone() call on selected_token_ids was identified in grpo_fast.py since the subsequent torch.where operation already returns a new tensor.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_fast.py Outdated
hamishivi and others added 3 commits June 23, 2026 14:32
torch.where already returns a new tensor, so cloning selected_token_ids first is
an unnecessary allocation/copy.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The tiled path computed clipfrac as (pg_losses2 > pg_losses), which is always
False for cispo (pg_losses2 == pg_losses) so clipfrac was wrongly logged as 0.
Compute the clip fraction per objective: (pg2 > pg1) for dapo and
(ratio > 1 + clip_higher) for cispo, matching the non-tiled path. Also avoid the
redundant pg_losses2/kl_all allocations. Extend the unit test to assert clipfrac
matches the dense reference (with old logprobs chosen so cispo clipfrac is non-zero).

Also resolve ty type-check errors (cast backbone/lm_head, ignore the DeepSpeed
runtime ds_grad_is_ready attribute) so `make quality` passes.

Co-authored-by: Cursor <cursoragent@cursor.com>
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, thereby significantly reducing peak memory usage. Feedback on the implementation identifies a critical issue with DeepSpeed ZeRO-3 gradient synchronization in TiledGRPOLMHeadLoss.backward. Because the lm_head parameters are not returned in the custom backward function, their backward hooks are never triggered during the outer backward pass, which will prevent gradient reduction across GPUs and cause silent weight divergence. A detailed solution is provided to pass and return the parameter gradients explicitly.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_utils.py Outdated
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. The review feedback suggests several key improvements: adding a validation check to reject the unsupported combination of use_liger_grpo_loss and use_rho_correction, removing a performance-degrading torch.cuda.empty_cache() call in the training loop, retrieving lm_head parameters recursively to support PEFT adapters, and correcting # ty: ignore typos to standard # type: ignore comments.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py Outdated
Comment thread open_instruct/grpo_fast.py
Comment thread open_instruct/grpo_utils.py Outdated
Comment thread open_instruct/grpo_utils.py Outdated
hamishivi and others added 3 commits June 23, 2026 15:16
Thread the per-token truncated-importance-sampling weights (rho = pi^train_old /
pi^infer_old) through tiled_grpo_lm_head_loss: scale the policy loss per token and
zero clipfrac for dropped tokens, matching the non-tiled compute_grpo_loss path.
_train_sample_liger_grpo now computes rho via compute_rho_correction, logs its
histograms/metrics, and passes the weights to the kernel. Extends the unit test
(and standalone validation) to cover the rho path against a dense reference.

Co-authored-by: Cursor <cursoragent@cursor.com>
Use lm_head.parameters() (recurse=True) so wrapped/PEFT heads (e.g. LoRA) have
their trainable parameters included in the ZeRO-3 grad-ready bookkeeping.

Co-authored-by: Cursor <cursoragent@cursor.com>
ty honors the standard # type: ignore[unresolved-attribute] form, which is more
portable than the ty-specific spelling.

Co-authored-by: Cursor <cursoragent@cursor.com>
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) to recompute the lm-head projection and loss tile-by-tile, reproducing the DAPO/CISPO objectives with significantly lower peak memory. The feedback highlights several important improvements: resolving an issue where PeftModel is not unwrapped (which would run the entire model and break training), optimizing DeepSpeed ZeRO-3 performance by gathering lm_head parameters outside the shard loop to avoid redundant communication overhead, removing a redundant requires_grad_ call on view tensors to prevent potential PyTorch errors, and parameterizing the hardcoded KL estimator index for better flexibility.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py Outdated
hamishivi and others added 2 commits June 23, 2026 15:49
Thread kl_estimator through tiled_grpo_lm_head_loss (kl = kl_all[kl_estimator])
instead of hardcoding estimator 2, and drop the kl_estimator==2 restriction on
use_liger_grpo_loss. Logs loss/kl_avg with the configured estimator. Extends the
unit test (and standalone validation) to sweep estimators 0-3 against a dense
reference.

Co-authored-by: Cursor <cursoragent@cursor.com>
_unwrap_causal_lm now also descends through PEFT PeftModel wrappers (via
get_base_model), so get_causal_lm_backbone_and_lm_head returns the real backbone
and lm_head instead of treating the full causal LM as the backbone (which would
materialize logits and break the tiled loss). Adds unit tests for plain,
DeepSpeed-, PEFT-, and DeepSpeed+PEFT-wrapped models.

Co-authored-by: Cursor <cursoragent@cursor.com>
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, reproducing the DAPO/CISPO objective with much lower peak memory. The changes include the core implementation of the tiled loss, helper functions for causal LM backbone extraction, and comprehensive unit tests. The review feedback highlights two performance and memory optimization opportunities: removing the torch.cuda.empty_cache() call inside the training loop to prevent CPU-GPU synchronization overhead, and performing temperature scaling in-place on the logits tensor to reduce memory allocation.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_fast.py
Comment thread open_instruct/grpo_utils.py
hamishivi and others added 2 commits June 23, 2026 16:07
Mirrors single_gpu_on_beaker.sh but enables --use_liger_grpo_loss under ZeRO-3
(deepspeed_stage 3) so the tiled loss path can be smoke-tested on GPU.

Co-authored-by: Cursor <cursoragent@cursor.com>
Switch the single-GPU GRPO debug test to ZeRO-3 and --use_liger_grpo_loss so the
tiled lm-head loss path is exercised on GPU.

Co-authored-by: Cursor <cursoragent@cursor.com>
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. The feedback highlights several excellent optimization and robustness opportunities: using PyTorch's native log_softmax for numerical stability, safeguarding shard division against zero-token edge cases, avoiding unnecessary gradient computations in backward when gradients are not required, and removing a performance-blocking torch.cuda.empty_cache() call from the inner training loop.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_utils.py
Comment thread open_instruct/grpo_fast.py
@hamishivi

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in tiled GRPO lm-head loss (--use_liger_grpo_loss) that recomputes the lm-head projection and loss tile-by-tile to avoid materializing full-vocabulary logits, significantly reducing peak memory usage. It includes the implementation of the tiled loss path, helper functions for model unwrapping, comprehensive unit tests, and an updated debug script. Feedback highlights that under specific configurations (e.g., num_mini_batches == 1 and use_vllm_logprobs == False on the first epoch), the fallback branch still materializes full-vocabulary logits, which could cause unexpected memory spikes and defeat the purpose of the tiled loss.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/grpo_fast.py
In the single-minibatch, non-vLLM, first-epoch fallback, _train_sample_liger_grpo
computed the current-policy old logprobs via forward_for_logprobs, which
materializes [B, T, vocab] logits and defeats the tiled loss's memory savings.
Compute them tile-by-tile from backbone hidden states via a new
tiled_token_logprobs helper instead. Adds a unit test validating the helper
against a dense log_softmax+gather reference.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant