Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions projects/thinking_midtraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ $\tilde{\mathcal{D}} = \{\tilde{c}^1, \tilde{c}^2, \ldots, \tilde{c}^N\}$.

### 2) Thinking SFT Mid-training

We perform supervised fine-tuning (SFT) mid-training on half of the augmented corpus, which we call $$\tilde{\mathcal{D}}\_{\text{SFT}}$$, using standard next-token prediction. Given a base model $$\mathcal{M}\_{\text{base}}$$ parameterized by $\theta$, we optimize the following objective: $\mathcal{L}\_{\text{SFT}}(\theta) = -\mathbb{E}\_{\tilde{c}^i \sim \tilde{\mathcal{D}}} \left[ \sum_{j=1}^{|\tilde{c}^i|} \log P_\theta(\tilde{c}^i_j \mid \tilde{c}^i_{<j}) \right]$
We perform supervised fine-tuning (SFT) mid-training on half of the augmented corpus, which we call $\tilde{\mathcal{D}}_{\text{SFT}}$, using standard next-token prediction. Given a base model $\mathcal{M}_{\text{base}}$ parameterized by $\theta$, we optimize the following objective: $\mathcal{L}_{\text{SFT}}(\theta) = -\mathbb{E}_{\tilde{c}^i \sim \tilde{\mathcal{D}}} \left[ \sum_{j=1}^{|\tilde{c}^i|} \log P_\theta(\tilde{c}^i_j \mid \tilde{c}^i_{<j}) \right]$

where $\tilde{c}^i_j$ denotes the $j$-th token in the augmented chunk $\tilde{c}^i$, and $\tilde{c}^i_{<j}$ represents all preceding tokens. Importantly, the loss is computed over the entire augmented sequence, including both the original content tokens $x_j$ and the generated thought tokens $\tau_j$. This allows the model to learn to produce intermediate reasoning steps alongside the original content.

Expand All @@ -70,17 +70,17 @@ This SFT mid-training phase serves as an intermediate step between initial pretr

While SFT mid-training encourages the model to imitate the teacher's reasoning patterns, it does not directly optimize for the utility of the generated thoughts. To address this, we introduce a reinforcement learning mid-training phase to further refine the model's reasoning capabilities on pretraining data.

Given the second half of the augmented pretraining corpus $\tilde{\mathcal{D}}\_{RL}$, we process each chunk $\tilde{c}^i$ by splitting it into a prefix $p^i$ and a suffix $s^i$:
$\tilde{c}^i = [p^i, s^i]$ where $p^i$ consists of the initial $l$ tokens and $s^i$ contains the remaining tokens, with $l < |\tilde{c}^i|$. For each prefix $p^i$, the model being mid-trained $\mathcal{M}\_{\text{mid}}$ is tasked with generating a sequence of "thinking" tokens $\hat{\tau}^i$ followed by a predicted suffix $\hat{s}^i$: $[\hat{\tau}^i, \hat{s}^i] = \mathcal{M}\_{\text{mid}}(p^i)$, where $\hat{\tau}^i$ represents the model's intermediate reasoning steps and $\hat{s}^i$ is its prediction of the ground truth suffix $s^i$.
Given the second half of the augmented pretraining corpus $\tilde{\mathcal{D}}_{RL}$, we process each chunk $\tilde{c}^i$ by splitting it into a prefix $p^i$ and a suffix $s^i$:
$\tilde{c}^i = [p^i, s^i]$ where $p^i$ consists of the initial $l$ tokens and $s^i$ contains the remaining tokens, with $l < |\tilde{c}^i|$. For each prefix $p^i$, the model being mid-trained $\mathcal{M}_{\text{mid}}$ is tasked with generating a sequence of "thinking" tokens $\hat{\tau}^i$ followed by a predicted suffix $\hat{s}^i$: $[\hat{\tau}^i, \hat{s}^i] = \mathcal{M}_{\text{mid}}(p^i)$, where $\hat{\tau}^i$ represents the model's intermediate reasoning steps and $\hat{s}^i$ is its prediction of the ground truth suffix $s^i$.

