Skip to content

Add top-K GJS divergence mode to Tinker SDPO engine#41

Open
kfallah wants to merge 3 commits intomainfrom
feat/tinker-topk-gjs
Open

Add top-K GJS divergence mode to Tinker SDPO engine#41
kfallah wants to merge 3 commits intomainfrom
feat/tinker-topk-gjs

Conversation

@kfallah
Copy link
Copy Markdown
Owner

@kfallah kfallah commented Mar 7, 2026

Summary

  • Adds use_topk_divergence: bool = False toggle to TrainingConfig that switches the Tinker engine from scalar KL advantages (teacher_lp - student_lp) to distributional GJS advantages computed over top-K token distributions
  • New claas/training/gjs.py module with pure-Python GJS computation (compute_topk_gjs, extract_token_logprobs, slice_completion_topk) — mirrors the torch-based GJS in sdpo_loss.py for the non-GPU Tinker path
  • Refactors PreparedSample from TypedDict to discriminated-union dataclasses (ScalarPreparedSample | TopKPreparedSample) with corresponding ScalarBehavior | TopKBehavior for type-safe code paths

Test plan

  • test_gjs.py: 13 unit tests covering GJS mathematical properties (identical distributions → 0, JSD symmetry at α=0.5, non-negativity, floor logprob handling, α=1.0 degeneration, multi-position, length mismatch)
  • test_tinker_engine.py: 3 new tests (test_engine_distill_topk_mode, test_engine_distill_topk_multistep, test_engine_distill_scalar_mode_reports_divergence_mode)
  • All existing tests pass (127 passed, 29 skipped — skips are tinker SDK not installed locally, as expected)
  • ruff check clean, ty check shows only expected unresolved-import errors for GPU deps

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for top-K divergence mode in training configuration.
    • Introduced GJS divergence computation for top-K probability distributions.
    • Enhanced metrics reporting with divergence mode tracking (scalar_kl or topk_gjs).
  • Tests

    • Added comprehensive test coverage for top-K divergence computation and scalar divergence modes.

Adds a `use_topk_divergence` toggle to TrainingConfig that switches the
Tinker engine from scalar KL advantages (teacher_lp - student_lp) to
distributional GJS advantages computed over top-K token distributions.

When enabled, the engine fetches top-K teacher/student distributions via
`sample_async(topk_prompt_logprobs=K)` and computes per-token Generalized
Jensen-Shannon Divergence — the same formulation used by the local/Modal
engine's torch-based `sdpo_loss.py`, but implemented in pure Python for
the non-GPU Tinker path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 7, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: e30ff075-29f2-4778-9316-37b00b976899

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

These changes introduce support for top-K divergence computation in the Tinker SDPO training engine. A new GJS module implements pure-Python top-K divergence calculations, configuration options enable the feature, new discriminated-union data structures handle both scalar and top-K processing modes, and the training engine integrates conditional logic to switch between computation paths.

Changes

Cohort / File(s) Summary
Configuration & Type Definitions
claas/core/types.py, claas/eval/configs/base.yaml, claas/eval/types.py
Added use_topk_divergence boolean flag to TrainingConfig, corresponding YAML configuration entry, and divergence_mode field to TinkerDistillMetrics for reporting computation mode.
GJS Divergence Module
claas/training/gjs.py
New module implementing pure-Python top-K GJS divergence computation with three public functions: compute_topk_gjs (per-position GJS over top-K distributions), extract_token_logprobs (lookup token logprobs in top-K sets), and slice_completion_topk (slice top-K data for completions).
Training Engine Integration
claas/training/engine/tinker/engine.py
Introduced discriminated-union data models (ScalarPreparedSample, TopKPreparedSample, ScalarBehavior, TopKBehavior) and new top-K computation functions. Refactored _prepare_sample_inputs, _build_sample_datum, and related flows with conditional logic to route between scalar logprob and top-K GJS computation paths based on use_topk flag.
Test Coverage
tests/test_gjs.py, tests/test_tinker_engine.py
Comprehensive test suite for GJS functions (168 lines) covering identity, symmetry, floor handling, and edge cases. Extended engine tests (188 lines) validating top-K sampling integration, multi-step training, divergence mode reporting, and end-to-end distillation flow with top-K data.

Sequence Diagram(s)

