Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions projects/thinking_midtraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,17 @@ This SFT mid-training phase serves as an intermediate step between initial pretr
While SFT mid-training encourages the model to imitate the teacher's reasoning patterns, it does not directly optimize for the utility of the generated thoughts. To address this, we introduce a reinforcement learning mid-training phase to further refine the model's reasoning capabilities on pretraining data.

Given the second half of the augmented pretraining corpus $\tilde{\mathcal{D}}_{RL}$, we process each chunk $\tilde{c}^i$ by splitting it into a prefix $p^i$ and a suffix $s^i$:
$\tilde{c}^i = [p^i, s^i]$ where $p^i$ consists of the initial $l$ tokens and $s^i$ contains the remaining tokens, with $l < |\tilde{c}^i|$. For each prefix $p^i$, the model being mid-trained $\mathcal{M}_{\text{mid}}$ is tasked with generating a sequence of "thinking" tokens $\hat{\tau}^i$ followed by a predicted suffix $\hat{s}^i$: $[\hat{\tau}^i, \hat{s}^i] = \mathcal{M}_{\text{mid}}(p^i)$, where $\hat{\tau}^i$ represents the model's intermediate reasoning steps and $\hat{s}^i$ is its prediction of the ground truth suffix $s^i$.
$\tilde{c}^i = [p^i, s^i]$ where $p^i$ consists of the initial $l$ tokens and $s^i$ contains the remaining tokens, with $l < |\tilde{c}^i|$.

To evaluate the quality of the generated suffix, we employ a LLM as a judge. The judge, $\mathcal{M}_{\text{judge}}$ receives both the generated suffix $\hat{s}^i$ and the ground truth $s^i$, and outputs a binary reward $r^i \in \{0, 1\}$ indicating whether $\hat{s}^i$ matches $s^i$ sufficiently well according to predefined criteria (e.g., semantic similarity, factual correctness, or task completion): $r^i = \mathcal{M}_{\text{judge}}(\hat{s}^i, s^i)$.
For each prefix $p^i$, the model being mid-trained $\mathcal{M}_{\text{mid}}$ is tasked with generating a sequence of "thinking" tokens $\hat{\tau}^i$ followed by a predicted suffix $\hat{s}^i$:

$$[\hat{\tau}^i, \hat{s}^i] = \mathcal{M}_{\text{mid}}(p^i)$$

where $\hat{\tau}^i$ represents the model's intermediate reasoning steps and $\hat{s}^i$ is its prediction of the ground truth suffix $s^i$.

To evaluate the quality of the generated suffix, we employ a LLM as a judge. The judge, $\mathcal{M}_{\text{judge}}$ receives both the generated suffix $\hat{s}^i$ and the ground truth $s^i$, and outputs a binary reward $r^i \in \{0, 1\}$ indicating whether $\hat{s}^i$ matches $s^i$ sufficiently well according to predefined criteria (e.g., semantic similarity, factual correctness, or task completion):

$$r^i = \mathcal{M}_{\text{judge}}(\hat{s}^i, s^i)$$



Expand All @@ -94,7 +102,11 @@ By incorporating RL mid-training, our method encourages the model not only to im

### RL Post-Training

The final stage of the pipeline is to run standard post-training. Given a set of questions $\mathcal{Q}$ from a post-training dataset, the model being post-trained $\mathcal{M}_{\text{post}}$ generates thoughts $\tau$ and answer $\hat{y}^i$ for each question $Q^i \in \mathcal{Q}$. We employ a rule-based reward model, $\mathcal{M}_{\text{RLVR}}$ to score the responses compare to the ground truth $y^i$: $r^i = \mathcal{M}_{\text{RLVR}}(\hat{y}^i, y^i)$.
The final stage of the pipeline is to run standard post-training. Given a set of questions $\mathcal{Q}$ from a post-training dataset, the model being post-trained $\mathcal{M}_{\text{post}}$ generates thoughts $\tau$ and answer $\hat{y}^i$ for each question $Q^i \in \mathcal{Q}$.

We employ a rule-based reward model, $\mathcal{M}_{\text{RLVR}}$ to score the responses compare to the ground truth $y^i$:

$$r^i = \mathcal{M}_{\text{RLVR}}(\hat{y}^i, y^i)$$

$$
\mathcal{L}_{\text{RLVR}}(\theta) = -\mathbb{E}_{p^i \sim \mathcal{P}} [ \mathbb{E}_{\hat{y}^i \sim \mathcal{M}_{\text{post}}(\cdot \mid Q^i)} [r^i] ]
Expand Down
Loading