Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion docs/source/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,91 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))

我们可以通过设置以下参数进行 GKD 训练:

### 基础参数

| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
|------|------|--------|---------|------|
| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID |
| `--teacher_model` | str | None | - | 教师模型路径或模型 ID |
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Forward KL <br>• 0.5: JSD (平衡)<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 离线学习<br>• 0.5: 混合策略<br>• 1.0: 纯 On-Policy |
| `--seq_kd` | bool | False | True/False | 是否使用教师生成序列<br>• False: 非 on-policy 时使用数据集<br>• True: 非 on-policy 时使用教师生成 |
| `--temperature` | float | 0.9 | > 0 | 生成采样温度,控制随机性 |
| `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 |
| `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 |

### Top-K KL 计算

默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 **Top-K** 模式来减少显存占用和计算量。

| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `--gkd_logits_topk` | int | None | Top-K logits 数量<br>• None: 使用完整词表(默认)<br>• 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL |

**Top-K 模式原理**:

在 Top-K 模式下,选取**教师模型**输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。
$$
D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M})
$$

其中 Top-K 索引来自教师模型:$\text{Top-K} = \text{argtop}_K(P_T)$,$\tilde{P}_T$ 和 $\tilde{P}_S$ 是在 Top-K 子集上**重新归一化**的概率分布:

$$
\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K}
$$

**使用示例**:

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B-Instruct \
--teacher_model Qwen/Qwen2.5-72B-Instruct \
--gkd_logits_topk 64 \
--dataset your_dataset \
...
```

> **注意**:Top-K 模式不能与 liger kernel 同时使用(`--use_liger_kernel`)。

### 外部教师模型 API

当设置 `gkd_logits_topk` 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。

| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `--teacher_model_server` | str | None | 教师模型服务地址<br>如:`http://localhost:8000` |
| `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 |


**步骤 1:部署教师模型服务**

```bash
# 使用 vllm serve 部署教师模型
CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
--port 8000 \
--max-logprobs 64 \
--gpu-memory-utilization 0.9
```

**步骤 2:启动 GKD 训练**

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--dataset your_dataset \
--lmbda 1.0 \
--beta 1.0 \
...
```

> **vLLM max_logprobs 限制**:
> - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整
> - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置

## 采样加速

在 GKD 训练中,涉及到两种在线采样的情况:
Expand All @@ -168,6 +243,8 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))

训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh)

使用 Teacher Server 的训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh)

### 方案 2:教师模型预采样

对于教师模型采样(`seq_kd=True`),推荐使用 **预采样** 方式:先用教师模型离线生成高质量数据,再进行训练。
Expand Down
6 changes: 5 additions & 1 deletion docs/source/Megatron-SWIFT/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ Megatron GKD 当前已支持以下功能:

| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `--teacher_model` | str | 必需 | 教师模型路径或模型 ID |
| `--teacher_model` | str | - | 教师模型路径或模型 ID<br>*使用 `teacher_model_server` 时可省略 |
| `--teacher_model_server` | str | None | 教师模型服务地址(仅支持 `vllm serve`),如 `http://localhost:8000` |
| `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 |
| `--beta` | float | 0.5 | JSD 散度插值系数:<br>• 0.0: Forward KL<br>• 0.5: 对称 JSD<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | On-Policy 学习触发概率:<br>• 0.0: 纯 Off-Policy<br>• 1.0: 纯 On-Policy |
| `--seq_kd` | bool | False | 是否使用教师生成的响应(当前暂不支持) |
Expand Down Expand Up @@ -71,3 +73,5 @@ GKD 支持三种训练模式,通过 `lmbda` 和 `seq_kd` 参数控制:
更多参数请参考[命令行文档](./Command-line-parameters.md)

训练脚本请参考 [Megatron GKD 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)

使用 Teacher Server 的训练脚本请参考 [这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh)
84 changes: 83 additions & 1 deletion docs/source_en/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,96 @@ Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model

We can perform GKD training by setting the following parameters:

### Basic Parameters

| Parameter | Type | Default | Range | Description |
|------|------|--------|---------|------|
| `--teacher_model` | str | Required | - | Teacher model path or model ID |
| `--teacher_model` | str | None | - | Teacher model path or model ID<br>*Can be omitted when using `teacher_model_server` |
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Forward KL <br>• 0.5: JSD (balanced)<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Offline<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences<br>• False: Use dataset when not on-policy<br>• True: Use teacher generation when not on-policy |
| `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
| `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
| `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation |

### Top-K KL Computation

By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use **Top-K** mode to reduce memory usage and computation.

| Parameter | Type | Default | Description |
|------|------|--------|------|
| `--gkd_logits_topk` | int | None | Number of Top-K logits<br>• None: Use full vocabulary (default)<br>• Positive integer: Only use the K tokens with highest teacher probability for KL computation |

**Top-K Mode Principle**:

In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor grammatical error here. "It use" should be "It uses".

Suggested change
In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It uses the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.


$$
D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M})
$$

Where the Top-K indices come from the teacher model: $\text{Top-K} = \text{argtop}_K(P_T)$, and $\tilde{P}_T$ and $\tilde{P}_S$ are the probability distributions **renormalized** over the Top-K subset:

$$
\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K}
$$