sequenceDiagram
    participant Engine as Training Engine
    participant Sampler as Student Sampler
    participant GJS as GJS Module
    participant Training as Training Step
    
    Engine->>Engine: Check use_topk flag
    
    alt Top-K Mode
        Engine->>Sampler: sample_async (fetch top-K distributions)
        Sampler-->>Engine: top-K token lists + logprobs
        Engine->>GJS: compute_topk_gjs(teacher_topk, student_topk, alpha)
        GJS-->>Engine: per-position GJS divergences
        Engine->>Engine: Build TopKPreparedSample & TopKBehavior
    else Scalar Mode
        Engine->>Sampler: compute_logprobs_async
        Sampler-->>Engine: scalar logprobs
        Engine->>Engine: _slice_completion_logprobs (existing path)
        Engine->>Engine: Build ScalarPreparedSample & ScalarBehavior
    end
    
    Engine->>Engine: _build_sample_datum (union handling)
    Engine->>Training: training_step(prepared, behavior, alpha)
    Training-->>Engine: loss, advantages, metadata
    Engine->>Engine: Log divergence_mode (topk_gjs / scalar_kl)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 Whisker-twitching with glee...

Top-K divergence hops into view,
Scalar and union types, both tried and true,
GJS computes with whiskers so bright,
Teacher and student dance left and right,
The Tinker engine learns to divide!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and accurately summarizes the main change: adding a top-K GJS divergence mode to the Tinker SDPO engine, which is the primary objective of this PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/tinker-topk-gjs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

- Use explicit `list[BehaviorSignal](...)` cast for top-K behavior list
  assignment (ty enforces list invariance)