To evaluate the quality of the generated suffix, we employ a LLM as a judge. The judge, $\mathcal{M}\_{\text{judge}}$ receives both the generated suffix $\hat{s}^i$ and the ground truth $s^i$, and outputs a binary reward $r^i \in \{0, 1\}$ indicating whether $\hat{s}^i$ matches $s^i$ sufficiently well according to predefined criteria (e.g., semantic similarity, factual correctness, or task completion): $r^i = \mathcal{M}\_{\text{judge}}(\hat{s}^i, s^i)$.
To evaluate the quality of the generated suffix, we employ a LLM as a judge. The judge, $\mathcal{M}_{\text{judge}}$ receives both the generated suffix $\hat{s}^i$ and the ground truth $s^i$, and outputs a binary reward $r^i \in \{0, 1\}$ indicating whether $\hat{s}^i$ matches $s^i$ sufficiently well according to predefined criteria (e.g., semantic similarity, factual correctness, or task completion): $r^i = \mathcal{M}_{\text{judge}}(\hat{s}^i, s^i)$.



The RL objective is then to maximize the expected reward over the augmented corpus:

$$
\mathcal{L}\_{\text{RLVR}}(\theta) = -\mathbb{E}\_{p^i \sim \tilde{\mathcal{D}}} [ \mathbb{E}\_{[\hat{\tau}^i, \hat{s}^i] \sim \mathcal{M}\_{\text{mid}}(\cdot \mid p^i)} [r^i] ]
\mathcal{L}_{\text{RLVR}}(\theta) = -\mathbb{E}_{p^i \sim \tilde{\mathcal{D}}} [ \mathbb{E}_{[\hat{\tau}^i, \hat{s}^i] \sim \mathcal{M}_{\text{mid}}(\cdot \mid p^i)} [r^i] ]
$$

where $\theta$ are the parameters of the model. We optimize this objective using DrGRPO.
Expand All @@ -90,10 +90,10 @@ By incorporating RL mid-training, our method encourages the model not only to im

### RL Post-Training

The final stage of the pipeline is to run standard post-training. Given a set of questions $\mathcal{Q}$ from a post-training dataset, the model being post-trained $\mathcal{M}\_{\text{post}}$ generates thoughts $\tau$ and answer $\hat{y}^i$ for each question $Q^i \in \mathcal{Q}$. We employ a rule-based reward model, $\mathcal{M}\_{\text{RLVR}}$ to score the responses compare to the ground truth $y^i$: $r^i = \mathcal{M}\_{\text{RLVR}}(\hat{y}^i, y^i)$.
The final stage of the pipeline is to run standard post-training. Given a set of questions $\mathcal{Q}$ from a post-training dataset, the model being post-trained $\mathcal{M}_{\text{post}}$ generates thoughts $\tau$ and answer $\hat{y}^i$ for each question $Q^i \in \mathcal{Q}$. We employ a rule-based reward model, $\mathcal{M}_{\text{RLVR}}$ to score the responses compare to the ground truth $y^i$: $r^i = \mathcal{M}_{\text{RLVR}}(\hat{y}^i, y^i)$.

$$
\mathcal{L}\_{\text{RLVR}}(\theta) = -\mathbb{E}\_{p^i \sim \mathcal{P}} [ \mathbb{E}\_{\hat{y}^i \sim \mathcal{M}\_{\text{post}}(\cdot \mid Q^i)} [r^i] ]
\mathcal{L}_{\text{RLVR}}(\theta) = -\mathbb{E}_{p^i \sim \mathcal{P}} [ \mathbb{E}_{\hat{y}^i \sim \mathcal{M}_{\text{post}}(\cdot \mid Q^i)} [r^i] ]
$$

where $\theta$ are the parameters of the model. We optimize this using DrGRPO.
Expand Down
Loading