**Usage Example**:

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B-Instruct \
--teacher_model Qwen/Qwen2.5-14B-Instruct \
--gkd_logits_topk 64 \
--dataset your_dataset \
...
```

> **Note**: Top-K mode cannot be used with liger kernel (`--use_liger_kernel`).

### External Teacher Model API

When `gkd_logits_topk` is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process.

| Parameter | Type | Default | Description |
|------|------|--------|------|
| `--teacher_model_server` | str | None | Teacher model service URL<br>e.g., `http://localhost:8000` |
| `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API |

**Supported Backends**:
- `vllm serve` (recommended)

> **Note**: Only `vllm serve` is supported as the teacher server backend. The training code sends raw token IDs via the `prompt` field and uses the `prompt_logprobs` parameter in the `/v1/completions` API to obtain input token log-probabilities. This is a vLLM-native feature.

**Step 1: Deploy Teacher Model Service**

```bash
# Deploy teacher model with vllm serve
CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
--port 8000 \
--max-logprobs 64 \
--gpu-memory-utilization 0.9
```

**Step 2: Start GKD Training**

```bash
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--dataset your_dataset \
--lmbda 1.0 \
--beta 1.0 \
...
```

> **vLLM max_logprobs Limitation**:
> - vLLM default `max_logprobs=20`, adjustable via `--max-logprobs N` parameter
> - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting

## Sampling Acceleration

In GKD training, there are two types of online sampling scenarios:
Expand All @@ -168,6 +248,8 @@ Use vLLM as the inference backend to accelerate student model sampling. Supports

Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode).

Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh).


### Solution 2: Teacher Model Pre-sampling

Expand Down
6 changes: 5 additions & 1 deletion docs/source_en/Megatron-SWIFT/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ Megatron GKD currently supports the following features:

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--teacher_model` | str | Required | Path or model ID of the teacher model |
| `--teacher_model` | str | - | Path or model ID of the teacher model<br>*Can be omitted when using `teacher_model_server` |
| `--teacher_model_server` | str | None | Teacher model service URL (`vllm serve` only), e.g. `http://localhost:8000` |
| `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API |
| `--beta` | float | 0.5 | JSD divergence interpolation coefficient:<br>• 0.0: Forward KL<br>• 0.5: Symmetric JSD<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | On-Policy learning probability:<br>• 0.0: Pure Off-Policy<br>• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | Use teacher-generated responses (not yet supported) |
Expand Down Expand Up @@ -71,3 +73,5 @@ GKD supports three training modes, controlled by `lmbda` and `seq_kd` parameters
For more parameters, please refer to [Command-line Parameters](./Command-line-parameters.md)

For training scripts, please refer to [Megatron GKD Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)

Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh)
43 changes: 43 additions & 0 deletions examples/megatron/rlhf/gkd/teacher_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Teacher server must be running first:
# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct --port 8000 --max-logprobs 64

CUDA_VISIBLE_DEVICES=1,2 \
NPROC_PER_NODE=2 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
megatron rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-0.5B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--dataset 'modelscope/gsm8k' \
--tensor_model_parallel_size 1 \
--pipeline_model_parallel_size 1 \
--context_parallel_size 1 \
--expert_model_parallel_size 1 \
--lmbda 1 \
--seq_kd false \
--beta 0.5 \
--torch_dtype bfloat16 \
--micro_batch_size 2 \
--global_batch_size 32 \
--train_iters 500 \
--lr 5e-5 \
--lr_warmup_fraction 0.1 \
--logging_steps 1 \
--save_steps 100 \
--save_total_limit 10 \
--max_length 2048 \
--max_completion_length 2048 \
--attention_backend flash \
--use_vllm true \
--vllm_mode colocate \
--vllm_gpu_memory_utilization 0.5 \
--vllm_tensor_parallel_size 1 \
--vllm_max_model_len 4096 \
--sleep_level 1 \
--finetune \
--no_save_optim \
--no_save_rng \
--temperature 1.0 \
--padding_free true \
--recompute_granularity selective
44 changes: 44 additions & 0 deletions examples/train/rlhf/gkd/teacher_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GKD Training with External Teacher Model Server (vLLM)
# ===================== Step 1: Start Teacher Server =====================
# Run in a separate terminal / GPU:
#
# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \
# --port 8000 \
# --max-logprobs 64 \
# --gpu-memory-utilization 0.9

# ========================================================================

NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-0.5B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--use_vllm true \
--vllm_mode colocate \
--vllm_gpu_memory_utilization 0.5 \
--vllm_tensor_parallel_size 1 \
--vllm_max_model_len 4096 \
--sleep_level 0 \
--dataset 'modelscope/gsm8k' \
--lmbda 1 \
--seq_kd false \
--beta 0.5 \
--torch_dtype bfloat16 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--learning_rate 5e-5 \
--logging_steps 1 \
--save_steps 100 \
--save_total_limit 2 \
--max_length 2048 \
--max_completion_length 2048 \
Comment on lines +37 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values for max_length and max_completion_length are both set to 2048. max_length represents the total sequence length (prompt + completion). Setting them to the same value implies that the prompt length must be zero to avoid truncation, which is usually not the case. This could lead to unexpected behavior. Please consider reducing max_completion_length to a smaller value to allow for non-empty prompts.

Suggested change
--max_length 2048 \
--max_completion_length 2048 \
--max_length 2048 \
--max_completion_length 1024 \

--warmup_ratio 0.1 \
--save_only_model true \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--attn_impl flash_attn \
--report_to tensorboard swanlab
Loading
Loading