diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md
index 25e4545802..f8d455a93a 100644
--- a/docs/source/Instruction/GKD.md
+++ b/docs/source/Instruction/GKD.md
@@ -139,9 +139,11 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
我们可以通过设置以下参数进行 GKD 训练:
+### 基础参数
+
| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
|------|------|--------|---------|------|
-| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID |
+| `--teacher_model` | str | None | - | 教师模型路径或模型 ID |
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数
• 0.0: Forward KL
• 0.5: JSD (平衡)
• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率
• 0.0: 离线学习
• 0.5: 混合策略
• 1.0: 纯 On-Policy |
| `--seq_kd` | bool | False | True/False | 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成 |
@@ -149,6 +151,79 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
| `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 |
| `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 |
+### Top-K KL 计算
+
+默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 **Top-K** 模式来减少显存占用和计算量。
+
+| 参数 | 类型 | 默认值 | 说明 |
+|------|------|--------|------|
+| `--gkd_logits_topk` | int | None | Top-K logits 数量
• None: 使用完整词表(默认)
• 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL |
+
+**Top-K 模式原理**:
+
+在 Top-K 模式下,选取**教师模型**输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。
+$$
+D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M})
+$$
+
+其中 Top-K 索引来自教师模型:$\text{Top-K} = \text{argtop}_K(P_T)$,$\tilde{P}_T$ 和 $\tilde{P}_S$ 是在 Top-K 子集上**重新归一化**的概率分布:
+
+$$
+\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K}
+$$
+
+**使用示例**:
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-7B-Instruct \
+ --teacher_model Qwen/Qwen2.5-72B-Instruct \
+ --gkd_logits_topk 64 \
+ --dataset your_dataset \
+ ...
+```
+
+> **注意**:Top-K 模式不能与 liger kernel 同时使用(`--use_liger_kernel`)。
+
+### 外部教师模型 API
+
+当设置 `gkd_logits_topk` 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。
+
+| 参数 | 类型 | 默认值 | 说明 |
+|------|------|--------|------|
+| `--teacher_model_server` | str | None | 教师模型服务地址
如:`http://localhost:8000` |
+| `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 |
+
+
+**步骤 1:部署教师模型服务**
+
+```bash
+# 使用 vllm serve 部署教师模型
+CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
+ --port 8000 \
+ --max-logprobs 64 \
+ --gpu-memory-utilization 0.9
+```
+
+**步骤 2:启动 GKD 训练**
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-7B \
+ --teacher_model_server http://localhost:8000 \
+ --gkd_logits_topk 64 \
+ --dataset your_dataset \
+ --lmbda 1.0 \
+ --beta 1.0 \
+ ...
+```
+
+> **vLLM max_logprobs 限制**:
+> - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整
+> - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置
+
## 采样加速
在 GKD 训练中,涉及到两种在线采样的情况:
@@ -168,6 +243,8 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh)
+使用 Teacher Server 的训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh)
+
### 方案 2:教师模型预采样
对于教师模型采样(`seq_kd=True`),推荐使用 **预采样** 方式:先用教师模型离线生成高质量数据,再进行训练。
diff --git a/docs/source/Megatron-SWIFT/GKD.md b/docs/source/Megatron-SWIFT/GKD.md
index 9cc30004d7..9a023cb456 100644
--- a/docs/source/Megatron-SWIFT/GKD.md
+++ b/docs/source/Megatron-SWIFT/GKD.md
@@ -33,7 +33,9 @@ Megatron GKD 当前已支持以下功能:
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
-| `--teacher_model` | str | 必需 | 教师模型路径或模型 ID |
+| `--teacher_model` | str | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 |
+| `--teacher_model_server` | str | None | 教师模型服务地址(仅支持 `vllm serve`),如 `http://localhost:8000` |
+| `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 |
| `--beta` | float | 0.5 | JSD 散度插值系数:
• 0.0: Forward KL
• 0.5: 对称 JSD
• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | On-Policy 学习触发概率:
• 0.0: 纯 Off-Policy
• 1.0: 纯 On-Policy |
| `--seq_kd` | bool | False | 是否使用教师生成的响应(当前暂不支持) |
@@ -71,3 +73,5 @@ GKD 支持三种训练模式,通过 `lmbda` 和 `seq_kd` 参数控制:
更多参数请参考[命令行文档](./Command-line-parameters.md)
训练脚本请参考 [Megatron GKD 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)
+
+使用 Teacher Server 的训练脚本请参考 [这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh)
diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md
index 9bff8a81eb..cad177cae8 100644
--- a/docs/source_en/Instruction/GKD.md
+++ b/docs/source_en/Instruction/GKD.md
@@ -139,9 +139,11 @@ Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model
We can perform GKD training by setting the following parameters:
+### Basic Parameters
+
| Parameter | Type | Default | Range | Description |
|------|------|--------|---------|------|
-| `--teacher_model` | str | Required | - | Teacher model path or model ID |
+| `--teacher_model` | str | None | - | Teacher model path or model ID
*Can be omitted when using `teacher_model_server` |
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient
• 0.0: Forward KL
• 0.5: JSD (balanced)
• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability
• 0.0: Pure Offline
• 0.5: Mixed strategy (**recommended**)
• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences
• False: Use dataset when not on-policy
• True: Use teacher generation when not on-policy |
@@ -149,6 +151,84 @@ We can perform GKD training by setting the following parameters:
| `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
| `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation |
+### Top-K KL Computation
+
+By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use **Top-K** mode to reduce memory usage and computation.
+
+| Parameter | Type | Default | Description |
+|------|------|--------|------|
+| `--gkd_logits_topk` | int | None | Number of Top-K logits
• None: Use full vocabulary (default)
• Positive integer: Only use the K tokens with highest teacher probability for KL computation |
+
+**Top-K Mode Principle**:
+
+In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
+
+$$
+D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M})
+$$
+
+Where the Top-K indices come from the teacher model: $\text{Top-K} = \text{argtop}_K(P_T)$, and $\tilde{P}_T$ and $\tilde{P}_S$ are the probability distributions **renormalized** over the Top-K subset:
+
+$$
+\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K}
+$$
+
+**Usage Example**:
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-7B-Instruct \
+ --teacher_model Qwen/Qwen2.5-14B-Instruct \
+ --gkd_logits_topk 64 \
+ --dataset your_dataset \
+ ...
+```
+
+> **Note**: Top-K mode cannot be used with liger kernel (`--use_liger_kernel`).
+
+### External Teacher Model API
+
+When `gkd_logits_topk` is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process.
+
+| Parameter | Type | Default | Description |
+|------|------|--------|------|
+| `--teacher_model_server` | str | None | Teacher model service URL
e.g., `http://localhost:8000` |
+| `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API |
+
+**Supported Backends**:
+- `vllm serve` (recommended)
+
+> **Note**: Only `vllm serve` is supported as the teacher server backend. The training code sends raw token IDs via the `prompt` field and uses the `prompt_logprobs` parameter in the `/v1/completions` API to obtain input token log-probabilities. This is a vLLM-native feature.
+
+**Step 1: Deploy Teacher Model Service**
+
+```bash
+# Deploy teacher model with vllm serve
+CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \
+ --port 8000 \
+ --max-logprobs 64 \
+ --gpu-memory-utilization 0.9
+```
+
+**Step 2: Start GKD Training**
+
+```bash
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-7B \
+ --teacher_model_server http://localhost:8000 \
+ --gkd_logits_topk 64 \
+ --dataset your_dataset \
+ --lmbda 1.0 \
+ --beta 1.0 \
+ ...
+```
+
+> **vLLM max_logprobs Limitation**:
+> - vLLM default `max_logprobs=20`, adjustable via `--max-logprobs N` parameter
+> - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting
+
## Sampling Acceleration
In GKD training, there are two types of online sampling scenarios:
@@ -168,6 +248,8 @@ Use vLLM as the inference backend to accelerate student model sampling. Supports
Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode).
+Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh).
+
### Solution 2: Teacher Model Pre-sampling
diff --git a/docs/source_en/Megatron-SWIFT/GKD.md b/docs/source_en/Megatron-SWIFT/GKD.md
index 37c08eea5a..0502b8485b 100644
--- a/docs/source_en/Megatron-SWIFT/GKD.md
+++ b/docs/source_en/Megatron-SWIFT/GKD.md
@@ -33,7 +33,9 @@ Megatron GKD currently supports the following features:
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
-| `--teacher_model` | str | Required | Path or model ID of the teacher model |
+| `--teacher_model` | str | - | Path or model ID of the teacher model
*Can be omitted when using `teacher_model_server` |
+| `--teacher_model_server` | str | None | Teacher model service URL (`vllm serve` only), e.g. `http://localhost:8000` |
+| `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API |
| `--beta` | float | 0.5 | JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | Use teacher-generated responses (not yet supported) |
@@ -71,3 +73,5 @@ GKD supports three training modes, controlled by `lmbda` and `seq_kd` parameters
For more parameters, please refer to [Command-line Parameters](./Command-line-parameters.md)
For training scripts, please refer to [Megatron GKD Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd)
+
+Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh)
diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh
new file mode 100644
index 0000000000..2ccb05b8a1
--- /dev/null
+++ b/examples/megatron/rlhf/gkd/teacher_server.sh
@@ -0,0 +1,43 @@
+# Teacher server must be running first:
+# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct --port 8000 --max-logprobs 64
+
+CUDA_VISIBLE_DEVICES=1,2 \
+NPROC_PER_NODE=2 \
+PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
+megatron rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-0.5B \
+ --teacher_model_server http://localhost:8000 \
+ --gkd_logits_topk 64 \
+ --dataset 'modelscope/gsm8k' \
+ --tensor_model_parallel_size 1 \
+ --pipeline_model_parallel_size 1 \
+ --context_parallel_size 1 \
+ --expert_model_parallel_size 1 \
+ --lmbda 1 \
+ --seq_kd false \
+ --beta 0.5 \
+ --torch_dtype bfloat16 \
+ --micro_batch_size 2 \
+ --global_batch_size 32 \
+ --train_iters 500 \
+ --lr 5e-5 \
+ --lr_warmup_fraction 0.1 \
+ --logging_steps 1 \
+ --save_steps 100 \
+ --save_total_limit 10 \
+ --max_length 2048 \
+ --max_completion_length 2048 \
+ --attention_backend flash \
+ --use_vllm true \
+ --vllm_mode colocate \
+ --vllm_gpu_memory_utilization 0.5 \
+ --vllm_tensor_parallel_size 1 \
+ --vllm_max_model_len 4096 \
+ --sleep_level 1 \
+ --finetune \
+ --no_save_optim \
+ --no_save_rng \
+ --temperature 1.0 \
+ --padding_free true \
+ --recompute_granularity selective
diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh
new file mode 100644
index 0000000000..cbb56680f9
--- /dev/null
+++ b/examples/train/rlhf/gkd/teacher_server.sh
@@ -0,0 +1,44 @@
+# GKD Training with External Teacher Model Server (vLLM)
+# ===================== Step 1: Start Teacher Server =====================
+# Run in a separate terminal / GPU:
+#
+# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \
+# --port 8000 \
+# --max-logprobs 64 \
+# --gpu-memory-utilization 0.9
+
+# ========================================================================
+
+NPROC_PER_NODE=4 \
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
+swift rlhf \
+ --rlhf_type gkd \
+ --model Qwen/Qwen2.5-0.5B \
+ --teacher_model_server http://localhost:8000 \
+ --gkd_logits_topk 64 \
+ --use_vllm true \
+ --vllm_mode colocate \
+ --vllm_gpu_memory_utilization 0.5 \
+ --vllm_tensor_parallel_size 1 \
+ --vllm_max_model_len 4096 \
+ --sleep_level 0 \
+ --dataset 'modelscope/gsm8k' \
+ --lmbda 1 \
+ --seq_kd false \
+ --beta 0.5 \
+ --torch_dtype bfloat16 \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --learning_rate 5e-5 \
+ --logging_steps 1 \
+ --save_steps 100 \
+ --save_total_limit 2 \
+ --max_length 2048 \
+ --max_completion_length 2048 \
+ --warmup_ratio 0.1 \
+ --save_only_model true \
+ --dataloader_num_workers 4 \
+ --dataset_num_proc 4 \
+ --attn_impl flash_attn \
+ --report_to tensorboard swanlab
diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py
index a3a63b7090..50017b92ad 100644
--- a/swift/arguments/rlhf_args.py
+++ b/swift/arguments/rlhf_args.py
@@ -39,7 +39,8 @@ class TeacherModelArguments:
Args:
teacher_model (Optional[str]): The model ID or a local path to the teacher model. This is required when
- `rlhf_type` is 'gkd'. Analogous to the main `model` argument. Defaults to None.
+ `rlhf_type` is 'gkd' and `teacher_model_server` is not set. Analogous to the main `model` argument.
+ Defaults to None.
teacher_adapters (List[str]): A list of paths to LoRA weights. These weights, often produced by SFT, are loaded
to form the teacher model. Defaults to an empty list (`[]`).
teacher_model_type (Optional[str]): The model type of the teacher model. If not specified, it's often inferred.
@@ -50,6 +51,10 @@ class TeacherModelArguments:
one of the following values: 'zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'. If not
provided, it defaults to using the same DeepSpeed configuration as the main training model. Analogous to
the main `deepspeed` argument.
+ teacher_model_server (Optional[str]): The URL of the teacher model server (e.g., 'http://localhost:8000').
+ When set, the teacher logprobs will be fetched from the external API service (e.g., swift deploy, vLLM)
+ instead of loading a local teacher model. This enables using larger teacher models or services hosted
+ remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_adapters: List[str] = field(default_factory=list)
@@ -63,6 +68,13 @@ class TeacherModelArguments:
'DeepSpeed configuration for teacher model. '
'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload'
})
+ teacher_model_server: Optional[str] = field(
+ default=None,
+ metadata={
+ 'help':
+ 'URL of the teacher model server (e.g., http://localhost:8000). '
+ 'When set, teacher logprobs are fetched via API instead of loading a local model.'
+ })
@dataclass
@@ -196,6 +208,11 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo
gkd_loss + sft_alpha * sft_loss`. Defaults to 0.
lmbda (float): The lambda parameter for GKD, balancing policy and value losses. Defaults to 0.5.
seq_kd (bool): Whether to use sequence-level knowledge distillation for GKD. Defaults to False.
+ gkd_logits_topk (Optional[int]): The number of top-k logits to use for KL divergence computation in GKD.
+ If None, uses full vocabulary for KL computation (more accurate but memory-intensive).
+ If set to a positive integer, only top-k teacher logits are used (more efficient).
+ When using `teacher_model_server`, this is limited by the server's `max_logprobs` setting
+ (vLLM default is 20, can be increased with `--max-logprobs`). Defaults to None.
offload_teacher_model (bool): Whether to offload the teacher model to CPU memory to save VRAM during GKD
training. Defaults to False.
max_new_tokens (Optional[int]): A backward-compatibility argument. Please use `max_completion_length` instead.
@@ -233,6 +250,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo
sft_alpha: float = 0
lmbda: float = 0.5
seq_kd: bool = False
+ gkd_logits_topk: Optional[int] = None
offload_teacher_model: bool = False
# compat
max_new_tokens: Optional[int] = None # use max_completion_length instead
@@ -545,3 +563,25 @@ def _check_gkd(self):
if self.async_generate:
raise NotImplementedError('Currently, async_generate is not supported for GKD.')
+
+ # Validate teacher model configuration
+ if self.teacher_model is None and self.teacher_model_server is None:
+ raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.')
+
+ if self.teacher_model is not None and self.teacher_model_server is not None:
+ raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.')
+
+ # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs)
+ if self.teacher_model_server is not None:
+ if self.gkd_logits_topk is None:
+ raise ValueError('gkd_logits_topk is required when using teacher_model_server')
+
+ # Validate gkd_logits_topk
+ if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0:
+ raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}')
+
+ if self.gkd_logits_topk is not None and self.use_liger_kernel:
+ raise ValueError('gkd_logits_topk is not supported when using liger kernel')
+
+ if self.teacher_model_server and self.seq_kd:
+ raise NotImplementedError('Sequential KD is not supported when using teacher_model_server')
diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py
index 082ba37ee4..033352bbdd 100644
--- a/swift/megatron/arguments/megatron_args.py
+++ b/swift/megatron/arguments/megatron_args.py
@@ -45,6 +45,14 @@ class RLHFMegatronArgumentsMixin:
teacher_model: Optional[str] = field(default=None)
teacher_model_type: Optional[str] = field(default=None)
teacher_model_revision: Optional[str] = field(default=None)
+ teacher_model_server: Optional[str] = field(
+ default=None,
+ metadata={
+ 'help':
+ 'URL of the teacher model server (e.g., http://localhost:8000). '
+ 'When set, teacher logprobs are fetched via API instead of loading a local model.'
+ })
+ gkd_logits_topk: Optional[int] = None
lmbda: float = 0.5 # On-policy probability: with prob lmbda, use student-generated responses
seq_kd: bool = False # Sequential KD: use teacher-generated responses when not on-policy
offload_teacher_model: bool = False # Offload teacher model to CPU to save GPU memory
@@ -194,6 +202,21 @@ def __post_init__(self):
if self.vllm_limit_mm_per_prompt is not None:
self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt)
self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs)
+ if self.rlhf_type == 'gkd':
+ if self.teacher_model is None and self.teacher_model_server is None:
+ raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.')
+
+ if self.teacher_model is not None and self.teacher_model_server is not None:
+ raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.')
+
+ # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs)
+ if self.teacher_model_server is not None:
+ if self.gkd_logits_topk is None:
+ raise ValueError('gkd_logits_topk is required when using teacher_model_server')
+
+ # Validate gkd_logits_topk
+ if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0:
+ raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}')
def _init_grpo(self):
diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py
index add6aea39d..09caf85106 100644
--- a/swift/megatron/trainers/gkd_trainer.py
+++ b/swift/megatron/trainers/gkd_trainer.py
@@ -41,10 +41,25 @@ def __init__(self, args: MegatronArguments, template, **kwargs):
self.lmbda = args.lmbda # On-policy probability
self.seq_kd = args.seq_kd # Sequential KD: use teacher-generated responses
self.offload_teacher_model = args.offload_teacher_model # Offload teacher to CPU
- self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_')
- self.teacher_config = self.teacher_bridge.processor.model_info.config
+ self.teacher_model_server = getattr(args, 'teacher_model_server', None)
+ self.use_teacher_api = self.teacher_model_server is not None
+ if args.teacher_model:
+ self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_')
+ self.teacher_config = self.teacher_bridge.processor.model_info.config
self.sft_alpha = getattr(args, 'sft_alpha', 0.0) # Weight for SFT loss
- assert args.teacher_model is not None, 'Teacher model path is required for GKD training'
+
+ # GKD top-k logits configuration
+ self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None)
+ # Check use_teacher_api based on args, not client existence
+ # (API client is only created on last rank, but all ranks need to know the mode)
+
+ # Validate teacher configuration
+ if not self.use_teacher_api:
+ assert args.teacher_model is not None, \
+ 'Teacher model path is required for GKD training (or set teacher_model_server for API mode)'
+ else:
+ logger.info(f'Using teacher model API for logprobs, top_logprobs={self.gkd_logits_topk}')
+
self.use_vllm = getattr(args, 'use_vllm', False)
super().__init__(args, template)
@@ -67,6 +82,9 @@ def train(self, train_dataset, val_dataset):
def prepare_model(self):
super().prepare_model()
+ if self.use_teacher_api:
+ logger.info('Skipping local teacher model loading - using external API for teacher logprobs')
+ return
args = self.args
vp_size = getattr(args, 'virtual_pipeline_model_parallel_size')
assert vp_size is None or vp_size == 1, 'GKD currently does not support VPP.'
@@ -83,12 +101,12 @@ def prepare_model(self):
def _offload_teacher_models(self):
"""Offload teacher models to CPU to save GPU memory."""
- if self.teacher_models:
+ if self.teacher_models and not self.use_teacher_api:
offload_megatron_model_to_cpu(self.teacher_models)
def _load_teacher_models_to_gpu(self):
"""Load teacher models back to GPU."""
- if self.teacher_models:
+ if self.teacher_models and not self.use_teacher_api:
load_megatron_model_to_gpu(self.teacher_models, load_grad=False)
@contextmanager
@@ -247,22 +265,49 @@ def resample_encode_failed_inputs(self, inputs: List[Dict], max_resample_rounds:
return valid_samples[:required_count]
def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None:
+ if self.use_teacher_api:
+ self._compute_teacher_logits_from_api(encoded_batches)
+ else:
+ self._compute_teacher_logits_local(encoded_batches, vp_stage)
+
+ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None:
teacher_model = self.teacher_models[vp_stage or 0]
+ topk = self.gkd_logits_topk
for encoded_batch in encoded_batches:
- # Deep copy to avoid modifying original batch
teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()}
teacher_batch.pop('data_source', None)
teacher_data = self._prepare_batch(teacher_batch)
teacher_data.pop('loss_scale', None)
- # Remove labels so returns logits instead of loss
teacher_data.pop('labels', None)
- # Teacher forward with args override for correct hidden_size
with self.load_teacher_model_context(), torch.no_grad():
teacher_logits = forward_step_helper(self.args, teacher_model, teacher_data)
if teacher_logits is not None:
teacher_logits = teacher_logits.detach()
- encoded_batch['teacher_logits'] = teacher_logits
+
+ if topk is not None and teacher_logits is not None:
+ topk_logits, topk_indices = torch.topk(teacher_logits, k=topk, dim=-1)
+ encoded_batch['teacher_api_logprobs'] = topk_logits
+ encoded_batch['teacher_api_indices'] = topk_indices
+ encoded_batch['teacher_logits'] = None
+ else:
+ encoded_batch['teacher_logits'] = teacher_logits
+
+ def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None:
+ """Fetch teacher logprobs from external API service."""
+ from swift.rlhf_trainers.gkd_trainer import fetch_teacher_logprobs
+ topk = self.gkd_logits_topk
+ for encoded_batch in encoded_batches:
+ input_ids = encoded_batch['input_ids']
+ teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
+ self.teacher_model_server, input_ids.tolist(), topk=topk)
+ # fetch_teacher_logprobs returns [batch, seq_len-1, topk] (shifted).
+ # Pad last position with -inf to match student [batch, seq_len, topk].
+ teacher_logprobs = F.pad(teacher_logprobs, (0, 0, 0, 1), value=float('-inf'))
+ teacher_indices = F.pad(teacher_indices, (0, 0, 0, 1), value=0)
+ encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device)
+ encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device)
+ encoded_batch['teacher_logits'] = None
def _replace_data_iterator(self, data_iterator):
num_microbatches = self.args.num_microbatches
@@ -350,6 +395,8 @@ def generalized_jsd_loss(
labels: torch.Tensor,
beta: float = 0.5,
chunk_size: int = 512,
+ teacher_topk_logprobs: torch.Tensor = None,
+ teacher_topk_indices: torch.Tensor = None,
) -> torch.Tensor:
args = self.args
mask = labels != -100
@@ -364,12 +411,20 @@ def generalized_jsd_loss(
if num_valid == 0:
return (student_logits.sum() * 0).reshape(())
+ # Top-k mode: direct computation without vocab parallel
+ if teacher_topk_logprobs is not None and teacher_topk_indices is not None:
+ total_loss = self._jsd_topk(student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta)
+ if args.context_parallel_size > 1:
+ torch.distributed.all_reduce(
+ total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
+ return total_loss / num_valid
+
+ # Full vocabulary mode (original code)
# Align vocab size between student and teacher
student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits)
# Apply temperature scaling and mask
- student_logits_masked = (student_logits
- / self.temperature)[mask] # [local_num_valid_tokens, partition_vocab_size]
+ student_logits_masked = (student_logits / self.temperature)[mask]
teacher_logits_masked = (teacher_logits / self.temperature)[mask]
del student_logits, teacher_logits
@@ -389,36 +444,28 @@ def generalized_jsd_loss(
s_chunk = student_logits_masked[start_idx:end_idx]
t_chunk = teacher_logits_masked[start_idx:end_idx]
- # Compute log_softmax with vocab-parallel support
s_log_probs = vocab_parallel_log_softmax(s_chunk)
t_log_probs = vocab_parallel_log_softmax(t_chunk)
del s_chunk, t_chunk
if beta == 0:
- # JSD = KL(teacher || student)
jsd_chunk = vocab_parallel_kl_div(s_log_probs, t_log_probs)
elif beta == 1:
- # JSD = KL(student || teacher)
jsd_chunk = vocab_parallel_kl_div(t_log_probs, s_log_probs)
else:
- # Compute mixture log probabilities: m = beta * teacher + (1-beta) * student
- # log(m) = logsumexp(log(student) + log(1-beta), log(teacher) + log(beta))
mixture_log_probs = torch.logsumexp(
torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]),
dim=0,
)
-
kl_teacher = vocab_parallel_kl_div(mixture_log_probs, t_log_probs)
kl_student = vocab_parallel_kl_div(mixture_log_probs, s_log_probs)
del mixture_log_probs
-
jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student
del kl_teacher, kl_student
total_loss = total_loss + jsd_chunk.sum()
del jsd_chunk, s_log_probs, t_log_probs
- # Clean up masked logits
del student_logits_masked, teacher_logits_masked
# All-reduce total_loss across CP group for correct sum
@@ -428,18 +475,57 @@ def generalized_jsd_loss(
return total_loss / num_valid
+ def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta):
+ """Compute JSD on teacher's top-k distribution.
+
+ Both local and API teacher are handled uniformly: gather student logits at
+ teacher's top-k indices, scale by 1/T, and log_softmax over top-k subset.
+ By shift-invariance of log_softmax, this gives identical results whether
+ teacher_topk_logprobs contains raw logits (local) or raw logprobs (API).
+
+ """
+ s_scaled = student_logits / self.temperature
+ s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices)
+ t_topk = teacher_topk_logprobs / self.temperature
+
+ s_topk_masked = s_topk[mask]
+ t_topk_masked = t_topk[mask]
+
+ if s_topk_masked.numel() == 0:
+ return student_logits.new_zeros(())
+
+ t_log_p = F.log_softmax(t_topk_masked, dim=-1)
+ s_log_p = F.log_softmax(s_topk_masked, dim=-1)
+ t_p = torch.exp(t_log_p)
+
+ if beta == 0:
+ jsd = (t_p * (t_log_p - s_log_p)).sum(dim=-1)
+ elif beta == 1:
+ s_p = torch.exp(s_log_p)
+ jsd = (s_p * (s_log_p - t_log_p)).sum(dim=-1)
+ else:
+ s_p = torch.exp(s_log_p)
+ m_log_p = torch.log(beta * t_p + (1 - beta) * s_p + 1e-10)
+ jsd = beta * (t_p * (t_log_p - m_log_p)).sum(-1) + (1 - beta) * (s_p * (s_log_p - m_log_p)).sum(-1)
+
+ return jsd.sum()
+
def loss_func(self,
output_tensor: torch.Tensor,
*,
labels: torch.Tensor,
- teacher_logits: torch.Tensor,
+ teacher_logits: torch.Tensor = None,
+ teacher_api_logprobs: torch.Tensor = None,
+ teacher_api_indices: torch.Tensor = None,
data_source: DataSource = DataSource.DATASET):
"""Compute GKD loss (JSD + optional SFT loss).
Args:
output_tensor: Student model logits [batch, seq_len, vocab_size]
labels: Token labels for masking [batch, seq_len]
- teacher_logits: Teacher model logits [batch, seq_len, vocab_size]
+ teacher_logits: Teacher model logits [batch, seq_len, vocab_size] (for local teacher)
+ teacher_api_logprobs: Teacher log probabilities [batch, seq_len, topk] (for API mode)
+ teacher_api_indices: Teacher token indices [batch, seq_len, topk] (for API mode)
data_source: Data source (STUDENT/TEACHER/DATASET) for conditional SFT loss
"""
student_logits = output_tensor
@@ -449,6 +535,8 @@ def loss_func(self,
teacher_logits=teacher_logits,
labels=labels,
beta=self.beta,
+ teacher_topk_logprobs=teacher_api_logprobs,
+ teacher_topk_indices=teacher_api_indices,
)
loss = jsd_loss
@@ -493,6 +581,8 @@ def forward_step(self, data_iterator, model):
data = next(data_iterator)
data_source = data.pop('data_source', DataSource.DATASET)
teacher_logits = data.pop('teacher_logits', None)
+ teacher_api_logprobs = data.pop('teacher_api_logprobs', None)
+ teacher_api_indices = data.pop('teacher_api_indices', None)
data = self._prepare_batch(data, vp_stage)
data.pop('loss_scale', None)
@@ -503,4 +593,10 @@ def forward_step(self, data_iterator, model):
student_output = model(**data)
return student_output, partial(
- self.loss_func, labels=labels, teacher_logits=teacher_logits, data_source=data_source)
+ self.loss_func,
+ labels=labels,
+ teacher_logits=teacher_logits,
+ teacher_api_logprobs=teacher_api_logprobs,
+ teacher_api_indices=teacher_api_indices,
+ data_source=data_source,
+ )
diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py
index b22f3ab3c9..fb020f5df8 100644
--- a/swift/pipelines/train/rlhf.py
+++ b/swift/pipelines/train/rlhf.py
@@ -230,8 +230,12 @@ def _get_trainer_kwargs(self):
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
if self.args.chord_sft_dataset:
trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset()
- if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed:
- trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
+ if self.args.rlhf_type == 'gkd':
+ if self.args.teacher_deepspeed:
+ trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed
+ trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk
+ if self.args.teacher_model_server:
+ trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server
return trainer_kwargs
diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py
index 205583f7d9..f02672d57b 100644
--- a/swift/rlhf_trainers/gkd_trainer.py
+++ b/swift/rlhf_trainers/gkd_trainer.py
@@ -51,12 +51,17 @@ class DataSource(str, Enum):
DATASET = 'dataset' # Off-policy: use dataset responses
+teacher_model_server_model_name = None
+
+
class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer):
def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
- teacher_model = kwargs.pop('teacher_model')
+ teacher_model = kwargs.pop('teacher_model', None)
teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None)
self.vllm_client = kwargs.pop('vllm_client', None)
+ self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None)
+ teacher_model_server = kwargs.pop('teacher_model_server', None)
super().__init__(model, None, *_args, **kwargs)
args = kwargs['args']
self.lmbda = args.lmbda
@@ -66,32 +71,36 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
self._total_train_tokens = 0
+ self.teacher_model_server = teacher_model_server
+ self.use_teacher_api = teacher_model_server is not None
+
# Initialize logging components
self._prepare_logging()
- # Initialize liger loss
+ # Initialize liger loss if enabled
self._prepare_liger_loss()
self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation
self.is_teacher_ds3 = None
# Initialize teacher model
- if self.is_deepspeed_enabled:
- if teacher_deepspeed_config is not None:
- self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3
- if not self.is_teacher_ds3:
- self.teacher_ds3_gather_for_generation = False
- self.teacher_model = prepare_deepspeed(
- teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args)
+ if teacher_model is not None:
+ if self.is_deepspeed_enabled:
+ if teacher_deepspeed_config is not None:
+ self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3
+ if not self.is_teacher_ds3:
+ self.teacher_ds3_gather_for_generation = False
+ self.teacher_model = prepare_deepspeed(
+ teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args)
+ else:
+ self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
+ elif self.is_fsdp_enabled:
+ from .utils import prepare_fsdp
+ self.teacher_model = prepare_fsdp(teacher_model, self.accelerator)
else:
- self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
- elif self.is_fsdp_enabled:
- from .utils import prepare_fsdp
- self.teacher_model = prepare_fsdp(teacher_model, self.accelerator)
- else:
- self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
- self.teacher_model.eval()
- if self.args.offload_teacher_model:
- self.offload_model(self.accelerator.unwrap_model(self.teacher_model))
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
+ self.teacher_model.eval()
+ if self.args.offload_teacher_model:
+ self.offload_model(self.accelerator.unwrap_model(self.teacher_model))
# Initialize rollout infrastructure for vLLM support
self.prepare_rollout()
@@ -103,7 +112,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
else:
self.maybe_activation_offload_context = nullcontext()
- self._trl_version_gte_0_24 = version.parse(trl.__version__) >= version.parse('0.24')
# Initialize resample data iterator for truncation_strategy 'raise'('delete')
if self.template.truncation_strategy == 'raise':
@@ -156,6 +164,10 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# Get data source: DataSource.STUDENT, DataSource.TEACHER, or DataSource.DATASET
data_source = inputs.pop('_data_source', DataSource.DATASET)
+ # Get teacher logprobs from API if available (set in training_step)
+ teacher_api_logprobs = inputs.pop('_teacher_api_logprobs', None)
+ teacher_api_indices = inputs.pop('_teacher_api_indices', None)
+
model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}}
# If generate is used, then use_logits_to_keep must be set to False.
use_logits_to_keep = self.get_use_logits_to_keep(True)
@@ -222,11 +234,49 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)
# loss / grad norm is unexpectedly large, normalize by sequence length
# https://github.com/linkedin/Liger-Kernel/blob/v0.6.3/src/liger_kernel/chunked_loss/jsd_loss.py#L9-L39
- loss /= student_hidden.shape[1]
+ # loss /= student_hidden.shape[1]
# Release hidden states after loss computation
del student_hidden, teacher_hidden, true_labels
+ outputs_student = None
+ elif self.use_teacher_api:
+ assert teacher_api_logprobs is not None
+ if self.args.sft_alpha > 0:
+ model_inputs['labels'] = inputs['labels']
+ outputs_student = model(**model_inputs)
+
+ # teacher_api shape: [batch, seq_len-1, topk]
+ # teacher[i] = P(token[i+1] | token[0..i]), matching logits[i].
+ # But teacher has seq_len-1 positions (no logits[seq_len-1] equivalent).
+ # Pad a -inf row at the end so teacher becomes [batch, seq_len, topk]
+ # (the last position will be masked out by shifted_labels = -100).
+ teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf'))
+ teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0)
+ # Now teacher is [batch, seq_len, topk], same as full logits.
+ # Apply logits_to_keep to truncate teacher the same way model truncates logits.
+ logits_to_keep = inputs.get('logits_to_keep')
+ if logits_to_keep is not None:
+ if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool:
+ teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep]
+ teacher_api_indices = teacher_api_indices[:, logits_to_keep]
+ else:
+ n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep)
+ teacher_api_logprobs = teacher_api_logprobs[:, -n:]
+ teacher_api_indices = teacher_api_indices[:, -n:]
+ shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1)
+
+ loss = self.generalized_jsd_loss(
+ student_logits=outputs_student.logits,
+ labels=shifted_labels,
+ beta=self.beta,
+ temperature=self.temperature,
+ teacher_topk_logprobs=teacher_api_logprobs,
+ teacher_topk_indices=teacher_api_indices,
+ )
+
+ if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT:
+ loss = loss + self.args.sft_alpha * outputs_student.loss
else:
- # Standard loss computation
+ # Standard loss computation (local teacher model)
if self.args.sft_alpha > 0:
model_inputs['labels'] = inputs['labels']
# compute student output
@@ -239,35 +289,47 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
outputs_teacher = self.teacher_model(**model_inputs)
shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1)
- mask = shifted_labels != -100
- shifted_student_logits = outputs_student.logits[mask][None]
- shifted_teacher_logits = outputs_teacher.logits[mask][None]
-
- # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct.
- stu_dim = shifted_student_logits.shape[-1]
- tea_dim = shifted_teacher_logits.shape[-1]
- if stu_dim < tea_dim:
- shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0)
- shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:]
- elif stu_dim > tea_dim:
- shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0)
- shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:]
-
- # compute loss
- loss = self.generalized_jsd_loss(
- student_logits=shifted_student_logits,
- teacher_logits=shifted_teacher_logits,
- beta=self.beta,
- )
+
+ if self.gkd_logits_topk is not None:
+ # Top-k mode with local teacher
+ loss = self.generalized_jsd_loss(
+ student_logits=outputs_student.logits,
+ teacher_logits=outputs_teacher.logits,
+ labels=shifted_labels,
+ beta=self.beta,
+ temperature=self.temperature,
+ topk=self.gkd_logits_topk,
+ )
+ else:
+ # Full vocabulary mode
+ mask = shifted_labels != -100
+ shifted_student_logits = outputs_student.logits[mask][None]
+ shifted_teacher_logits = outputs_teacher.logits[mask][None]
+
+ # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct.
+ stu_dim = shifted_student_logits.shape[-1]
+ tea_dim = shifted_teacher_logits.shape[-1]
+ if stu_dim < tea_dim:
+ shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0)
+ shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:]
+ elif stu_dim > tea_dim:
+ shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0)
+ shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:]
+
+ # compute loss
+ loss = self.generalized_jsd_loss(
+ student_logits=shifted_student_logits,
+ teacher_logits=shifted_teacher_logits,
+ beta=self.beta,
+ temperature=self.temperature,
+ )
+
# Add SFT loss if enabled (skip for student-generated responses)
if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT:
loss = loss + self.args.sft_alpha * outputs_student.loss
# Return loss
if return_outputs:
- if self.use_liger_gkd_loss:
- # outputs has been released in liger loss computation to reduce peak memory
- outputs_student = None
return (loss, outputs_student)
else:
return loss
@@ -388,13 +450,37 @@ def training_step(self,
# Mark data source for downstream processing (e.g., conditional SFT loss)
encoded_inputs['_data_source'] = data_source
+ # Fetch teacher logprobs from API if using external teacher service
+ if self.use_teacher_api:
+ teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs)
+ encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs
+ encoded_inputs['_teacher_api_indices'] = teacher_indices
+
with self.template.forward_context(self.model, encoded_inputs):
loss = HFSFTTrainer.training_step(self, model, encoded_inputs, num_items_in_batch)
return loss
+ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tensor]):
+ """Fetch teacher logprobs from external API service.
+
+ Returns:
+ Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk]
+ """
+ input_ids = encoded_inputs['input_ids']
+ teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
+ self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk)
+ return teacher_logprobs.to(input_ids.device), teacher_indices.to(input_ids.device)
+
def prediction_step(self, model, inputs, *args, **kwargs):
# Prediction uses full messages
encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False)
+
+ # Fetch teacher logprobs from API if using external teacher service (for eval)
+ if self.use_teacher_api:
+ teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs)
+ encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs
+ encoded_inputs['_teacher_api_indices'] = teacher_indices
+
with self.template.forward_context(self.model, encoded_inputs):
return super().prediction_step(model, encoded_inputs, *args, **kwargs)
@@ -459,7 +545,7 @@ def _prepare_liger_loss(self):
raise ImportError(
'Liger kernel is not installed. Please install liger-kernel by running: pip install liger-kernel')
assert self.args.sft_alpha == 0, 'SFT loss is not supported with liger loss'
-
+ assert self.gkd_logits_topk is None, 'Top-k mode is not supported with liger loss'
self.liger_jsd_loss = LigerFusedLinearJSDLoss(
beta=self.beta,
ignore_index=-100,
@@ -471,24 +557,45 @@ def _prepare_liger_loss(self):
@staticmethod
def generalized_jsd_loss(
student_logits,
- teacher_logits,
+ teacher_logits=None,
labels=None,
beta=0.5,
temperature=1.0,
chunk_size=512,
+ topk=None,
+ teacher_topk_logprobs=None,
+ teacher_topk_indices=None,
):
- # Apply temperature scaling
+ # Top-k mode: reduce logits to [*, k] before the standard JSD pipeline
+ if teacher_topk_logprobs is not None and teacher_topk_indices is not None:
+ # API teacher: gather student logits at teacher's top-k indices, then both
+ # get scaled by 1/T and re-normalized over top-k via downstream log_softmax.
+ # vLLM logprobs = log_softmax(logits) at T=1; treating them as logit-like
+ # scores and dividing by T then re-normalizing is equivalent to
+ # log_softmax(logits/T) over top-k (by shift-invariance of softmax).
+ s_scaled = student_logits / temperature
+ student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices)
+ teacher_logits = teacher_topk_logprobs / temperature
+ del s_scaled
+ temperature = 1.0
+ elif topk is not None and teacher_logits is not None:
+ # Local teacher: select top-k from teacher, gather corresponding student logits
+ t_scaled = teacher_logits / temperature
+ s_scaled = student_logits / temperature
+ teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1)
+ student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx)
+ del t_scaled, s_scaled, topk_idx
+ temperature = 1.0
+
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
- # Apply masking if labels provided
if labels is not None:
mask = labels != -100
student_logits = student_logits[mask]
teacher_logits = teacher_logits[mask]
num_valid = mask.sum()
else:
- # Flatten to [num_tokens, vocab_size]
student_logits = student_logits.view(-1, student_logits.size(-1))
teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1))
num_valid = student_logits.size(0)
@@ -499,7 +606,6 @@ def generalized_jsd_loss(
num_valid_int = num_valid if isinstance(num_valid, int) else num_valid.item()
total_loss = student_logits.new_zeros(())
- # Precompute beta tensor once if needed
if beta != 0 and beta != 1:
beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device)
log_beta = torch.log(beta_t)
@@ -507,7 +613,6 @@ def generalized_jsd_loss(
else:
beta_t = log_beta = log_1_minus_beta = None
- # Process in chunks to reduce peak memory
for start_idx in range(0, num_valid_int, chunk_size):
end_idx = min(start_idx + chunk_size, num_valid_int)
s_chunk = student_logits[start_idx:end_idx]
@@ -526,11 +631,9 @@ def generalized_jsd_loss(
torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]),
dim=0,
)
-
kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True)
kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True)
del mixture_log_probs
-
jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student
del kl_teacher, kl_student
@@ -606,3 +709,83 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non
row = [table[header][i] for header in headers]
rows.append(row)
swanlab.log({'completions': swanlab.echarts.Table().add(headers, rows)})
+
+
+def fetch_teacher_logprobs(base_url, input_ids, topk=20, timeout=300.0):
+ """Fetch top-k prompt logprobs from a vLLM-compatible /v1/completions endpoint.
+
+ Uses prompt_logprobs to get logprobs for input tokens without generating.
+ vLLM prompt_logprobs are always raw (temperature=1) log-probabilities from the model;
+ the temperature parameter in the API only affects token sampling, not prompt_logprobs.
+
+ Args:
+ base_url: vLLM server URL (e.g., 'http://localhost:8000').
+ input_ids: List of token ID sequences.
+ topk: Number of top log probabilities per token.
+ timeout: Request timeout in seconds.
+
+ Returns:
+ (logprobs, indices) tensors of shape [batch, max_seq_len - 1, topk].
+ The shift is because prompt_logprobs[0] is always None (first token has no
+ conditional probability), so position i in the output corresponds to
+ P(token_{i+1} | token_0..token_i), aligning with model logits[i].
+ """
+ import logging
+ import requests
+ from concurrent.futures import ThreadPoolExecutor
+
+ _logger = logging.getLogger(__name__)
+ base_url = base_url.rstrip('/')
+ batch_size = len(input_ids)
+ max_seq_len = max(len(ids) for ids in input_ids)
+ url = f'{base_url}/v1/completions'
+ global teacher_model_server_model_name
+ if teacher_model_server_model_name is None:
+ try:
+ resp = requests.get(f'{base_url}/v1/models', timeout=10)
+ model = resp.json()['data'][0]['id'] if resp.ok else 'default'
+ except Exception:
+ model = 'default'
+ teacher_model_server_model_name = model
+ else:
+ model = teacher_model_server_model_name
+
+ # prompt_logprobs[0] is always None (no conditional prob for the first token),
+ # prompt_logprobs[i] = P(token_i | token_0..token_{i-1}) which aligns with logits[i-1].
+ # So we skip position 0 and the result has shape [batch, max_seq_len-1, topk],
+ # aligning with student logits which predict the next token at each position.
+ out_len = max_seq_len - 1
+ logprobs_out = torch.full((batch_size, out_len, topk), float('-inf'), dtype=torch.float32)
+ indices_out = torch.zeros((batch_size, out_len, topk), dtype=torch.long)
+
+ def _fetch_one(batch_idx):
+ payload = {
+ 'model': model,
+ 'prompt': input_ids[batch_idx],
+ 'max_tokens': 1,
+ 'temperature': 0,
+ 'prompt_logprobs': topk,
+ }
+ try:
+ resp = requests.post(url, json=payload, timeout=timeout)
+ resp.raise_for_status()
+ prompt_logprobs_list = resp.json()['choices'][0].get('prompt_logprobs', [])
+ # Skip position 0 (always None), shift left so pos 1 -> output pos 0
+ for raw_pos in range(1, len(prompt_logprobs_list)):
+ pos_lp = prompt_logprobs_list[raw_pos]
+ if pos_lp is None:
+ continue
+ out_pos = raw_pos - 1
+ if out_pos >= out_len:
+ break
+ sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1]['logprob'])[:topk]
+ for k, (tid_str, info) in enumerate(sorted_items):
+ indices_out[batch_idx, out_pos, k] = int(tid_str)
+ logprobs_out[batch_idx, out_pos, k] = info['logprob']
+ except Exception as e:
+ _logger.error(f'Failed to get teacher logprobs for sequence {batch_idx}: {e}')
+
+ with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool:
+ list(pool.map(_fetch_one, range(batch_size)))
+
+ return logprobs_out, indices_out
diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py
index 2484df4e77..402bf7dae6 100644
--- a/swift/rlhf_trainers/utils.py
+++ b/swift/rlhf_trainers/utils.py
@@ -566,10 +566,16 @@ def profiling_context(trainer, name: str):
profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration}
- if 'wandb' in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
+ is_main_process = False
+ if hasattr(trainer, 'accelerator'):
+ is_main_process = trainer.accelerator.is_main_process
+ elif hasattr(trainer, 'is_main_process'):
+ is_main_process = trainer.is_main_process
+
+ if 'wandb' in trainer.args.report_to and wandb.run is not None and is_main_process:
wandb.log(profiling_metrics)
- if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and trainer.accelerator.is_main_process:
+ if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and is_main_process:
swanlab.log(profiling_metrics)