- Replace `**core_kwargs` dict splat with explicit keyword args in
  dataclass constructors (ty can't narrow dict value union types)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: de7a9d36d0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +518 to +523
elif isinstance(prepared, TopKPreparedSample) and isinstance(behavior, ScalarBehavior):
# Initial step in top-K mode: behavior is still scalar from rollout cache
# Fall back to scalar KL using teacher_logprobs derived from top-K
raw_advantages = [
t - s
for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply GJS advantages on the first top-K distillation step

This fallback makes use_topk_divergence=True run scalar teacher_logprob - student_logprob advantages whenever behavior is still ScalarBehavior, which is always true for step 1; with steps_per_batch=1, the whole update never uses GJS even though metadata reports "divergence_mode": "topk_gjs". That means experiments configured for top-K divergence silently train with the old objective in a common single-step setting.

Useful? React with 👍 / 👎.


# Extract response token IDs from full_tokens
response_tokens = prepared.full_tokens[prepared.prompt_len:]
scalar_logprobs = extract_token_logprobs(student_topk, response_tokens)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep exact student logprobs for importance-sampling ratios

Here the student logprobs used downstream in loss_fn_inputs["logprobs"] are reconstructed from top-K and floored when a rollout token is not in the returned top-K set, so any miss becomes -20.0 regardless of the true probability. In multi-step top-K mode, this can severely skew IS ratios/advantages whenever updated student rankings drop a response token out of top-K (especially for smaller teacher_top_k), producing incorrect training signals.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@claas/eval/types.py`:
- Line 134: The Tinker engine returns divergence_mode in metadata but the
TinkerDistillMetrics dataclass's divergence_mode field (default "scalar_kl") is
not being set from metadata; update the TinkerDistillMetrics instantiation in
runner.py (where TinkerDistillMetrics is constructed) to pass
divergence_mode=metadata.get("divergence_mode", "scalar_kl") so the value from
metadata (e.g., "topk_gjs" or "scalar_kl") is used instead of always using the
default.

In `@claas/training/engine/tinker/engine.py`:
- Around line 325-336: The type mismatch arises because
_compute_student_topk_for_batch is annotated to return list[TopKBehavior] which
is not compatible with list[BehaviorSignal]; update the return type annotation
of the _compute_student_topk_for_batch function to list[BehaviorSignal] (or a
more general Sequence[BehaviorSignal] if preferred) so its declared output
matches the variable behavior_signals and resolves the invariant list type error
involving TopKBehavior and BehaviorSignal.
- Around line 436-479: Using core_kwargs = dict(...) then unpacking with
**core_kwargs loses type information and breaks static analysis; instead, pass
the shared parameters explicitly into the PreparedSample constructors to
preserve types. Replace the **core_kwargs usage in both TopKPreparedSample(...)
and ScalarPreparedSample(...) with explicit keyword arguments:
full_tokens=full_tokens, input_tokens=input_tokens, target_tokens=target_tokens,
prompt_len=prompt_len, completion_len=completion_len,
teacher_scored_text=teacher_scored_text, and keep the existing
teacher_logprobs=... and teacher_topk=... args for TopKPreparedSample;
alternatively define a TypedDict for core_kwargs and annotate its type where
constructed if you prefer keeping a single container, but the simplest fix is to
inline the explicit keywords in TopKPreparedSample and ScalarPreparedSample to
restore type inference.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: f8a23e6b-058b-4687-9048-eb3b60674b35

📥 Commits

Reviewing files that changed from the base of the PR and between 838bf90 and de7a9d3.

📒 Files selected for processing (7)
  • claas/core/types.py
  • claas/eval/configs/base.yaml
  • claas/eval/types.py
  • claas/training/engine/tinker/engine.py
  • claas/training/gjs.py
  • tests/test_gjs.py
  • tests/test_tinker_engine.py

completion_len: int = 0
batch_size: int = 0
steps_per_batch_applied: int = 1
divergence_mode: str = "scalar_kl"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify how TinkerDistillMetrics is instantiated in runner.py
rg -n "TinkerDistillMetrics\(" claas/eval/runner.py -A 15

Repository: kfallah/CLaaS

Length of output: 1601


🏁 Script executed:

#!/bin/bash
# Check the TinkerDistillMetrics dataclass definition in types.py
rg -n "class TinkerDistillMetrics" claas/eval/types.py -A 20

Repository: kfallah/CLaaS

Length of output: 711


🏁 Script executed:

#!/bin/bash
# Check where metadata is populated to see if divergence_mode is passed
rg -n "divergence_mode" claas/ -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 681


🏁 Script executed:

#!/bin/bash
# Get more context around the explicit TinkerDistillMetrics instantiation
rg -n "def.*:" claas/eval/runner.py | head -20

Repository: kfallah/CLaaS

Length of output: 671


🏁 Script executed:

#!/bin/bash
# Get the function containing the TinkerDistillMetrics instantiation at line 75
sed -n '60,95p' claas/eval/runner.py

Repository: kfallah/CLaaS

Length of output: 1433


Runner.py does not extract divergence_mode from metadata.

The Tinker engine returns divergence_mode in its metadata response (either "topk_gjs" or "scalar_kl"), but the TinkerDistillMetrics instantiation at claas/eval/runner.py:75-85 doesn't extract it. The field will silently default to "scalar_kl" even if the engine returns a different value.

Update the instantiation to include:

divergence_mode=metadata.get("divergence_mode", "scalar_kl"),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/eval/types.py` at line 134, The Tinker engine returns divergence_mode
in metadata but the TinkerDistillMetrics dataclass's divergence_mode field
(default "scalar_kl") is not being set from metadata; update the
TinkerDistillMetrics instantiation in runner.py (where TinkerDistillMetrics is
constructed) to pass divergence_mode=metadata.get("divergence_mode",
"scalar_kl") so the value from metadata (e.g., "topk_gjs" or "scalar_kl") is
used instead of always using the default.

Comment on lines +325 to +336
if use_topk:
behavior_signals = await _compute_student_topk_for_batch(
student_sampling=student_sampling,
prepared_samples=prepared_samples,
top_k=teacher_top_k,
)
else:
scalar_logprobs = await _compute_student_logprobs_for_batch(
student_sampling=student_sampling,
prepared_samples=prepared_samples,
)
behavior_signals = [ScalarBehavior(logprobs=lps) for lps in scalar_logprobs]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Type annotation mismatch causes static analysis failure.

The pipeline reports list[TopKBehavior] is not assignable to list[BehaviorSignal]. This is a variance issue—lists are invariant in Python's type system.

The fix is to widen the return type annotation of _compute_student_topk_for_batch from list[TopKBehavior] to list[BehaviorSignal]:

🔧 Proposed fix
 async def _compute_student_topk_for_batch(
     *,
     student_sampling: Any,
     prepared_samples: list[PreparedSample],
     top_k: int,
-) -> list[TopKBehavior]:
+) -> list[BehaviorSignal]:
🧰 Tools
🪛 GitHub Actions: CI

[error] 326-331: invalid-assignment: Object of type list[TopKBehavior] is not assignable to list[ScalarBehavior | TopKBehavior]

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/tinker/engine.py` around lines 325 - 336, The type
mismatch arises because _compute_student_topk_for_batch is annotated to return
list[TopKBehavior] which is not compatible with list[BehaviorSignal]; update the
return type annotation of the _compute_student_topk_for_batch function to
list[BehaviorSignal] (or a more general Sequence[BehaviorSignal] if preferred)
so its declared output matches the variable behavior_signals and resolves the
invariant list type error involving TopKBehavior and BehaviorSignal.

- Extract SampleCore, ScalarPreparedSample, TopKPreparedSample,
  ScalarBehavior, TopKBehavior into claas/training/engine/tinker/types.py
- Change use_topk_divergence default from False to True in TrainingConfig
  and base.yaml — top-K GJS is now the default divergence mode

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant