Skip to content

Add On-Policy Distillation (OPD) for OLMo-core GRPO#1740

Open
farhatkevin wants to merge 15 commits into
mainfrom
codex/opd-olmo-core-grpo
Open

Add On-Policy Distillation (OPD) for OLMo-core GRPO#1740
farhatkevin wants to merge 15 commits into
mainfrom
codex/opd-olmo-core-grpo

Conversation

@farhatkevin

@farhatkevin farhatkevin commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds On-Policy Distillation (OPD) to the OLMo-core GRPO stack as a reusable teacher-scoring + distillation-loss layer. The student keeps sampling its own rollouts; a frozen teacher scores the exact prefixes the student visited; the learner regresses the student toward the teacher's top-k distribution via a direct teacher-top-k forward KL. Works as GRPO + OPD or as pure OPD (task reward zeroed, online rollout infra retained). Disabled by default (opd_enabled=False) — no behavior change to existing runs.

Design doc: docs/algorithms/opd_design.md (+ opd_flow_diagram.html, opd_walkthrough.html).

What's included

  • DistillKit (open_instruct/distillkit/): reusable sparse teacher signal + forward_kl_topk_from_logprobs, and vLLM prompt_logprobs → dense [T, K] extraction (handles vLLM's first-token alignment).
  • Teacher scorer (open_instruct/opd_utils.py): OPDTeacherScorerRayActor (a frozen vLLM engine that scores, not generates) + placement/creation.
  • Batch plumbing: optional teacher_topk_token_ids/teacher_topk_logprobs [B, T, K] carried through pack_sequences, collation, CollatedBatchData, and the Ulysses SP splitter (slices T, keeps K).
  • Learner (olmo_core_train_modules.py + grpo_utils.forward_for_logprobs_and_topk): gathers student top-k logprobs from the same forward used for GRPO (no second forward), adds the masked OPD loss.
  • Config/validation (grpo.py, opd_validation.py): teacher scorer setup; hard-fail on invalid pure-OPD config and on teacher/student tokenizer vocab-id mismatch.

Correctness decisions worth noting

  • Temperature: GRPO sampled-token logprobs use the rollout temperature (importance ratio); OPD student top-k logprobs use raw T=1 logits to match the teacher's raw_logprobs, so the objective is a coherent KL(teacher_raw ‖ student_raw).
  • Teacher tokenizer: the teacher vLLM is loaded with the teacher's own tokenizer (not the student's); scoring runs on pre-tokenized ids, with id-level vocab compatibility validated at startup.
  • Pure OPD: errors unless --beta 0.0 --load_ref_policy false.

Testing

Unit: DistillKit loss/extraction, pack/collate, SP split of [B,T,K], validators, and the raw-vs-tempered logprob split. New end-to-end OPD learner test on a real model (forward → raw top-k gather → forward_kl_topk → backward + gradients) and a temperature-invariance check.

Runs:

  1. Full GPU test suite (54 passed, 1 skipped): Beaker
  2. New end-to-end OPD learner test TestOPDLearnerLossEndToEnd (2 passed): Beaker

Re-run after addressing Gemini + Codex review (OPD forward memory fix, dead-helper removal): full suite 54 passed, 1 skipped.

GPU_TESTS=01KW2VF2BRBH1H1NHH0Y13NCEV

🤖 Generated with Claude Code

farhatkevin and others added 12 commits June 22, 2026 00:02
- forward_for_logprobs_and_topk: gather OPD student top-k logprobs from raw T=1
  logits to match teacher raw_logprobs; keep GRPO sampled-token logprobs at the
  rollout temperature for the importance ratio.
- opd_validation.py: hard-fail on invalid pure-OPD config (beta/ref policy) and on
  teacher/student tokenizer vocab-id mismatch; wired into grpo.py.
- grpo.py: load the teacher vLLM scorer with the teacher's own tokenizer/revision
  instead of the student's.
- test_olmo_core_train_modules.py: end-to-end OPD learner test on a real model
  (forward -> raw top-k gather -> forward_kl_topk -> backward + grads) and a
  temperature-invariance check; test_opd_validation.py covers the validators.
- docs: OPD temperature notes + walkthrough page.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@github-actions

Copy link
Copy Markdown
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2026-06-26 18:16:58.884529105 +0000
+++ site-pr/sitemap.xml	2026-06-26 18:16:55.993999023 +0000
@@ -65,6 +65,10 @@
          <lastmod>2026-06-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/algorithms/opd_design/</loc>
+         <lastmod>2026-06-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements On-Policy Distillation (OPD) for the OLMo-core GRPO stack, introducing a reusable teacher-scoring layer and a sparse forward-KL top-k distillation loss. It adds the distillkit module for sparse distillation math, implements the OPDTeacherScorerRayActor for online teacher scoring, and extends the data pipeline, collation, and sequence-parallel splitting to support 3D teacher tensors. The code review identified three important issues: a potential NaN loss and gradient propagation in the forward KL calculation due to non-finite student logprobs, a potential division-by-zero error in the teacher scoring orchestration if the responses list is empty, and a device mismatch during the student logprob gather if the teacher token IDs are not explicitly moved to the student logits' device.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread open_instruct/distillkit/losses.py Outdated
finite_teacher, safe_teacher_logprobs - torch.log(denom), torch.zeros_like(safe_teacher_logprobs)
)

