diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index 25e4545802..f8d455a93a 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -139,9 +139,11 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) 我们可以通过设置以下参数进行 GKD 训练: +### 基础参数 + | 参数 | 类型 | 默认值 | 取值范围 | 说明 | |------|------|--------|---------|------| -| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID | +| `--teacher_model` | str | None | - | 教师模型路径或模型 ID | | `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数
• 0.0: Forward KL
• 0.5: JSD (平衡)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率
• 0.0: 离线学习
• 0.5: 混合策略
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | True/False | 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成 | @@ -149,6 +151,79 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) | `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 | | `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 | +### Top-K KL 计算 + +默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 **Top-K** 模式来减少显存占用和计算量。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--gkd_logits_topk` | int | None | Top-K logits 数量
• None: 使用完整词表(默认)
• 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL | + +**Top-K 模式原理**: + +在 Top-K 模式下,选取**教师模型**输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。 +$$ +D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) +$$ + +其中 Top-K 索引来自教师模型:$\text{Top-K} = \text{argtop}_K(P_T)$,$\tilde{P}_T$ 和 $\tilde{P}_S$ 是在 Top-K 子集上**重新归一化**的概率分布: + +$$ +\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} +$$ + +**使用示例**: + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B-Instruct \ + --teacher_model Qwen/Qwen2.5-72B-Instruct \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + ... +``` + +> **注意**:Top-K 模式不能与 liger kernel 同时使用(`--use_liger_kernel`)。 + +### 外部教师模型 API + +当设置 `gkd_logits_topk` 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--teacher_model_server` | str | None | 教师模型服务地址
如:`http://localhost:8000` | +| `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 | + + +**步骤 1:部署教师模型服务** + +```bash +# 使用 vllm serve 部署教师模型 +CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ + --port 8000 \ + --max-logprobs 64 \ + --gpu-memory-utilization 0.9 +``` + +**步骤 2:启动 GKD 训练** + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + --lmbda 1.0 \ + --beta 1.0 \ + ... +``` + +> **vLLM max_logprobs 限制**: +> - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整 +> - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置 + ## 采样加速 在 GKD 训练中,涉及到两种在线采样的情况: @@ -168,6 +243,8 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) 训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh) +使用 Teacher Server 的训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh) + ### 方案 2:教师模型预采样 对于教师模型采样(`seq_kd=True`),推荐使用 **预采样** 方式:先用教师模型离线生成高质量数据,再进行训练。 diff --git a/docs/source/Megatron-SWIFT/GKD.md b/docs/source/Megatron-SWIFT/GKD.md index 9cc30004d7..9a023cb456 100644 --- a/docs/source/Megatron-SWIFT/GKD.md +++ b/docs/source/Megatron-SWIFT/GKD.md @@ -33,7 +33,9 @@ Megatron GKD 当前已支持以下功能: | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| -| `--teacher_model` | str | 必需 | 教师模型路径或模型 ID | +| `--teacher_model` | str | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 | +| `--teacher_model_server` | str | None | 教师模型服务地址(仅支持 `vllm serve`),如 `http://localhost:8000` | +| `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 | | `--beta` | float | 0.5 | JSD 散度插值系数:
• 0.0: Forward KL
• 0.5: 对称 JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy 学习触发概率:
• 0.0: 纯 Off-Policy
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | 是否使用教师生成的响应(当前暂不支持) | @@ -71,3 +73,5 @@ GKD 支持三种训练模式,通过 `lmbda` 和 `seq_kd` 参数控制: 更多参数请参考[命令行文档](./Command-line-parameters.md) 训练脚本请参考 [Megatron GKD 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd) + +使用 Teacher Server 的训练脚本请参考 [这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh) diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index 9bff8a81eb..cad177cae8 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -139,9 +139,11 @@ Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model We can perform GKD training by setting the following parameters: +### Basic Parameters + | Parameter | Type | Default | Range | Description | |------|------|--------|---------|------| -| `--teacher_model` | str | Required | - | Teacher model path or model ID | +| `--teacher_model` | str | None | - | Teacher model path or model ID
*Can be omitted when using `teacher_model_server` | | `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient
• 0.0: Forward KL
• 0.5: JSD (balanced)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability
• 0.0: Pure Offline
• 0.5: Mixed strategy (**recommended**)
• 1.0: Pure On-Policy | | `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences
• False: Use dataset when not on-policy
• True: Use teacher generation when not on-policy | @@ -149,6 +151,84 @@ We can perform GKD training by setting the following parameters: | `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions | | `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation | +### Top-K KL Computation + +By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use **Top-K** mode to reduce memory usage and computation. + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `--gkd_logits_topk` | int | None | Number of Top-K logits
• None: Use full vocabulary (default)
• Positive integer: Only use the K tokens with highest teacher probability for KL computation | + +**Top-K Mode Principle**: + +In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. + +$$ +D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) +$$ + +Where the Top-K indices come from the teacher model: $\text{Top-K} = \text{argtop}_K(P_T)$, and $\tilde{P}_T$ and $\tilde{P}_S$ are the probability distributions **renormalized** over the Top-K subset: + +$$ +\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} +$$ + +**Usage Example**: + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B-Instruct \ + --teacher_model Qwen/Qwen2.5-14B-Instruct \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + ... +``` + +> **Note**: Top-K mode cannot be used with liger kernel (`--use_liger_kernel`). + +### External Teacher Model API + +When `gkd_logits_topk` is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process. + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `--teacher_model_server` | str | None | Teacher model service URL
e.g., `http://localhost:8000` | +| `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API | + +**Supported Backends**: +- `vllm serve` (recommended) + +> **Note**: Only `vllm serve` is supported as the teacher server backend. The training code sends raw token IDs via the `prompt` field and uses the `prompt_logprobs` parameter in the `/v1/completions` API to obtain input token log-probabilities. This is a vLLM-native feature. + +**Step 1: Deploy Teacher Model Service** + +```bash +# Deploy teacher model with vllm serve +CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ + --port 8000 \ + --max-logprobs 64 \ + --gpu-memory-utilization 0.9 +``` + +**Step 2: Start GKD Training** + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + --lmbda 1.0 \ + --beta 1.0 \ + ... +``` + +> **vLLM max_logprobs Limitation**: +> - vLLM default `max_logprobs=20`, adjustable via `--max-logprobs N` parameter +> - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting + ## Sampling Acceleration In GKD training, there are two types of online sampling scenarios: @@ -168,6 +248,8 @@ Use vLLM as the inference backend to accelerate student model sampling. Supports Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode). +Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh). + ### Solution 2: Teacher Model Pre-sampling diff --git a/docs/source_en/Megatron-SWIFT/GKD.md b/docs/source_en/Megatron-SWIFT/GKD.md index 37c08eea5a..0502b8485b 100644 --- a/docs/source_en/Megatron-SWIFT/GKD.md +++ b/docs/source_en/Megatron-SWIFT/GKD.md @@ -33,7 +33,9 @@ Megatron GKD currently supports the following features: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `--teacher_model` | str | Required | Path or model ID of the teacher model | +| `--teacher_model` | str | - | Path or model ID of the teacher model
*Can be omitted when using `teacher_model_server` | +| `--teacher_model_server` | str | None | Teacher model service URL (`vllm serve` only), e.g. `http://localhost:8000` | +| `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API | | `--beta` | float | 0.5 | JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy | | `--seq_kd` | bool | False | Use teacher-generated responses (not yet supported) | @@ -71,3 +73,5 @@ GKD supports three training modes, controlled by `lmbda` and `seq_kd` parameters For more parameters, please refer to [Command-line Parameters](./Command-line-parameters.md) For training scripts, please refer to [Megatron GKD Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd) + +Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh) diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh new file mode 100644 index 0000000000..2ccb05b8a1 --- /dev/null +++ b/examples/megatron/rlhf/gkd/teacher_server.sh @@ -0,0 +1,43 @@ +# Teacher server must be running first: +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct --port 8000 --max-logprobs 64 + +CUDA_VISIBLE_DEVICES=1,2 \ +NPROC_PER_NODE=2 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-0.5B \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 64 \ + --dataset 'modelscope/gsm8k' \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --context_parallel_size 1 \ + --expert_model_parallel_size 1 \ + --lmbda 1 \ + --seq_kd false \ + --beta 0.5 \ + --torch_dtype bfloat16 \ + --micro_batch_size 2 \ + --global_batch_size 32 \ + --train_iters 500 \ + --lr 5e-5 \ + --lr_warmup_fraction 0.1 \ + --logging_steps 1 \ + --save_steps 100 \ + --save_total_limit 10 \ + --max_length 2048 \ + --max_completion_length 2048 \ + --attention_backend flash \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.5 \ + --vllm_tensor_parallel_size 1 \ + --vllm_max_model_len 4096 \ + --sleep_level 1 \ + --finetune \ + --no_save_optim \ + --no_save_rng \ + --temperature 1.0 \ + --padding_free true \ + --recompute_granularity selective diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh new file mode 100644 index 0000000000..cbb56680f9 --- /dev/null +++ b/examples/train/rlhf/gkd/teacher_server.sh @@ -0,0 +1,44 @@ +# GKD Training with External Teacher Model Server (vLLM) +# ===================== Step 1: Start Teacher Server ===================== +# Run in a separate terminal / GPU: +# +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \ +# --port 8000 \ +# --max-logprobs 64 \ +# --gpu-memory-utilization 0.9 + +# ======================================================================== + +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-0.5B \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 64 \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.5 \ + --vllm_tensor_parallel_size 1 \ + --vllm_max_model_len 4096 \ + --sleep_level 0 \ + --dataset 'modelscope/gsm8k' \ + --lmbda 1 \ + --seq_kd false \ + --beta 0.5 \ + --torch_dtype bfloat16 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --learning_rate 5e-5 \ + --logging_steps 1 \ + --save_steps 100 \ + --save_total_limit 2 \ + --max_length 2048 \ + --max_completion_length 2048 \ + --warmup_ratio 0.1 \ + --save_only_model true \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 \ + --attn_impl flash_attn \ + --report_to tensorboard swanlab diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index a3a63b7090..50017b92ad 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -39,7 +39,8 @@ class TeacherModelArguments: Args: teacher_model (Optional[str]): The model ID or a local path to the teacher model. This is required when - `rlhf_type` is 'gkd'. Analogous to the main `model` argument. Defaults to None. + `rlhf_type` is 'gkd' and `teacher_model_server` is not set. Analogous to the main `model` argument. + Defaults to None. teacher_adapters (List[str]): A list of paths to LoRA weights. These weights, often produced by SFT, are loaded to form the teacher model. Defaults to an empty list (`[]`). teacher_model_type (Optional[str]): The model type of the teacher model. If not specified, it's often inferred. @@ -50,6 +51,10 @@ class TeacherModelArguments: one of the following values: 'zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'. If not provided, it defaults to using the same DeepSpeed configuration as the main training model. Analogous to the main `deepspeed` argument. + teacher_model_server (Optional[str]): The URL of the teacher model server (e.g., 'http://localhost:8000'). + When set, the teacher logprobs will be fetched from the external API service (e.g., swift deploy, vLLM) + instead of loading a local teacher model. This enables using larger teacher models or services hosted + remotely. When this is set, `teacher_model` is not required. Defaults to None. """ teacher_model: Optional[str] = None teacher_adapters: List[str] = field(default_factory=list) @@ -63,6 +68,13 @@ class TeacherModelArguments: 'DeepSpeed configuration for teacher model. ' 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' }) + teacher_model_server: Optional[str] = field( + default=None, + metadata={ + 'help': + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' + }) @dataclass @@ -196,6 +208,11 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo gkd_loss + sft_alpha * sft_loss`. Defaults to 0. lmbda (float): The lambda parameter for GKD, balancing policy and value losses. Defaults to 0.5. seq_kd (bool): Whether to use sequence-level knowledge distillation for GKD. Defaults to False. + gkd_logits_topk (Optional[int]): The number of top-k logits to use for KL divergence computation in GKD. + If None, uses full vocabulary for KL computation (more accurate but memory-intensive). + If set to a positive integer, only top-k teacher logits are used (more efficient). + When using `teacher_model_server`, this is limited by the server's `max_logprobs` setting + (vLLM default is 20, can be increased with `--max-logprobs`). Defaults to None. offload_teacher_model (bool): Whether to offload the teacher model to CPU memory to save VRAM during GKD training. Defaults to False. max_new_tokens (Optional[int]): A backward-compatibility argument. Please use `max_completion_length` instead. @@ -233,6 +250,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo sft_alpha: float = 0 lmbda: float = 0.5 seq_kd: bool = False + gkd_logits_topk: Optional[int] = None offload_teacher_model: bool = False # compat max_new_tokens: Optional[int] = None # use max_completion_length instead @@ -545,3 +563,25 @@ def _check_gkd(self): if self.async_generate: raise NotImplementedError('Currently, async_generate is not supported for GKD.') + + # Validate teacher model configuration + if self.teacher_model is None and self.teacher_model_server is None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.') + + if self.teacher_model is not None and self.teacher_model_server is not None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.') + + # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs) + if self.teacher_model_server is not None: + if self.gkd_logits_topk is None: + raise ValueError('gkd_logits_topk is required when using teacher_model_server') + + # Validate gkd_logits_topk + if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0: + raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}') + + if self.gkd_logits_topk is not None and self.use_liger_kernel: + raise ValueError('gkd_logits_topk is not supported when using liger kernel') + + if self.teacher_model_server and self.seq_kd: + raise NotImplementedError('Sequential KD is not supported when using teacher_model_server') diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 082ba37ee4..033352bbdd 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -45,6 +45,14 @@ class RLHFMegatronArgumentsMixin: teacher_model: Optional[str] = field(default=None) teacher_model_type: Optional[str] = field(default=None) teacher_model_revision: Optional[str] = field(default=None) + teacher_model_server: Optional[str] = field( + default=None, + metadata={ + 'help': + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' + }) + gkd_logits_topk: Optional[int] = None lmbda: float = 0.5 # On-policy probability: with prob lmbda, use student-generated responses seq_kd: bool = False # Sequential KD: use teacher-generated responses when not on-policy offload_teacher_model: bool = False # Offload teacher model to CPU to save GPU memory @@ -194,6 +202,21 @@ def __post_init__(self): if self.vllm_limit_mm_per_prompt is not None: self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) + if self.rlhf_type == 'gkd': + if self.teacher_model is None and self.teacher_model_server is None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.') + + if self.teacher_model is not None and self.teacher_model_server is not None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.') + + # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs) + if self.teacher_model_server is not None: + if self.gkd_logits_topk is None: + raise ValueError('gkd_logits_topk is required when using teacher_model_server') + + # Validate gkd_logits_topk + if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0: + raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}') def _init_grpo(self): diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index add6aea39d..09caf85106 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -41,10 +41,25 @@ def __init__(self, args: MegatronArguments, template, **kwargs): self.lmbda = args.lmbda # On-policy probability self.seq_kd = args.seq_kd # Sequential KD: use teacher-generated responses self.offload_teacher_model = args.offload_teacher_model # Offload teacher to CPU - self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_') - self.teacher_config = self.teacher_bridge.processor.model_info.config + self.teacher_model_server = getattr(args, 'teacher_model_server', None) + self.use_teacher_api = self.teacher_model_server is not None + if args.teacher_model: + self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_') + self.teacher_config = self.teacher_bridge.processor.model_info.config self.sft_alpha = getattr(args, 'sft_alpha', 0.0) # Weight for SFT loss - assert args.teacher_model is not None, 'Teacher model path is required for GKD training' + + # GKD top-k logits configuration + self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) + # Check use_teacher_api based on args, not client existence + # (API client is only created on last rank, but all ranks need to know the mode) + + # Validate teacher configuration + if not self.use_teacher_api: + assert args.teacher_model is not None, \ + 'Teacher model path is required for GKD training (or set teacher_model_server for API mode)' + else: + logger.info(f'Using teacher model API for logprobs, top_logprobs={self.gkd_logits_topk}') + self.use_vllm = getattr(args, 'use_vllm', False) super().__init__(args, template) @@ -67,6 +82,9 @@ def train(self, train_dataset, val_dataset): def prepare_model(self): super().prepare_model() + if self.use_teacher_api: + logger.info('Skipping local teacher model loading - using external API for teacher logprobs') + return args = self.args vp_size = getattr(args, 'virtual_pipeline_model_parallel_size') assert vp_size is None or vp_size == 1, 'GKD currently does not support VPP.' @@ -83,12 +101,12 @@ def prepare_model(self): def _offload_teacher_models(self): """Offload teacher models to CPU to save GPU memory.""" - if self.teacher_models: + if self.teacher_models and not self.use_teacher_api: offload_megatron_model_to_cpu(self.teacher_models) def _load_teacher_models_to_gpu(self): """Load teacher models back to GPU.""" - if self.teacher_models: + if self.teacher_models and not self.use_teacher_api: load_megatron_model_to_gpu(self.teacher_models, load_grad=False) @contextmanager @@ -247,22 +265,49 @@ def resample_encode_failed_inputs(self, inputs: List[Dict], max_resample_rounds: return valid_samples[:required_count] def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: + if self.use_teacher_api: + self._compute_teacher_logits_from_api(encoded_batches) + else: + self._compute_teacher_logits_local(encoded_batches, vp_stage) + + def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: teacher_model = self.teacher_models[vp_stage or 0] + topk = self.gkd_logits_topk for encoded_batch in encoded_batches: - # Deep copy to avoid modifying original batch teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()} teacher_batch.pop('data_source', None) teacher_data = self._prepare_batch(teacher_batch) teacher_data.pop('loss_scale', None) - # Remove labels so returns logits instead of loss teacher_data.pop('labels', None) - # Teacher forward with args override for correct hidden_size with self.load_teacher_model_context(), torch.no_grad(): teacher_logits = forward_step_helper(self.args, teacher_model, teacher_data) if teacher_logits is not None: teacher_logits = teacher_logits.detach() - encoded_batch['teacher_logits'] = teacher_logits + + if topk is not None and teacher_logits is not None: + topk_logits, topk_indices = torch.topk(teacher_logits, k=topk, dim=-1) + encoded_batch['teacher_api_logprobs'] = topk_logits + encoded_batch['teacher_api_indices'] = topk_indices + encoded_batch['teacher_logits'] = None + else: + encoded_batch['teacher_logits'] = teacher_logits + + def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None: + """Fetch teacher logprobs from external API service.""" + from swift.rlhf_trainers.gkd_trainer import fetch_teacher_logprobs + topk = self.gkd_logits_topk + for encoded_batch in encoded_batches: + input_ids = encoded_batch['input_ids'] + teacher_logprobs, teacher_indices = fetch_teacher_logprobs( + self.teacher_model_server, input_ids.tolist(), topk=topk) + # fetch_teacher_logprobs returns [batch, seq_len-1, topk] (shifted). + # Pad last position with -inf to match student [batch, seq_len, topk]. + teacher_logprobs = F.pad(teacher_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_indices = F.pad(teacher_indices, (0, 0, 0, 1), value=0) + encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device) + encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device) + encoded_batch['teacher_logits'] = None def _replace_data_iterator(self, data_iterator): num_microbatches = self.args.num_microbatches @@ -350,6 +395,8 @@ def generalized_jsd_loss( labels: torch.Tensor, beta: float = 0.5, chunk_size: int = 512, + teacher_topk_logprobs: torch.Tensor = None, + teacher_topk_indices: torch.Tensor = None, ) -> torch.Tensor: args = self.args mask = labels != -100 @@ -364,12 +411,20 @@ def generalized_jsd_loss( if num_valid == 0: return (student_logits.sum() * 0).reshape(()) + # Top-k mode: direct computation without vocab parallel + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: + total_loss = self._jsd_topk(student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta) + if args.context_parallel_size > 1: + torch.distributed.all_reduce( + total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) + return total_loss / num_valid + + # Full vocabulary mode (original code) # Align vocab size between student and teacher student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits) # Apply temperature scaling and mask - student_logits_masked = (student_logits - / self.temperature)[mask] # [local_num_valid_tokens, partition_vocab_size] + student_logits_masked = (student_logits / self.temperature)[mask] teacher_logits_masked = (teacher_logits / self.temperature)[mask] del student_logits, teacher_logits @@ -389,36 +444,28 @@ def generalized_jsd_loss( s_chunk = student_logits_masked[start_idx:end_idx] t_chunk = teacher_logits_masked[start_idx:end_idx] - # Compute log_softmax with vocab-parallel support s_log_probs = vocab_parallel_log_softmax(s_chunk) t_log_probs = vocab_parallel_log_softmax(t_chunk) del s_chunk, t_chunk if beta == 0: - # JSD = KL(teacher || student) jsd_chunk = vocab_parallel_kl_div(s_log_probs, t_log_probs) elif beta == 1: - # JSD = KL(student || teacher) jsd_chunk = vocab_parallel_kl_div(t_log_probs, s_log_probs) else: - # Compute mixture log probabilities: m = beta * teacher + (1-beta) * student - # log(m) = logsumexp(log(student) + log(1-beta), log(teacher) + log(beta)) mixture_log_probs = torch.logsumexp( torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), dim=0, ) - kl_teacher = vocab_parallel_kl_div(mixture_log_probs, t_log_probs) kl_student = vocab_parallel_kl_div(mixture_log_probs, s_log_probs) del mixture_log_probs - jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student del kl_teacher, kl_student total_loss = total_loss + jsd_chunk.sum() del jsd_chunk, s_log_probs, t_log_probs - # Clean up masked logits del student_logits_masked, teacher_logits_masked # All-reduce total_loss across CP group for correct sum @@ -428,18 +475,57 @@ def generalized_jsd_loss( return total_loss / num_valid + def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta): + """Compute JSD on teacher's top-k distribution. + + Both local and API teacher are handled uniformly: gather student logits at + teacher's top-k indices, scale by 1/T, and log_softmax over top-k subset. + By shift-invariance of log_softmax, this gives identical results whether + teacher_topk_logprobs contains raw logits (local) or raw logprobs (API). + + """ + s_scaled = student_logits / self.temperature + s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + t_topk = teacher_topk_logprobs / self.temperature + + s_topk_masked = s_topk[mask] + t_topk_masked = t_topk[mask] + + if s_topk_masked.numel() == 0: + return student_logits.new_zeros(()) + + t_log_p = F.log_softmax(t_topk_masked, dim=-1) + s_log_p = F.log_softmax(s_topk_masked, dim=-1) + t_p = torch.exp(t_log_p) + + if beta == 0: + jsd = (t_p * (t_log_p - s_log_p)).sum(dim=-1) + elif beta == 1: + s_p = torch.exp(s_log_p) + jsd = (s_p * (s_log_p - t_log_p)).sum(dim=-1) + else: + s_p = torch.exp(s_log_p) + m_log_p = torch.log(beta * t_p + (1 - beta) * s_p + 1e-10) + jsd = beta * (t_p * (t_log_p - m_log_p)).sum(-1) + (1 - beta) * (s_p * (s_log_p - m_log_p)).sum(-1) + + return jsd.sum() + def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, - teacher_logits: torch.Tensor, + teacher_logits: torch.Tensor = None, + teacher_api_logprobs: torch.Tensor = None, + teacher_api_indices: torch.Tensor = None, data_source: DataSource = DataSource.DATASET): """Compute GKD loss (JSD + optional SFT loss). Args: output_tensor: Student model logits [batch, seq_len, vocab_size] labels: Token labels for masking [batch, seq_len] - teacher_logits: Teacher model logits [batch, seq_len, vocab_size] + teacher_logits: Teacher model logits [batch, seq_len, vocab_size] (for local teacher) + teacher_api_logprobs: Teacher log probabilities [batch, seq_len, topk] (for API mode) + teacher_api_indices: Teacher token indices [batch, seq_len, topk] (for API mode) data_source: Data source (STUDENT/TEACHER/DATASET) for conditional SFT loss """ student_logits = output_tensor @@ -449,6 +535,8 @@ def loss_func(self, teacher_logits=teacher_logits, labels=labels, beta=self.beta, + teacher_topk_logprobs=teacher_api_logprobs, + teacher_topk_indices=teacher_api_indices, ) loss = jsd_loss @@ -493,6 +581,8 @@ def forward_step(self, data_iterator, model): data = next(data_iterator) data_source = data.pop('data_source', DataSource.DATASET) teacher_logits = data.pop('teacher_logits', None) + teacher_api_logprobs = data.pop('teacher_api_logprobs', None) + teacher_api_indices = data.pop('teacher_api_indices', None) data = self._prepare_batch(data, vp_stage) data.pop('loss_scale', None) @@ -503,4 +593,10 @@ def forward_step(self, data_iterator, model): student_output = model(**data) return student_output, partial( - self.loss_func, labels=labels, teacher_logits=teacher_logits, data_source=data_source) + self.loss_func, + labels=labels, + teacher_logits=teacher_logits, + teacher_api_logprobs=teacher_api_logprobs, + teacher_api_indices=teacher_api_indices, + data_source=data_source, + ) diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index b22f3ab3c9..fb020f5df8 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -230,8 +230,12 @@ def _get_trainer_kwargs(self): trainer_kwargs['reward_funcs'] = self.args.reward_funcs if self.args.chord_sft_dataset: trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset() - if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed: - trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + if self.args.rlhf_type == 'gkd': + if self.args.teacher_deepspeed: + trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk + if self.args.teacher_model_server: + trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 205583f7d9..f02672d57b 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -51,12 +51,17 @@ class DataSource(str, Enum): DATASET = 'dataset' # Off-policy: use dataset responses +teacher_model_server_model_name = None + + class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): - teacher_model = kwargs.pop('teacher_model') + teacher_model = kwargs.pop('teacher_model', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) + self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) + teacher_model_server = kwargs.pop('teacher_model_server', None) super().__init__(model, None, *_args, **kwargs) args = kwargs['args'] self.lmbda = args.lmbda @@ -66,32 +71,36 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 + self.teacher_model_server = teacher_model_server + self.use_teacher_api = teacher_model_server is not None + # Initialize logging components self._prepare_logging() - # Initialize liger loss + # Initialize liger loss if enabled self._prepare_liger_loss() self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation self.is_teacher_ds3 = None # Initialize teacher model - if self.is_deepspeed_enabled: - if teacher_deepspeed_config is not None: - self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3 - if not self.is_teacher_ds3: - self.teacher_ds3_gather_for_generation = False - self.teacher_model = prepare_deepspeed( - teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + if teacher_model is not None: + if self.is_deepspeed_enabled: + if teacher_deepspeed_config is not None: + self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3 + if not self.is_teacher_ds3: + self.teacher_ds3_gather_for_generation = False + self.teacher_model = prepare_deepspeed( + teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + else: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + from .utils import prepare_fsdp + self.teacher_model = prepare_fsdp(teacher_model, self.accelerator) else: - self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - from .utils import prepare_fsdp - self.teacher_model = prepare_fsdp(teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) - self.teacher_model.eval() - if self.args.offload_teacher_model: - self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + self.teacher_model.eval() + if self.args.offload_teacher_model: + self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) # Initialize rollout infrastructure for vLLM support self.prepare_rollout() @@ -103,7 +112,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) else: self.maybe_activation_offload_context = nullcontext() - self._trl_version_gte_0_24 = version.parse(trl.__version__) >= version.parse('0.24') # Initialize resample data iterator for truncation_strategy 'raise'('delete') if self.template.truncation_strategy == 'raise': @@ -156,6 +164,10 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Get data source: DataSource.STUDENT, DataSource.TEACHER, or DataSource.DATASET data_source = inputs.pop('_data_source', DataSource.DATASET) + # Get teacher logprobs from API if available (set in training_step) + teacher_api_logprobs = inputs.pop('_teacher_api_logprobs', None) + teacher_api_indices = inputs.pop('_teacher_api_indices', None) + model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}} # If generate is used, then use_logits_to_keep must be set to False. use_logits_to_keep = self.get_use_logits_to_keep(True) @@ -222,11 +234,49 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # loss / grad norm is unexpectedly large, normalize by sequence length # https://github.com/linkedin/Liger-Kernel/blob/v0.6.3/src/liger_kernel/chunked_loss/jsd_loss.py#L9-L39 - loss /= student_hidden.shape[1] + # loss /= student_hidden.shape[1] # Release hidden states after loss computation del student_hidden, teacher_hidden, true_labels + outputs_student = None + elif self.use_teacher_api: + assert teacher_api_logprobs is not None + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + # teacher_api shape: [batch, seq_len-1, topk] + # teacher[i] = P(token[i+1] | token[0..i]), matching logits[i]. + # But teacher has seq_len-1 positions (no logits[seq_len-1] equivalent). + # Pad a -inf row at the end so teacher becomes [batch, seq_len, topk] + # (the last position will be masked out by shifted_labels = -100). + teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) + # Now teacher is [batch, seq_len, topk], same as full logits. + # Apply logits_to_keep to truncate teacher the same way model truncates logits. + logits_to_keep = inputs.get('logits_to_keep') + if logits_to_keep is not None: + if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: + teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] + teacher_api_indices = teacher_api_indices[:, logits_to_keep] + else: + n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) + teacher_api_logprobs = teacher_api_logprobs[:, -n:] + teacher_api_indices = teacher_api_indices[:, -n:] + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + + loss = self.generalized_jsd_loss( + student_logits=outputs_student.logits, + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + teacher_topk_logprobs=teacher_api_logprobs, + teacher_topk_indices=teacher_api_indices, + ) + + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss else: - # Standard loss computation + # Standard loss computation (local teacher model) if self.args.sft_alpha > 0: model_inputs['labels'] = inputs['labels'] # compute student output @@ -239,35 +289,47 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_teacher = self.teacher_model(**model_inputs) shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) - mask = shifted_labels != -100 - shifted_student_logits = outputs_student.logits[mask][None] - shifted_teacher_logits = outputs_teacher.logits[mask][None] - - # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. - stu_dim = shifted_student_logits.shape[-1] - tea_dim = shifted_teacher_logits.shape[-1] - if stu_dim < tea_dim: - shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) - shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] - elif stu_dim > tea_dim: - shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) - shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] - - # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - beta=self.beta, - ) + + if self.gkd_logits_topk is not None: + # Top-k mode with local teacher + loss = self.generalized_jsd_loss( + student_logits=outputs_student.logits, + teacher_logits=outputs_teacher.logits, + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + topk=self.gkd_logits_topk, + ) + else: + # Full vocabulary mode + mask = shifted_labels != -100 + shifted_student_logits = outputs_student.logits[mask][None] + shifted_teacher_logits = outputs_teacher.logits[mask][None] + + # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. + stu_dim = shifted_student_logits.shape[-1] + tea_dim = shifted_teacher_logits.shape[-1] + if stu_dim < tea_dim: + shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) + shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] + elif stu_dim > tea_dim: + shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) + shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] + + # compute loss + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + beta=self.beta, + temperature=self.temperature, + ) + # Add SFT loss if enabled (skip for student-generated responses) if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: loss = loss + self.args.sft_alpha * outputs_student.loss # Return loss if return_outputs: - if self.use_liger_gkd_loss: - # outputs has been released in liger loss computation to reduce peak memory - outputs_student = None return (loss, outputs_student) else: return loss @@ -388,13 +450,37 @@ def training_step(self, # Mark data source for downstream processing (e.g., conditional SFT loss) encoded_inputs['_data_source'] = data_source + # Fetch teacher logprobs from API if using external teacher service + if self.use_teacher_api: + teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs) + encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs + encoded_inputs['_teacher_api_indices'] = teacher_indices + with self.template.forward_context(self.model, encoded_inputs): loss = HFSFTTrainer.training_step(self, model, encoded_inputs, num_items_in_batch) return loss + def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tensor]): + """Fetch teacher logprobs from external API service. + + Returns: + Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk] + """ + input_ids = encoded_inputs['input_ids'] + teacher_logprobs, teacher_indices = fetch_teacher_logprobs( + self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk) + return teacher_logprobs.to(input_ids.device), teacher_indices.to(input_ids.device) + def prediction_step(self, model, inputs, *args, **kwargs): # Prediction uses full messages encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False) + + # Fetch teacher logprobs from API if using external teacher service (for eval) + if self.use_teacher_api: + teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs) + encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs + encoded_inputs['_teacher_api_indices'] = teacher_indices + with self.template.forward_context(self.model, encoded_inputs): return super().prediction_step(model, encoded_inputs, *args, **kwargs) @@ -459,7 +545,7 @@ def _prepare_liger_loss(self): raise ImportError( 'Liger kernel is not installed. Please install liger-kernel by running: pip install liger-kernel') assert self.args.sft_alpha == 0, 'SFT loss is not supported with liger loss' - + assert self.gkd_logits_topk is None, 'Top-k mode is not supported with liger loss' self.liger_jsd_loss = LigerFusedLinearJSDLoss( beta=self.beta, ignore_index=-100, @@ -471,24 +557,45 @@ def _prepare_liger_loss(self): @staticmethod def generalized_jsd_loss( student_logits, - teacher_logits, + teacher_logits=None, labels=None, beta=0.5, temperature=1.0, chunk_size=512, + topk=None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, ): - # Apply temperature scaling + # Top-k mode: reduce logits to [*, k] before the standard JSD pipeline + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: + # API teacher: gather student logits at teacher's top-k indices, then both + # get scaled by 1/T and re-normalized over top-k via downstream log_softmax. + # vLLM logprobs = log_softmax(logits) at T=1; treating them as logit-like + # scores and dividing by T then re-normalizing is equivalent to + # log_softmax(logits/T) over top-k (by shift-invariance of softmax). + s_scaled = student_logits / temperature + student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + teacher_logits = teacher_topk_logprobs / temperature + del s_scaled + temperature = 1.0 + elif topk is not None and teacher_logits is not None: + # Local teacher: select top-k from teacher, gather corresponding student logits + t_scaled = teacher_logits / temperature + s_scaled = student_logits / temperature + teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1) + student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx) + del t_scaled, s_scaled, topk_idx + temperature = 1.0 + student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature - # Apply masking if labels provided if labels is not None: mask = labels != -100 student_logits = student_logits[mask] teacher_logits = teacher_logits[mask] num_valid = mask.sum() else: - # Flatten to [num_tokens, vocab_size] student_logits = student_logits.view(-1, student_logits.size(-1)) teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) num_valid = student_logits.size(0) @@ -499,7 +606,6 @@ def generalized_jsd_loss( num_valid_int = num_valid if isinstance(num_valid, int) else num_valid.item() total_loss = student_logits.new_zeros(()) - # Precompute beta tensor once if needed if beta != 0 and beta != 1: beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device) log_beta = torch.log(beta_t) @@ -507,7 +613,6 @@ def generalized_jsd_loss( else: beta_t = log_beta = log_1_minus_beta = None - # Process in chunks to reduce peak memory for start_idx in range(0, num_valid_int, chunk_size): end_idx = min(start_idx + chunk_size, num_valid_int) s_chunk = student_logits[start_idx:end_idx] @@ -526,11 +631,9 @@ def generalized_jsd_loss( torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), dim=0, ) - kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True) kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True) del mixture_log_probs - jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student del kl_teacher, kl_student @@ -606,3 +709,83 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non row = [table[header][i] for header in headers] rows.append(row) swanlab.log({'completions': swanlab.echarts.Table().add(headers, rows)}) + + +def fetch_teacher_logprobs(base_url, input_ids, topk=20, timeout=300.0): + """Fetch top-k prompt logprobs from a vLLM-compatible /v1/completions endpoint. + + Uses prompt_logprobs to get logprobs for input tokens without generating. + vLLM prompt_logprobs are always raw (temperature=1) log-probabilities from the model; + the temperature parameter in the API only affects token sampling, not prompt_logprobs. + + Args: + base_url: vLLM server URL (e.g., 'http://localhost:8000'). + input_ids: List of token ID sequences. + topk: Number of top log probabilities per token. + timeout: Request timeout in seconds. + + Returns: + (logprobs, indices) tensors of shape [batch, max_seq_len - 1, topk]. + The shift is because prompt_logprobs[0] is always None (first token has no + conditional probability), so position i in the output corresponds to + P(token_{i+1} | token_0..token_i), aligning with model logits[i]. + """ + import logging + import requests + from concurrent.futures import ThreadPoolExecutor + + _logger = logging.getLogger(__name__) + base_url = base_url.rstrip('/') + batch_size = len(input_ids) + max_seq_len = max(len(ids) for ids in input_ids) + url = f'{base_url}/v1/completions' + global teacher_model_server_model_name + if teacher_model_server_model_name is None: + try: + resp = requests.get(f'{base_url}/v1/models', timeout=10) + model = resp.json()['data'][0]['id'] if resp.ok else 'default' + except Exception: + model = 'default' + teacher_model_server_model_name = model + else: + model = teacher_model_server_model_name + + # prompt_logprobs[0] is always None (no conditional prob for the first token), + # prompt_logprobs[i] = P(token_i | token_0..token_{i-1}) which aligns with logits[i-1]. + # So we skip position 0 and the result has shape [batch, max_seq_len-1, topk], + # aligning with student logits which predict the next token at each position. + out_len = max_seq_len - 1 + logprobs_out = torch.full((batch_size, out_len, topk), float('-inf'), dtype=torch.float32) + indices_out = torch.zeros((batch_size, out_len, topk), dtype=torch.long) + + def _fetch_one(batch_idx): + payload = { + 'model': model, + 'prompt': input_ids[batch_idx], + 'max_tokens': 1, + 'temperature': 0, + 'prompt_logprobs': topk, + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + prompt_logprobs_list = resp.json()['choices'][0].get('prompt_logprobs', []) + # Skip position 0 (always None), shift left so pos 1 -> output pos 0 + for raw_pos in range(1, len(prompt_logprobs_list)): + pos_lp = prompt_logprobs_list[raw_pos] + if pos_lp is None: + continue + out_pos = raw_pos - 1 + if out_pos >= out_len: + break + sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1]['logprob'])[:topk] + for k, (tid_str, info) in enumerate(sorted_items): + indices_out[batch_idx, out_pos, k] = int(tid_str) + logprobs_out[batch_idx, out_pos, k] = info['logprob'] + except Exception as e: + _logger.error(f'Failed to get teacher logprobs for sequence {batch_idx}: {e}') + + with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: + list(pool.map(_fetch_one, range(batch_size))) + + return logprobs_out, indices_out diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index 2484df4e77..402bf7dae6 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -566,10 +566,16 @@ def profiling_context(trainer, name: str): profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} - if 'wandb' in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: + is_main_process = False + if hasattr(trainer, 'accelerator'): + is_main_process = trainer.accelerator.is_main_process + elif hasattr(trainer, 'is_main_process'): + is_main_process = trainer.is_main_process + + if 'wandb' in trainer.args.report_to and wandb.run is not None and is_main_process: wandb.log(profiling_metrics) - if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and trainer.accelerator.is_main_process: + if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and is_main_process: swanlab.log(profiling_metrics)