Add On-Policy Distillation (OPD) for OLMo-core GRPO#1740
Conversation
- forward_for_logprobs_and_topk: gather OPD student top-k logprobs from raw T=1 logits to match teacher raw_logprobs; keep GRPO sampled-token logprobs at the rollout temperature for the importance ratio. - opd_validation.py: hard-fail on invalid pure-OPD config (beta/ref policy) and on teacher/student tokenizer vocab-id mismatch; wired into grpo.py. - grpo.py: load the teacher vLLM scorer with the teacher's own tokenizer/revision instead of the student's. - test_olmo_core_train_modules.py: end-to-end OPD learner test on a real model (forward -> raw top-k gather -> forward_kl_topk -> backward + grads) and a temperature-invariance check; test_opd_validation.py covers the validators. - docs: OPD temperature notes + walkthrough page. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Documentation Changes Detected📄
|
There was a problem hiding this comment.
Code Review
This pull request implements On-Policy Distillation (OPD) for the OLMo-core GRPO stack, introducing a reusable teacher-scoring layer and a sparse forward-KL top-k distillation loss. It adds the distillkit module for sparse distillation math, implements the OPDTeacherScorerRayActor for online teacher scoring, and extends the data pipeline, collation, and sequence-parallel splitting to support 3D teacher tensors. The code review identified three important issues: a potential NaN loss and gradient propagation in the forward KL calculation due to non-finite student logprobs, a potential division-by-zero error in the teacher scoring orchestration if the responses list is empty, and a device mismatch during the student logprob gather if the teacher token IDs are not explicitly moved to the student logits' device.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| finite_teacher, safe_teacher_logprobs - torch.log(denom), torch.zeros_like(safe_teacher_logprobs) | ||
| ) | ||
|
|
||
| loss = (teacher_probs * (safe_teacher_logprobs - student_topk_logprobs.float())).sum(dim=-1) |
There was a problem hiding this comment.
If student_topk_logprobs contains non-finite values (such as -inf for zero-probability tokens or padding positions), the subtraction safe_teacher_logprobs - student_topk_logprobs can produce inf. When multiplied by teacher_probs (which is 0.0 where finite_teacher is False), this results in 0.0 * inf = nan, causing NaN loss and gradient propagation. To prevent this, mask student_topk_logprobs to be finite where finite_teacher is False before performing the subtraction.
| loss = (teacher_probs * (safe_teacher_logprobs - student_topk_logprobs.float())).sum(dim=-1) | |
| safe_student_logprobs = torch.where( | |
| finite_teacher, student_topk_logprobs.float(), torch.zeros_like(student_topk_logprobs.float()) | |
| ) | |
| loss = (teacher_probs * (safe_teacher_logprobs - safe_student_logprobs)).sum(dim=-1) |
| if not self.teacher_scorers: | ||
| raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.") |
There was a problem hiding this comment.
If responses is empty, num_scorers will be 0, which leads to a ZeroDivisionError when calculating chunk_size. Add a guard clause to return early if responses is empty.
| if not self.teacher_scorers: | |
| raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.") | |
| if not self.teacher_scorers: | |
| raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.") | |
| if not responses: | |
| return [], [], {} |
| f"student_logits prefix shape {student_logits.shape[:-1]} does not match " | ||
| f"teacher_token_ids prefix shape {teacher_token_ids.shape[:-1]}" | ||
| ) | ||
| safe_token_ids = teacher_token_ids.clamp(min=0, max=student_logits.shape[-1] - 1) |
There was a problem hiding this comment.
The teacher_token_ids tensor might be on a different device (e.g., CPU) than student_logits (e.g., GPU). To prevent device mismatch errors during torch.gather, explicitly move teacher_token_ids to the device of student_logits before clamping.
| safe_token_ids = teacher_token_ids.clamp(min=0, max=student_logits.shape[-1] - 1) | |
| safe_token_ids = teacher_token_ids.to(student_logits.device).clamp(min=0, max=student_logits.shape[-1] - 1) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: aad100afcd
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| prompt_logprobs=topk, | ||
| logprobs=topk, | ||
| flat_logprobs=True, | ||
| max_tokens=1, |
There was a problem hiding this comment.
Leave room for teacher scoring's dummy token
When a rollout reaches the configured limit (max_prompt_token_length + response_length), OPD sends the entire query+response as the vLLM prompt but this sampling config still asks vLLM to generate one more token. Since grpo.py passes the teacher max_model_len as exactly that same query+response limit, full-length/truncated rollouts have no remaining context for this dummy token and teacher scoring can reject or drop the request before producing prompt_logprobs, stalling OPD runs that keep length-finished samples.
Useful? React with 👍 / 👎.
| if not self.streaming_config.opd_use_task_rewards: | ||
| task_loss = torch.zeros_like(task_loss) | ||
| loss = masked_mean(task_loss, response_mask, None, loss_denominator) |
There was a problem hiding this comment.
Only suppress task loss when OPD is active
If a user supplies --opd_use_task_rewards false without also enabling OPD, the config is still accepted whenever a normal reward is enabled, but this branch zeros the GRPO loss and the OPD block below does not add any replacement loss. In that configuration training either backpropagates a loss with no gradient or silently performs no useful task update, so the guard should also require opd_enabled (or the config should reject this combination).
Useful? React with 👍 / 👎.
- forward_kl_topk_from_logprobs: zero the student term where the teacher entry is -inf so 0*inf can't produce NaN; add regression test. - _score_opd_teacher: return early on empty responses (avoid div-by-zero). - gather_student_logprobs_at_teacher_topk: move teacher ids to the student logits device before gather. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…g guard - grpo.py: teacher scorer max_model_len += 1 so vLLM's mandatory 1-token generation has room when a full-length query+response is scored (P1). - data_loader.py: reject opd_use_task_rewards=False unless opd_enabled=True, so the GRPO task loss is never silently zeroed with no OPD loss to replace it (P2). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ather helper - forward_for_logprobs_and_topk: compute the GRPO sampled-token logprob via the fused model_utils.log_softmax_and_gather (as the non-OPD path does) and the teacher top-k logprobs via gather(raw_logits) - logsumexp(raw_logits). Neither path now materializes a full-vocab [B,T,V] log-softmax tensor (the temperature split previously allocated two), and the GRPO logprob is back on the fused path. - distillkit: delete the unused, divergent gather_student_logprobs_at_teacher_topk helper (the live path inlines its own gather); drop it from __init__ exports and the walkthrough doc. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Summary
Adds On-Policy Distillation (OPD) to the OLMo-core GRPO stack as a reusable teacher-scoring + distillation-loss layer. The student keeps sampling its own rollouts; a frozen teacher scores the exact prefixes the student visited; the learner regresses the student toward the teacher's top-k distribution via a direct teacher-top-k forward KL. Works as GRPO + OPD or as pure OPD (task reward zeroed, online rollout infra retained). Disabled by default (
opd_enabled=False) — no behavior change to existing runs.Design doc:
docs/algorithms/opd_design.md(+opd_flow_diagram.html,opd_walkthrough.html).What's included
open_instruct/distillkit/): reusable sparse teacher signal +forward_kl_topk_from_logprobs, and vLLMprompt_logprobs→ dense[T, K]extraction (handles vLLM's first-token alignment).open_instruct/opd_utils.py):OPDTeacherScorerRayActor(a frozen vLLM engine that scores, not generates) + placement/creation.teacher_topk_token_ids/teacher_topk_logprobs[B, T, K]carried throughpack_sequences, collation,CollatedBatchData, and the Ulysses SP splitter (slicesT, keepsK).olmo_core_train_modules.py+grpo_utils.forward_for_logprobs_and_topk): gathers student top-k logprobs from the same forward used for GRPO (no second forward), adds the masked OPD loss.grpo.py,opd_validation.py): teacher scorer setup; hard-fail on invalid pure-OPD config and on teacher/student tokenizer vocab-id mismatch.Correctness decisions worth noting
raw_logprobs, so the objective is a coherentKL(teacher_raw ‖ student_raw).--beta 0.0 --load_ref_policy false.Testing
Unit: DistillKit loss/extraction, pack/collate, SP split of
[B,T,K], validators, and the raw-vs-tempered logprob split. New end-to-end OPD learner test on a real model (forward → raw top-k gather →forward_kl_topk→ backward + gradients) and a temperature-invariance check.Runs:
TestOPDLearnerLossEndToEnd(2 passed): BeakerRe-run after addressing Gemini + Codex review (OPD forward memory fix, dead-helper removal): full suite 54 passed, 1 skipped.
GPU_TESTS=01KW2VF2BRBH1H1NHH0Y13NCEV
🤖 Generated with Claude Code