Add top-K GJS divergence mode to Tinker SDPO engine#41
Conversation
Adds a `use_topk_divergence` toggle to TrainingConfig that switches the Tinker engine from scalar KL advantages (teacher_lp - student_lp) to distributional GJS advantages computed over top-K token distributions. When enabled, the engine fetches top-K teacher/student distributions via `sample_async(topk_prompt_logprobs=K)` and computes per-token Generalized Jensen-Shannon Divergence — the same formulation used by the local/Modal engine's torch-based `sdpo_loss.py`, but implemented in pure Python for the non-GPU Tinker path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThese changes introduce support for top-K divergence computation in the Tinker SDPO training engine. A new GJS module implements pure-Python top-K divergence calculations, configuration options enable the feature, new discriminated-union data structures handle both scalar and top-K processing modes, and the training engine integrates conditional logic to switch between computation paths. Changes
Sequence Diagram(s)sequenceDiagram
participant Engine as Training Engine
participant Sampler as Student Sampler
participant GJS as GJS Module
participant Training as Training Step
Engine->>Engine: Check use_topk flag
alt Top-K Mode
Engine->>Sampler: sample_async (fetch top-K distributions)
Sampler-->>Engine: top-K token lists + logprobs
Engine->>GJS: compute_topk_gjs(teacher_topk, student_topk, alpha)
GJS-->>Engine: per-position GJS divergences
Engine->>Engine: Build TopKPreparedSample & TopKBehavior
else Scalar Mode
Engine->>Sampler: compute_logprobs_async
Sampler-->>Engine: scalar logprobs
Engine->>Engine: _slice_completion_logprobs (existing path)
Engine->>Engine: Build ScalarPreparedSample & ScalarBehavior
end
Engine->>Engine: _build_sample_datum (union handling)
Engine->>Training: training_step(prepared, behavior, alpha)
Training-->>Engine: loss, advantages, metadata
Engine->>Engine: Log divergence_mode (topk_gjs / scalar_kl)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
- Use explicit `list[BehaviorSignal](...)` cast for top-K behavior list assignment (ty enforces list invariance) - Replace `**core_kwargs` dict splat with explicit keyword args in dataclass constructors (ty can't narrow dict value union types) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: de7a9d36d0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| elif isinstance(prepared, TopKPreparedSample) and isinstance(behavior, ScalarBehavior): | ||
| # Initial step in top-K mode: behavior is still scalar from rollout cache | ||
| # Fall back to scalar KL using teacher_logprobs derived from top-K | ||
| raw_advantages = [ | ||
| t - s | ||
| for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True) |
There was a problem hiding this comment.
Apply GJS advantages on the first top-K distillation step
This fallback makes use_topk_divergence=True run scalar teacher_logprob - student_logprob advantages whenever behavior is still ScalarBehavior, which is always true for step 1; with steps_per_batch=1, the whole update never uses GJS even though metadata reports "divergence_mode": "topk_gjs". That means experiments configured for top-K divergence silently train with the old objective in a common single-step setting.
Useful? React with 👍 / 👎.
|
|
||
| # Extract response token IDs from full_tokens | ||
| response_tokens = prepared.full_tokens[prepared.prompt_len:] | ||
| scalar_logprobs = extract_token_logprobs(student_topk, response_tokens) |
There was a problem hiding this comment.
Keep exact student logprobs for importance-sampling ratios
Here the student logprobs used downstream in loss_fn_inputs["logprobs"] are reconstructed from top-K and floored when a rollout token is not in the returned top-K set, so any miss becomes -20.0 regardless of the true probability. In multi-step top-K mode, this can severely skew IS ratios/advantages whenever updated student rankings drop a response token out of top-K (especially for smaller teacher_top_k), producing incorrect training signals.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@claas/eval/types.py`:
- Line 134: The Tinker engine returns divergence_mode in metadata but the
TinkerDistillMetrics dataclass's divergence_mode field (default "scalar_kl") is
not being set from metadata; update the TinkerDistillMetrics instantiation in
runner.py (where TinkerDistillMetrics is constructed) to pass
divergence_mode=metadata.get("divergence_mode", "scalar_kl") so the value from
metadata (e.g., "topk_gjs" or "scalar_kl") is used instead of always using the
default.
In `@claas/training/engine/tinker/engine.py`:
- Around line 325-336: The type mismatch arises because
_compute_student_topk_for_batch is annotated to return list[TopKBehavior] which
is not compatible with list[BehaviorSignal]; update the return type annotation
of the _compute_student_topk_for_batch function to list[BehaviorSignal] (or a
more general Sequence[BehaviorSignal] if preferred) so its declared output
matches the variable behavior_signals and resolves the invariant list type error
involving TopKBehavior and BehaviorSignal.
- Around line 436-479: Using core_kwargs = dict(...) then unpacking with
**core_kwargs loses type information and breaks static analysis; instead, pass
the shared parameters explicitly into the PreparedSample constructors to
preserve types. Replace the **core_kwargs usage in both TopKPreparedSample(...)
and ScalarPreparedSample(...) with explicit keyword arguments:
full_tokens=full_tokens, input_tokens=input_tokens, target_tokens=target_tokens,
prompt_len=prompt_len, completion_len=completion_len,
teacher_scored_text=teacher_scored_text, and keep the existing
teacher_logprobs=... and teacher_topk=... args for TopKPreparedSample;
alternatively define a TypedDict for core_kwargs and annotate its type where
constructed if you prefer keeping a single container, but the simplest fix is to
inline the explicit keywords in TopKPreparedSample and ScalarPreparedSample to
restore type inference.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: f8a23e6b-058b-4687-9048-eb3b60674b35
📒 Files selected for processing (7)
claas/core/types.pyclaas/eval/configs/base.yamlclaas/eval/types.pyclaas/training/engine/tinker/engine.pyclaas/training/gjs.pytests/test_gjs.pytests/test_tinker_engine.py
| completion_len: int = 0 | ||
| batch_size: int = 0 | ||
| steps_per_batch_applied: int = 1 | ||
| divergence_mode: str = "scalar_kl" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify how TinkerDistillMetrics is instantiated in runner.py
rg -n "TinkerDistillMetrics\(" claas/eval/runner.py -A 15Repository: kfallah/CLaaS
Length of output: 1601
🏁 Script executed:
#!/bin/bash
# Check the TinkerDistillMetrics dataclass definition in types.py
rg -n "class TinkerDistillMetrics" claas/eval/types.py -A 20Repository: kfallah/CLaaS
Length of output: 711
🏁 Script executed:
#!/bin/bash
# Check where metadata is populated to see if divergence_mode is passed
rg -n "divergence_mode" claas/ -B 2 -A 2Repository: kfallah/CLaaS
Length of output: 681
🏁 Script executed:
#!/bin/bash
# Get more context around the explicit TinkerDistillMetrics instantiation
rg -n "def.*:" claas/eval/runner.py | head -20Repository: kfallah/CLaaS
Length of output: 671
🏁 Script executed:
#!/bin/bash
# Get the function containing the TinkerDistillMetrics instantiation at line 75
sed -n '60,95p' claas/eval/runner.pyRepository: kfallah/CLaaS
Length of output: 1433
Runner.py does not extract divergence_mode from metadata.
The Tinker engine returns divergence_mode in its metadata response (either "topk_gjs" or "scalar_kl"), but the TinkerDistillMetrics instantiation at claas/eval/runner.py:75-85 doesn't extract it. The field will silently default to "scalar_kl" even if the engine returns a different value.
Update the instantiation to include:
divergence_mode=metadata.get("divergence_mode", "scalar_kl"),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/eval/types.py` at line 134, The Tinker engine returns divergence_mode
in metadata but the TinkerDistillMetrics dataclass's divergence_mode field
(default "scalar_kl") is not being set from metadata; update the
TinkerDistillMetrics instantiation in runner.py (where TinkerDistillMetrics is
constructed) to pass divergence_mode=metadata.get("divergence_mode",
"scalar_kl") so the value from metadata (e.g., "topk_gjs" or "scalar_kl") is
used instead of always using the default.
| if use_topk: | ||
| behavior_signals = await _compute_student_topk_for_batch( | ||
| student_sampling=student_sampling, | ||
| prepared_samples=prepared_samples, | ||
| top_k=teacher_top_k, | ||
| ) | ||
| else: | ||
| scalar_logprobs = await _compute_student_logprobs_for_batch( | ||
| student_sampling=student_sampling, | ||
| prepared_samples=prepared_samples, | ||
| ) | ||
| behavior_signals = [ScalarBehavior(logprobs=lps) for lps in scalar_logprobs] |
There was a problem hiding this comment.
Type annotation mismatch causes static analysis failure.
The pipeline reports list[TopKBehavior] is not assignable to list[BehaviorSignal]. This is a variance issue—lists are invariant in Python's type system.
The fix is to widen the return type annotation of _compute_student_topk_for_batch from list[TopKBehavior] to list[BehaviorSignal]:
🔧 Proposed fix
async def _compute_student_topk_for_batch(
*,
student_sampling: Any,
prepared_samples: list[PreparedSample],
top_k: int,
-) -> list[TopKBehavior]:
+) -> list[BehaviorSignal]:🧰 Tools
🪛 GitHub Actions: CI
[error] 326-331: invalid-assignment: Object of type list[TopKBehavior] is not assignable to list[ScalarBehavior | TopKBehavior]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@claas/training/engine/tinker/engine.py` around lines 325 - 336, The type
mismatch arises because _compute_student_topk_for_batch is annotated to return
list[TopKBehavior] which is not compatible with list[BehaviorSignal]; update the
return type annotation of the _compute_student_topk_for_batch function to
list[BehaviorSignal] (or a more general Sequence[BehaviorSignal] if preferred)
so its declared output matches the variable behavior_signals and resolves the
invariant list type error involving TopKBehavior and BehaviorSignal.
- Extract SampleCore, ScalarPreparedSample, TopKPreparedSample, ScalarBehavior, TopKBehavior into claas/training/engine/tinker/types.py - Change use_topk_divergence default from False to True in TrainingConfig and base.yaml — top-K GJS is now the default divergence mode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
use_topk_divergence: bool = Falsetoggle toTrainingConfigthat switches the Tinker engine from scalar KL advantages (teacher_lp - student_lp) to distributional GJS advantages computed over top-K token distributionsclaas/training/gjs.pymodule with pure-Python GJS computation (compute_topk_gjs,extract_token_logprobs,slice_completion_topk) — mirrors the torch-based GJS insdpo_loss.pyfor the non-GPU Tinker pathPreparedSamplefrom TypedDict to discriminated-union dataclasses (ScalarPreparedSample | TopKPreparedSample) with correspondingScalarBehavior | TopKBehaviorfor type-safe code pathsTest plan
test_gjs.py: 13 unit tests covering GJS mathematical properties (identical distributions → 0, JSD symmetry at α=0.5, non-negativity, floor logprob handling, α=1.0 degeneration, multi-position, length mismatch)test_tinker_engine.py: 3 new tests (test_engine_distill_topk_mode,test_engine_distill_topk_multistep,test_engine_distill_scalar_mode_reports_divergence_mode)ruff checkclean,ty checkshows only expected unresolved-import errors for GPU deps🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Tests