loss = (teacher_probs * (safe_teacher_logprobs - student_topk_logprobs.float())).sum(dim=-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If student_topk_logprobs contains non-finite values (such as -inf for zero-probability tokens or padding positions), the subtraction safe_teacher_logprobs - student_topk_logprobs can produce inf. When multiplied by teacher_probs (which is 0.0 where finite_teacher is False), this results in 0.0 * inf = nan, causing NaN loss and gradient propagation. To prevent this, mask student_topk_logprobs to be finite where finite_teacher is False before performing the subtraction.

Suggested change
loss = (teacher_probs * (safe_teacher_logprobs - student_topk_logprobs.float())).sum(dim=-1)
safe_student_logprobs = torch.where(
finite_teacher, student_topk_logprobs.float(), torch.zeros_like(student_topk_logprobs.float())
)
loss = (teacher_probs * (safe_teacher_logprobs - safe_student_logprobs)).sum(dim=-1)

Comment on lines +1415 to +1416
if not self.teacher_scorers:
raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If responses is empty, num_scorers will be 0, which leads to a ZeroDivisionError when calculating chunk_size. Add a guard clause to return early if responses is empty.

Suggested change
if not self.teacher_scorers:
raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.")
if not self.teacher_scorers:
raise ValueError("OPD teacher scoring requested but no teacher scorers are configured.")
if not responses:
return [], [], {}

Comment thread open_instruct/distillkit/losses.py Outdated
f"student_logits prefix shape {student_logits.shape[:-1]} does not match "
f"teacher_token_ids prefix shape {teacher_token_ids.shape[:-1]}"
)
safe_token_ids = teacher_token_ids.clamp(min=0, max=student_logits.shape[-1] - 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The teacher_token_ids tensor might be on a different device (e.g., CPU) than student_logits (e.g., GPU). To prevent device mismatch errors during torch.gather, explicitly move teacher_token_ids to the device of student_logits before clamping.

Suggested change
safe_token_ids = teacher_token_ids.clamp(min=0, max=student_logits.shape[-1] - 1)
safe_token_ids = teacher_token_ids.to(student_logits.device).clamp(min=0, max=student_logits.shape[-1] - 1)

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: aad100afcd

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

prompt_logprobs=topk,
logprobs=topk,
flat_logprobs=True,
max_tokens=1,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Leave room for teacher scoring's dummy token

When a rollout reaches the configured limit (max_prompt_token_length + response_length), OPD sends the entire query+response as the vLLM prompt but this sampling config still asks vLLM to generate one more token. Since grpo.py passes the teacher max_model_len as exactly that same query+response limit, full-length/truncated rollouts have no remaining context for this dummy token and teacher scoring can reject or drop the request before producing prompt_logprobs, stalling OPD runs that keep length-finished samples.

Useful? React with 👍 / 👎.

Comment on lines +543 to +545
if not self.streaming_config.opd_use_task_rewards:
task_loss = torch.zeros_like(task_loss)
loss = masked_mean(task_loss, response_mask, None, loss_denominator)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Only suppress task loss when OPD is active

If a user supplies --opd_use_task_rewards false without also enabling OPD, the config is still accepted whenever a normal reward is enabled, but this branch zeros the GRPO loss and the OPD block below does not add any replacement loss. In that configuration training either backpropagates a loss with no gradient or silently performs no useful task update, so the guard should also require opd_enabled (or the config should reject this combination).

Useful? React with 👍 / 👎.

farhatkevin and others added 3 commits June 26, 2026 11:21
- forward_kl_topk_from_logprobs: zero the student term where the teacher entry
  is -inf so 0*inf can't produce NaN; add regression test.
- _score_opd_teacher: return early on empty responses (avoid div-by-zero).
- gather_student_logprobs_at_teacher_topk: move teacher ids to the student
  logits device before gather.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…g guard

- grpo.py: teacher scorer max_model_len += 1 so vLLM's mandatory 1-token
  generation has room when a full-length query+response is scored (P1).
- data_loader.py: reject opd_use_task_rewards=False unless opd_enabled=True,
  so the GRPO task loss is never silently zeroed with no OPD loss to replace
  it (P2).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ather helper

- forward_for_logprobs_and_topk: compute the GRPO sampled-token logprob via the
  fused model_utils.log_softmax_and_gather (as the non-OPD path does) and the
  teacher top-k logprobs via gather(raw_logits) - logsumexp(raw_logits). Neither
  path now materializes a full-vocab [B,T,V] log-softmax tensor (the temperature
  split previously allocated two), and the GRPO logprob is back on the fused path.
- distillkit: delete the unused, divergent gather_student_logprobs_at_teacher_topk
  helper (the live path inlines its own gather); drop it from __init__ exports and
  the walkthrough doc.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant