From 176ef74733aaae0a247221aeed255e27e71290ce Mon Sep 17 00:00:00 2001 From: hjh Date: Tue, 20 Jan 2026 15:53:58 +0800 Subject: [PATCH 01/10] init support --- docs/source/Instruction/GKD.md | 84 ++++++- docs/source_en/Instruction/GKD.md | 85 ++++++- swift/arguments/rlhf_args.py | 46 +++- swift/megatron/arguments/megatron_args.py | 30 +++ swift/megatron/pipelines/train/rlhf.py | 7 + swift/megatron/trainers/gkd_trainer.py | 214 +++++++++++++--- swift/pipelines/train/rlhf.py | 11 +- swift/rlhf_trainers/__init__.py | 2 + swift/rlhf_trainers/gkd_trainer.py | 283 ++++++++++++++++++---- swift/rlhf_trainers/teacher_api_client.py | 263 ++++++++++++++++++++ swift/rlhf_trainers/utils.py | 41 ++++ tests/train/test_teacher_api_client.py | 231 ++++++++++++++++++ 12 files changed, 1220 insertions(+), 77 deletions(-) create mode 100644 swift/rlhf_trainers/teacher_api_client.py create mode 100644 tests/train/test_teacher_api_client.py diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index 25e4545802..1f18c5ef6e 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -139,9 +139,11 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) 我们可以通过设置以下参数进行 GKD 训练: +### 基础参数 + | 参数 | 类型 | 默认值 | 取值范围 | 说明 | |------|------|--------|---------|------| -| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID | +| `--teacher_model` | str | None | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 | | `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数
• 0.0: Forward KL
• 0.5: JSD (平衡)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率
• 0.0: 离线学习
• 0.5: 混合策略
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | True/False | 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成 | @@ -149,6 +151,86 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) | `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 | | `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 | +### Top-K KL 计算 + +默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 **Top-K** 模式来减少显存占用和计算量。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--gkd_logits_topk` | int | None | Top-K logits 数量
• None: 使用完整词表(默认)
• 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL | + +**Top-K 模式原理**: + +在 Top-K 模式下,选取**教师模型**输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。 +$$ +D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) +$$ + +其中 Top-K 索引来自教师模型:$\text{Top-K} = \text{argtop}_K(P_T)$,$\tilde{P}_T$ 和 $\tilde{P}_S$ 是在 Top-K 子集上**重新归一化**的概率分布: + +$$ +\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} +$$ + +**使用示例**: + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B-Instruct \ + --teacher_model Qwen/Qwen2.5-72B-Instruct \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + ... +``` + +> **注意**:Top-K 模式不能与 liger kernel 同时使用(`--use_liger_kernel`)。 + +### 外部教师模型 API + +当设置 `gkd_logits_topk` 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--teacher_model_server` | str | None | 教师模型服务地址
如:`http://localhost:8000` | +| `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 | + +**支持的后端**: +- `swift deploy`(vLLM backend) +- 独立 vLLM 服务(`vllm serve`) + +**步骤 1:部署教师模型服务** + +```bash +# 使用 swift deploy 部署教师模型 +CUDA_VISIBLE_DEVICES=0,1 swift deploy \ + --model Qwen/Qwen2-72B-Instruct \ + --infer_backend vllm \ + --port 8000 \ + --vllm_engine_kwargs '{"max_logprobs": 64}' + +# 或使用独立 vLLM 服务 +vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 +``` + +**步骤 2:启动 GKD 训练** + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2-7B-Instruct \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 20 \ + --dataset your_dataset \ + --lmbda 1.0 \ + --beta 0.5 \ + ... +``` + +> **vLLM max_logprobs 限制**: +> - vLLM 默认 `max_logprobs=20`,可通过 `--vllm_engine_kwargs '{"max_logprobs": N}'` 参数调整 +> - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置 + ## 采样加速 在 GKD 训练中,涉及到两种在线采样的情况: diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index 9bff8a81eb..e72c6a6698 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -139,9 +139,11 @@ Set parameter `seq_kd=True`, when on-policy is not triggered, use teacher model We can perform GKD training by setting the following parameters: +### Basic Parameters + | Parameter | Type | Default | Range | Description | |------|------|--------|---------|------| -| `--teacher_model` | str | Required | - | Teacher model path or model ID | +| `--teacher_model` | str | None | - | Teacher model path or model ID
*Can be omitted when using `teacher_model_server` | | `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient
• 0.0: Forward KL
• 0.5: JSD (balanced)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability
• 0.0: Pure Offline
• 0.5: Mixed strategy (**recommended**)
• 1.0: Pure On-Policy | | `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences
• False: Use dataset when not on-policy
• True: Use teacher generation when not on-policy | @@ -149,6 +151,87 @@ We can perform GKD training by setting the following parameters: | `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions | | `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation | +### Top-K KL Computation + +By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use **Top-K** mode to reduce memory usage and computation. + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `--gkd_logits_topk` | int | None | Number of Top-K logits
• None: Use full vocabulary (default)
• Positive integer: Only use the K tokens with highest teacher probability for KL computation | + +**Top-K Mode Principle**: + +In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. + +$$ +D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) +$$ + +Where the Top-K indices come from the teacher model: $\text{Top-K} = \text{argtop}_K(P_T)$, and $\tilde{P}_T$ and $\tilde{P}_S$ are the probability distributions **renormalized** over the Top-K subset: + +$$ +\tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} +$$ + +**Usage Example**: + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2-7B-Instruct \ + --teacher_model Qwen/Qwen2-72B-Instruct \ + --gkd_logits_topk 64 \ + --dataset your_dataset \ + ... +``` + +> **Note**: Top-K mode cannot be used with liger kernel (`--use_liger_kernel`). + +### External Teacher Model API + +When `gkd_logits_topk` is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process. + +| Parameter | Type | Default | Description | +|------|------|--------|------| +| `--teacher_model_server` | str | None | Teacher model service URL
e.g., `http://localhost:8000` | +| `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API | + +**Supported Backends**: +- `swift deploy` (vLLM backend) +- Standalone vLLM server (`vllm serve`) + +**Step 1: Deploy Teacher Model Service** + +```bash +# Deploy teacher model with swift deploy (recommended) +CUDA_VISIBLE_DEVICES=0,1 swift deploy \ + --model Qwen/Qwen2-72B-Instruct \ + --infer_backend vllm \ + --port 8000 \ + --vllm_engine_kwargs '{"max_logprobs": 64}' + +# Or use standalone vLLM server +vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 +``` + +**Step 2: Start GKD Training** + +```bash +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2-7B-Instruct \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 20 \ + --dataset your_dataset \ + --lmbda 1.0 \ + --beta 0.5 \ + ... +``` + +> **vLLM max_logprobs Limitation**: +> - vLLM default `max_logprobs=20`, adjustable via `--vllm_engine_kwargs '{"max_logprobs": N}'` parameter +> - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting + ## Sampling Acceleration In GKD training, there are two types of online sampling scenarios: diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index 8bbe67469b..63f39dbcdf 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -39,7 +39,8 @@ class TeacherModelArguments: Args: teacher_model (Optional[str]): The model ID or a local path to the teacher model. This is required when - `rlhf_type` is 'gkd'. Analogous to the main `model` argument. Defaults to None. + `rlhf_type` is 'gkd' and `teacher_model_server` is not set. Analogous to the main `model` argument. + Defaults to None. teacher_adapters (List[str]): A list of paths to LoRA weights. These weights, often produced by SFT, are loaded to form the teacher model. Defaults to an empty list (`[]`). teacher_model_type (Optional[str]): The model type of the teacher model. If not specified, it's often inferred. @@ -50,6 +51,15 @@ class TeacherModelArguments: one of the following values: 'zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'. If not provided, it defaults to using the same DeepSpeed configuration as the main training model. Analogous to the main `deepspeed` argument. + teacher_model_server (Optional[str]): The URL of the teacher model server (e.g., 'http://localhost:8000'). + When set, the teacher logprobs will be fetched from the external API service (e.g., swift deploy, vLLM) + instead of loading a local teacher model. This enables using larger teacher models or services hosted + remotely. When this is set, `teacher_model` is not required. Defaults to None. + gkd_logits_topk (Optional[int]): The number of top-k logits to use for KL divergence computation in GKD. + If None, uses full vocabulary for KL computation (more accurate but memory-intensive). + If set to a positive integer, only top-k teacher logits are used (more efficient). + When using `teacher_model_server`, this is limited by the server's `max_logprobs` setting + (vLLM default is 20, can be increased with `--max-logprobs`). Defaults to None. """ teacher_model: Optional[str] = None teacher_adapters: List[str] = field(default_factory=list) @@ -63,6 +73,21 @@ class TeacherModelArguments: 'DeepSpeed configuration for teacher model. ' 'Can be a path to a json file or one of: zero0, zero1, zero2, zero3, zero2_offload, zero3_offload' }) + teacher_model_server: Optional[str] = field( + default=None, + metadata={ + 'help': + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' + }) + gkd_logits_topk: Optional[int] = field( + default=None, + metadata={ + 'help': + 'Number of top-k logits for KL computation in GKD. ' + 'None = full vocabulary, positive integer = top-k only. ' + 'When using teacher_model_server, limited by server max_logprobs (vLLM default: 20).' + }) @dataclass @@ -554,3 +579,22 @@ def _check_gkd(self): if self.async_generate: raise NotImplementedError('Currently, async_generate is not supported for GKD.') + + # Validate teacher model configuration + if self.teacher_model is None and self.teacher_model_server is None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.') + + if self.teacher_model is not None and self.teacher_model_server is not None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.') + + # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs) + if self.teacher_model_server is not None: + if self.gkd_logits_topk is None: + raise ValueError('gkd_logits_topk is required when using teacher_model_server') + + # Validate gkd_logits_topk + if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0: + raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}') + + if self.gkd_logits_topk is not None and self.use_liger_kernel: + raise ValueError('gkd_logits_topk is not supported when using liger kernel') diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 576e37fed6..81be8e65bd 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -45,6 +45,21 @@ class RLHFMegatronArgumentsMixin: teacher_model: Optional[str] = field(default=None) teacher_model_type: Optional[str] = field(default=None) teacher_model_revision: Optional[str] = field(default=None) + teacher_model_server: Optional[str] = field( + default=None, + metadata={ + 'help': + 'URL of the teacher model server (e.g., http://localhost:8000). ' + 'When set, teacher logprobs are fetched via API instead of loading a local model.' + }) + gkd_logits_topk: Optional[int] = field( + default=None, + metadata={ + 'help': + 'Number of top-k logits for KL computation in GKD. ' + 'None = full vocabulary, positive integer = top-k only. ' + 'When using teacher_model_server, limited by server max_logprobs (vLLM default: 20).' + }) lmbda: float = 0.5 # On-policy probability: with prob lmbda, use student-generated responses seq_kd: bool = False # Sequential KD: use teacher-generated responses when not on-policy offload_teacher_model: bool = False # Offload teacher model to CPU to save GPU memory @@ -194,6 +209,21 @@ def __post_init__(self): if self.vllm_limit_mm_per_prompt is not None: self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) + if self.rlhf_type == 'gkd': + if self.teacher_model is None and self.teacher_model_server is None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set.') + + if self.teacher_model is not None and self.teacher_model_server is not None: + raise ValueError('GKD requires either `teacher_model` or `teacher_model_server` to be set, not both.') + + # When using teacher_model_server, gkd_logits_topk is required (API only returns top-k logprobs) + if self.teacher_model_server is not None: + if self.gkd_logits_topk is None: + raise ValueError('gkd_logits_topk is required when using teacher_model_server') + + # Validate gkd_logits_topk + if self.gkd_logits_topk is not None and self.gkd_logits_topk <= 0: + raise ValueError(f'gkd_logits_topk must be a positive integer, got {self.gkd_logits_topk}') def _init_grpo(self): diff --git a/swift/megatron/pipelines/train/rlhf.py b/swift/megatron/pipelines/train/rlhf.py index 5ac25ef472..1d1b9faad4 100644 --- a/swift/megatron/pipelines/train/rlhf.py +++ b/swift/megatron/pipelines/train/rlhf.py @@ -31,6 +31,8 @@ def prepare_trainer(self): kwargs = {} if args.rlhf_type in ('grpo', 'gkd'): kwargs['vllm_client'] = self._prepare_vllm_client() + if args.rlhf_type == 'gkd': + kwargs['teacher_api_client'] = self._prepare_teacher_api_client() return trainer_cls(args, self.template, **kwargs) def _prepare_template(self) -> None: @@ -73,6 +75,11 @@ def _prepare_vllm_client(self): logger.info('Connected to vLLM server') return vllm_client + def _prepare_teacher_api_client(self): + """Prepare teacher API client for external teacher model service.""" + from swift.rlhf_trainers.utils import create_teacher_api_client + return create_teacher_api_client(self.args, check_health=True, timeout=60, use_last_rank=True) + def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index a5a225d8a9..cf040a042e 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -38,6 +38,7 @@ class MegatronGKDTrainer(MegatronRolloutMixin, MegatronRLHFTrainer): def __init__(self, args: MegatronArguments, template, **kwargs): self.vllm_client = kwargs.pop('vllm_client', None) + self.teacher_api_client = kwargs.pop('teacher_api_client', None) super().__init__(args, template) # GKD-specific parameters @@ -47,7 +48,18 @@ def __init__(self, args: MegatronArguments, template, **kwargs): self.seq_kd = args.seq_kd # Sequential KD: use teacher-generated responses self.offload_teacher_model = args.offload_teacher_model # Offload teacher to CPU self.sft_alpha = getattr(args, 'sft_alpha', 0.0) # Weight for SFT loss - assert args.teacher_model is not None, 'Teacher model path is required for GKD training' + + # GKD top-k logits configuration + self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) + self.use_teacher_api = self.teacher_api_client is not None + + # Validate teacher configuration + if not self.use_teacher_api: + assert args.teacher_model is not None, \ + 'Teacher model path is required for GKD training (or set teacher_model_server for API mode)' + else: + logger.info(f'Using teacher model API for logprobs, top_logprobs={self.gkd_logits_topk}') + self.use_vllm = getattr(args, 'use_vllm', False) # Get device for data processing @@ -86,12 +98,16 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k Teacher model uses the same parallel parameters (PP/TP/CP/EP) as student model, """ - # Get teacher model path from Swift args - teacher_model_path = self.args.teacher_model - logger.info(f'Loading teacher model from: {teacher_model_path}') - - # Load teacher model with same parallel config as student - self._load_teacher_model(teacher_model_path, model_type) + # Skip teacher model loading if using API + if not self.use_teacher_api: + # Get teacher model path from Swift args + teacher_model_path = self.args.teacher_model + logger.info(f'Loading teacher model from: {teacher_model_path}') + + # Load teacher model with same parallel config as student + self._load_teacher_model(teacher_model_path, model_type) + else: + logger.info('Skipping local teacher model loading - using external API for teacher logprobs') return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) @@ -430,20 +446,76 @@ def _get_num_microbatches(self) -> int: def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: teacher_model = self.teacher_models[vp_stage or 0] + if self.use_teacher_api: + # API mode: fetch teacher logprobs from external service + self._compute_teacher_logits_from_api(encoded_batches) + else: + # Local teacher model mode + for encoded_batch in encoded_batches: + # Deep copy to avoid modifying original batch + teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()} + teacher_batch.pop('data_source', None) + teacher_data = self._prepare_batch(teacher_batch) + teacher_data.pop('loss_scale', None) + # Remove labels so returns logits instead of loss + teacher_data.pop('labels', None) + # Teacher forward with args override for correct hidden_size + with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad(): + teacher_logits = forward_step_helper(teacher_model, teacher_data) + if teacher_logits is not None: + teacher_logits = teacher_logits.detach() + encoded_batch['teacher_logits'] = teacher_logits + + def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: + """Fetch teacher logprobs from external API service. + + Args: + encoded_batches: List of encoded batch dictionaries + vp_stage: Virtual pipeline stage (unused in API mode) + """ + import asyncio + + topk = self.gkd_logits_topk + for encoded_batch in encoded_batches: - # Deep copy to avoid modifying original batch - teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()} - teacher_batch.pop('data_source', None) - teacher_data = self._prepare_batch(teacher_batch) - teacher_data.pop('loss_scale', None) - # Remove labels so returns logits instead of loss - teacher_data.pop('labels', None) - # Teacher forward with args override for correct hidden_size - with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad(): - teacher_logits = forward_step_helper(teacher_model, teacher_data) - if teacher_logits is not None: - teacher_logits = teacher_logits.detach() - encoded_batch['teacher_logits'] = teacher_logits + input_ids = encoded_batch['input_ids'] + attention_mask = encoded_batch.get('attention_mask', None) + batch_size, seq_len = input_ids.shape + + # Prepare requests for API + async def fetch_batch(): + results = await self.teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + attention_mask=attention_mask.tolist() if attention_mask is not None else None, + top_logprobs=topk, + ) + return results + + # Run async function + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, fetch_batch()) + api_results = future.result() + else: + api_results = loop.run_until_complete(fetch_batch()) + + # Parse API results into tensors + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.long) + + for batch_idx, result in enumerate(api_results): + for pos_idx, pos_logprobs in enumerate(result.get('logprobs', [])): + if pos_idx >= seq_len: + break + for k_idx, (token_id, logprob) in enumerate(pos_logprobs[:topk]): + teacher_logprobs[batch_idx, pos_idx, k_idx] = logprob + teacher_indices[batch_idx, pos_idx, k_idx] = token_id + + encoded_batch['teacher_api_logprobs'] = teacher_logprobs + encoded_batch['teacher_api_indices'] = teacher_indices + encoded_batch['teacher_logits'] = None # Not used in API mode def _replace_data_iterator(self, data_iterator, model): num_microbatches = self._get_num_microbatches() @@ -531,7 +603,30 @@ def generalized_jsd_loss( labels: torch.Tensor, beta: float = 0.5, chunk_size: int = 512, + topk: int = None, + teacher_topk_logprobs: torch.Tensor = None, + teacher_topk_indices: torch.Tensor = None, ) -> torch.Tensor: + """Compute generalized JSD loss with optional top-k support. + + This method supports three modes: + 1. Full vocabulary mode (default): Uses complete logits with vocab-parallel + 2. Top-k mode with local teacher: Extracts top-k from teacher_logits + 3. Top-k mode with API logprobs: Uses pre-computed teacher_topk_logprobs and indices + + Args: + student_logits: Student model logits [batch, seq_len, vocab_size] + teacher_logits: Teacher model logits, can be None for API mode + labels: Token labels for masking [batch, seq_len] + beta: JSD interpolation coefficient + chunk_size: Chunk size for memory-efficient processing (full vocab mode only) + topk: Number of top-k logits to use (teacher's top-k). None for full vocabulary mode. + teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) + teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) + + Returns: + Scalar loss value + """ args = get_args() mask = labels != -100 local_num_valid = mask.sum() @@ -545,6 +640,57 @@ def generalized_jsd_loss( if num_valid == 0: return (student_logits.sum() * 0).reshape(()) + # Determine mode + use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None + use_topk = topk is not None or use_api_mode + + # ============== Top-K Mode ============== + if use_topk: + # Apply temperature scaling to student logits + student_logits_scaled = student_logits / self.temperature + + if use_api_mode: + # API mode: teacher logprobs already computed + teacher_topk_probs = torch.exp(teacher_topk_logprobs) + teacher_topk_log_probs = teacher_topk_logprobs + topk_indices = teacher_topk_indices + else: + # Local mode: extract top-k from teacher logits + teacher_logits_scaled = teacher_logits / self.temperature + teacher_topk_logits, topk_indices = torch.topk(teacher_logits_scaled, k=topk, dim=-1) + teacher_topk_probs = F.softmax(teacher_topk_logits, dim=-1) + teacher_topk_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) + + # Gather student logits at teacher's top-k indices and renormalize + student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=topk_indices) + student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) + + # Compute JSD on top-k distribution + if beta == 0: + jsd = (teacher_topk_probs * (teacher_topk_log_probs - student_topk_log_probs)).sum(dim=-1) + elif beta == 1: + student_topk_probs = F.softmax(student_topk_logits, dim=-1) + jsd = (student_topk_probs * (student_topk_log_probs - teacher_topk_log_probs)).sum(dim=-1) + else: + student_topk_probs = F.softmax(student_topk_logits, dim=-1) + mixture_probs = beta * teacher_topk_probs + (1 - beta) * student_topk_probs + mixture_log_probs = torch.log(mixture_probs + 1e-10) + kl_teacher = (teacher_topk_probs * (teacher_topk_log_probs - mixture_log_probs)).sum(dim=-1) + kl_student = (student_topk_probs * (student_topk_log_probs - mixture_log_probs)).sum(dim=-1) + jsd = beta * kl_teacher + (1 - beta) * kl_student + + # Apply mask and compute sum + jsd_masked = jsd * mask.float() + total_loss = jsd_masked.sum() + + # All-reduce total_loss across CP group for correct sum + if args.context_parallel_size > 1: + torch.distributed.all_reduce( + total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) + + return total_loss / num_valid + + # ============== Full Vocabulary Mode (with vocab parallel) ============== # Align vocab size between student and teacher student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits) @@ -576,23 +722,17 @@ def generalized_jsd_loss( del s_chunk, t_chunk if beta == 0: - # JSD = KL(teacher || student) jsd_chunk = vocab_parallel_kl_div(s_log_probs, t_log_probs) elif beta == 1: - # JSD = KL(student || teacher) jsd_chunk = vocab_parallel_kl_div(t_log_probs, s_log_probs) else: - # Compute mixture log probabilities: m = beta * teacher + (1-beta) * student - # log(m) = logsumexp(log(student) + log(1-beta), log(teacher) + log(beta)) mixture_log_probs = torch.logsumexp( torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), dim=0, ) - kl_teacher = vocab_parallel_kl_div(mixture_log_probs, t_log_probs) kl_student = vocab_parallel_kl_div(mixture_log_probs, s_log_probs) del mixture_log_probs - jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student del kl_teacher, kl_student @@ -613,23 +753,31 @@ def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, - teacher_logits: torch.Tensor, + teacher_logits: torch.Tensor = None, + teacher_api_logprobs: torch.Tensor = None, + teacher_api_indices: torch.Tensor = None, data_source: DataSource = DataSource.DATASET): """Compute GKD loss (JSD + optional SFT loss). Args: output_tensor: Student model logits [batch, seq_len, vocab_size] labels: Token labels for masking [batch, seq_len] - teacher_logits: Teacher model logits [batch, seq_len, vocab_size] + teacher_logits: Teacher model logits [batch, seq_len, vocab_size] (for local teacher) + teacher_api_logprobs: Teacher log probabilities [batch, seq_len, topk] (for API mode) + teacher_api_indices: Teacher token indices [batch, seq_len, topk] (for API mode) data_source: Data source (STUDENT/TEACHER/DATASET) for conditional SFT loss """ student_logits = output_tensor + # Compute JSD loss using unified generalized_jsd_loss jsd_loss = self.generalized_jsd_loss( student_logits=student_logits, teacher_logits=teacher_logits, labels=labels, beta=self.beta, + topk=self.gkd_logits_topk, + teacher_topk_logprobs=teacher_api_logprobs, + teacher_topk_indices=teacher_api_indices, ) loss = jsd_loss @@ -680,6 +828,8 @@ def forward_step(self, data_iterator, model): data = next(data_iterator) data_source = data.pop('data_source', DataSource.DATASET) teacher_logits = data.pop('teacher_logits', None) + teacher_api_logprobs = data.pop('teacher_api_logprobs', None) + teacher_api_indices = data.pop('teacher_api_indices', None) data = self._prepare_batch(data, vp_stage) timers('batch-generator').stop() @@ -692,7 +842,13 @@ def forward_step(self, data_iterator, model): student_output = model(**data) return student_output, partial( - self.loss_func, labels=labels, teacher_logits=teacher_logits, data_source=data_source) + self.loss_func, + labels=labels, + teacher_logits=teacher_logits, + teacher_api_logprobs=teacher_api_logprobs, + teacher_api_indices=teacher_api_indices, + data_source=data_source, + ) def patched_validate_args(self, args, *_args, **kwargs): """Override patched_validate_args to adjust EP parameters for Dense student. diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index 25f0ba359d..543f132591 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -227,8 +227,15 @@ def _get_trainer_kwargs(self): trainer_kwargs['reward_funcs'] = self.args.reward_funcs if self.args.chord_sft_dataset: trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset() - if self.args.rlhf_type == 'gkd' and self.args.teacher_deepspeed: - trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + if self.args.rlhf_type == 'gkd': + if self.args.teacher_deepspeed: + trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + # Initialize teacher API client if using external teacher service + if self.args.teacher_model_server: + from swift.rlhf_trainers.utils import create_teacher_api_client + trainer_kwargs['teacher_api_client'] = create_teacher_api_client( + self.args, check_health=False, timeout=60, use_last_rank=False + ) return trainer_kwargs diff --git a/swift/rlhf_trainers/__init__.py b/swift/rlhf_trainers/__init__.py index 5244f1d7e2..fa3e8f13b5 100644 --- a/swift/rlhf_trainers/__init__.py +++ b/swift/rlhf_trainers/__init__.py @@ -16,6 +16,7 @@ from .args_mixin import VllmArguments, GRPOArgumentsMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection from .vllm_client import VLLMClient + from .teacher_api_client import TeacherAPIClient from .arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig, GRPOConfig, GKDConfig else: _import_structure = { @@ -31,6 +32,7 @@ 'args_mixin': ['VllmArguments', 'GRPOArgumentsMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], 'vllm_client': ['VLLMClient'], + 'teacher_api_client': ['TeacherAPIClient'], 'arguments': ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig'] } diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 6935615b78..83cc95f9e8 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -51,9 +51,10 @@ class DataSource(str, Enum): class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): - teacher_model = kwargs.pop('teacher_model') + teacher_model = kwargs.pop('teacher_model', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) + self.teacher_api_client = kwargs.pop('teacher_api_client', None) kwargs['data_collator'] = identity_data_collator super().__init__(model, None, *_args, **kwargs) args = kwargs['args'] @@ -64,32 +65,47 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 + # GKD top-k logits configuration + self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) + self.use_teacher_api = self.teacher_api_client is not None + # Initialize logging components self._prepare_logging() - # Initialize liger loss - self._prepare_liger_loss() + # Initialize liger loss (only when not using top-k mode) + if self.gkd_logits_topk is None: + self._prepare_liger_loss() + else: + self.use_liger_gkd_loss = False + logger.info(f'Using top-k logits (k={self.gkd_logits_topk}) for KL computation, liger loss disabled.') self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation self.is_teacher_ds3 = None - # Initialize teacher model - if self.is_deepspeed_enabled: - if teacher_deepspeed_config is not None: - self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3 - if not self.is_teacher_ds3: - self.teacher_ds3_gather_for_generation = False - self.teacher_model = prepare_deepspeed( - teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + self.teacher_model = None + + # Initialize teacher model (skip if using API) + if not self.use_teacher_api: + if teacher_model is None: + raise ValueError('teacher_model is required when not using teacher_model_server') + if self.is_deepspeed_enabled: + if teacher_deepspeed_config is not None: + self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3 + if not self.is_teacher_ds3: + self.teacher_ds3_gather_for_generation = False + self.teacher_model = prepare_deepspeed( + teacher_model, self.accelerator, deepspeed_config=teacher_deepspeed_config, training_args=args) + else: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + elif self.is_fsdp_enabled: + from .utils import prepare_fsdp + self.teacher_model = prepare_fsdp(teacher_model, self.accelerator) else: - self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) - elif self.is_fsdp_enabled: - from .utils import prepare_fsdp - self.teacher_model = prepare_fsdp(teacher_model, self.accelerator) + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + self.teacher_model.eval() + if self.args.offload_teacher_model: + self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) else: - self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) - self.teacher_model.eval() - if self.args.offload_teacher_model: - self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) + logger.info(f'Using teacher model API for logprobs, top_logprobs={self.gkd_logits_topk}') # Initialize rollout infrastructure for vLLM support self.prepare_rollout() @@ -151,6 +167,10 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Get data source: DataSource.STUDENT, DataSource.TEACHER, or DataSource.DATASET data_source = inputs.pop('_data_source', DataSource.DATASET) + # Get teacher logprobs from API if available (set in training_step) + teacher_api_logprobs = inputs.pop('_teacher_api_logprobs', None) + teacher_api_indices = inputs.pop('_teacher_api_indices', None) + model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}} # If generate is used, then use_logits_to_keep must be set to False. use_logits_to_keep = self.get_use_logits_to_keep(True) @@ -220,8 +240,32 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss /= student_hidden.shape[1] # Release hidden states after loss computation del student_hidden, teacher_hidden, true_labels + outputs_student = None + elif self.use_teacher_api: + assert teacher_api_logprobs is not None + # API mode: use teacher logprobs from external service + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + outputs_student = model(**model_inputs) + + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + + # Compute top-k JSD loss with API logprobs + loss = self.generalized_jsd_loss( + student_logits=outputs_student.logits, + teacher_logits=None, # Not used in API mode + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + teacher_topk_logprobs=teacher_api_logprobs, + teacher_topk_indices=teacher_api_indices, + ) + + # Add SFT loss if enabled (skip for student-generated responses) + if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: + loss = loss + self.args.sft_alpha * outputs_student.loss else: - # Standard loss computation + # Standard loss computation (local teacher model) if self.args.sft_alpha > 0: model_inputs['labels'] = inputs['labels'] # compute student output @@ -234,35 +278,47 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_teacher = self.teacher_model(**model_inputs) shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) - mask = shifted_labels != -100 - shifted_student_logits = outputs_student.logits[mask][None] - shifted_teacher_logits = outputs_teacher.logits[mask][None] - - # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. - stu_dim = shifted_student_logits.shape[-1] - tea_dim = shifted_teacher_logits.shape[-1] - if stu_dim < tea_dim: - shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) - shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] - elif stu_dim > tea_dim: - shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) - shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] - - # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - beta=self.beta, - ) + + if self.gkd_logits_topk is not None: + # Top-k mode with local teacher + loss = self.generalized_jsd_loss( + student_logits=outputs_student.logits, + teacher_logits=outputs_teacher.logits, + labels=shifted_labels, + beta=self.beta, + temperature=self.temperature, + topk=self.gkd_logits_topk, + ) + else: + # Full vocabulary mode + mask = shifted_labels != -100 + shifted_student_logits = outputs_student.logits[mask][None] + shifted_teacher_logits = outputs_teacher.logits[mask][None] + + # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. + stu_dim = shifted_student_logits.shape[-1] + tea_dim = shifted_teacher_logits.shape[-1] + if stu_dim < tea_dim: + shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) + shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] + elif stu_dim > tea_dim: + shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) + shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] + + # compute loss + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + beta=self.beta, + temperature=self.temperature, + ) + # Add SFT loss if enabled (skip for student-generated responses) if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: loss = loss + self.args.sft_alpha * outputs_student.loss # Return loss if return_outputs: - if self.use_liger_gkd_loss: - # outputs has been released in liger loss computation to reduce peak memory - outputs_student = None return (loss, outputs_student) else: return loss @@ -383,10 +439,69 @@ def training_step(self, # Mark data source for downstream processing (e.g., conditional SFT loss) encoded_inputs['_data_source'] = data_source + # Fetch teacher logprobs from API if using external teacher service + if self.use_teacher_api: + teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs) + encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs + encoded_inputs['_teacher_api_indices'] = teacher_indices + with self.template.forward_context(self.model, encoded_inputs): loss = HFSFTTrainer.training_step(self, model, encoded_inputs, num_items_in_batch) return loss + def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tensor]): + """Fetch teacher logprobs from external API service. + + Args: + encoded_inputs: Dictionary containing input_ids, attention_mask, labels, etc. + + Returns: + Tuple of (teacher_logprobs, teacher_indices) tensors + """ + import asyncio + + input_ids = encoded_inputs['input_ids'] + attention_mask = encoded_inputs['attention_mask'] + batch_size, seq_len = input_ids.shape + topk = self.gkd_logits_topk + + # Prepare requests for API + # We need logprobs for each position, so we send the full sequence + # and request prompt_logprobs + async def fetch_batch(): + results = await self.teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + attention_mask=attention_mask.tolist(), + top_logprobs=topk, + ) + return results + + # Run async function + loop = asyncio.get_event_loop() + if loop.is_running(): + # If already in async context, create a new thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, fetch_batch()) + api_results = future.result() + else: + api_results = loop.run_until_complete(fetch_batch()) + + # Parse API results into tensors + # api_results should be list of dicts with 'logprobs' and 'indices' for each sample + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.long) + + for batch_idx, result in enumerate(api_results): + for pos_idx, pos_logprobs in enumerate(result.get('logprobs', [])): + if pos_idx >= seq_len: + break + for k_idx, (token_id, logprob) in enumerate(pos_logprobs[:topk]): + teacher_logprobs[batch_idx, pos_idx, k_idx] = logprob + teacher_indices[batch_idx, pos_idx, k_idx] = token_id + + return teacher_logprobs, teacher_indices + def prediction_step(self, model, inputs, *args, **kwargs): # Prediction uses full messages encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False) @@ -471,7 +586,89 @@ def generalized_jsd_loss( beta=0.5, temperature=1.0, chunk_size=512, + topk=None, + teacher_topk_logprobs=None, + teacher_topk_indices=None, ): + """Compute generalized JSD loss with optional top-k support. + + This method supports three modes: + 1. Full vocabulary mode (default): Uses complete logits from both models + 2. Top-k mode with local teacher: Extracts top-k from teacher_logits + 3. Top-k mode with API logprobs: Uses pre-computed teacher_topk_logprobs and indices + + For top-k mode, uses the teacher model's top-k tokens (following ROLL framework). + This reduces memory usage while maintaining training effectiveness. + + Args: + student_logits: Student model logits [batch, seq_len, vocab_size] or [num_tokens, vocab_size] + teacher_logits: Teacher model logits (same shape as student_logits), can be None for API mode + labels: Token labels for masking, shape [batch, seq_len] + beta: JSD interpolation coefficient (0=Forward KL, 0.5=JSD, 1=Reverse KL) + temperature: Temperature for softmax scaling + chunk_size: Chunk size for memory-efficient processing (full vocab mode only) + topk: Number of top-k logits to use (teacher's top-k). None for full vocabulary mode. + teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) + teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) + + Returns: + Scalar loss value + """ + # Determine mode + use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None + use_topk = topk is not None or use_api_mode + + # ============== Top-K Mode ============== + if use_topk: + # Apply temperature scaling to student logits + student_logits_scaled = student_logits / temperature + + if use_api_mode: + # API mode: teacher logprobs already computed (with temperature on server) + teacher_topk_log_probs = teacher_topk_logprobs + teacher_topk_probs = torch.exp(teacher_topk_logprobs) + topk_indices = teacher_topk_indices + else: + # Local mode: extract top-k from teacher logits + teacher_logits_scaled = teacher_logits / temperature + teacher_topk_logits, topk_indices = torch.topk(teacher_logits_scaled, k=topk, dim=-1) + teacher_topk_probs = F.softmax(teacher_topk_logits, dim=-1) + teacher_topk_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) + + # Gather student logits at teacher's top-k indices and renormalize + student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=topk_indices) + student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) + + # Compute JSD on top-k distribution + if beta == 0: + # Forward KL: KL(teacher || student) + jsd = (teacher_topk_probs * (teacher_topk_log_probs - student_topk_log_probs)).sum(dim=-1) + elif beta == 1: + # Reverse KL: KL(student || teacher) + student_topk_probs = F.softmax(student_topk_logits, dim=-1) + jsd = (student_topk_probs * (student_topk_log_probs - teacher_topk_log_probs)).sum(dim=-1) + else: + # Full JSD with mixture distribution + student_topk_probs = F.softmax(student_topk_logits, dim=-1) + mixture_probs = beta * teacher_topk_probs + (1 - beta) * student_topk_probs + mixture_log_probs = torch.log(mixture_probs + 1e-10) + kl_teacher = (teacher_topk_probs * (teacher_topk_log_probs - mixture_log_probs)).sum(dim=-1) + kl_student = (student_topk_probs * (student_topk_log_probs - mixture_log_probs)).sum(dim=-1) + jsd = beta * kl_teacher + (1 - beta) * kl_student + + # Apply mask and compute mean + if labels is not None: + mask = labels != -100 + jsd = jsd * mask.float() + num_valid = mask.sum() + else: + num_valid = jsd.numel() + + if num_valid == 0: + return student_logits.new_zeros(()) + return jsd.sum() / num_valid + + # ============== Full Vocabulary Mode ============== # Apply temperature scaling student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py new file mode 100644 index 0000000000..70ce47a31f --- /dev/null +++ b/swift/rlhf_trainers/teacher_api_client.py @@ -0,0 +1,263 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Client for fetching teacher model logprobs from swift deploy or vLLM server. + +This module provides a client for communicating with OpenAI-compatible endpoints +(e.g., swift deploy with vLLM backend, standalone vLLM server) to obtain teacher +model logprobs for knowledge distillation (GKD) training. +""" +import asyncio +import logging +from typing import Any, Dict, List, Optional, Tuple + +import aiohttp +import torch + +logger = logging.getLogger(__name__) + + +class TeacherAPIClient: + """Client for fetching teacher logprobs from swift deploy or vLLM server. + + This client is designed to work with OpenAI-compatible API endpoints: + - swift deploy (with vLLM backend) + - Standalone vLLM server (vllm serve) + + The client fetches top-k log probabilities for each token position, + which are then used for knowledge distillation (GKD) training. + + Args: + base_url: The base URL of the teacher model server (e.g., 'http://localhost:8000'). + top_logprobs: Number of top log probabilities to request per token. + timeout: Request timeout in seconds. + api_key: Optional API key for authentication. + model_name: Optional model name for the API request. If None, auto-detects. + """ + + def __init__( + self, + base_url: str, + top_logprobs: int = 20, + timeout: float = 300.0, + api_key: Optional[str] = None, + model_name: Optional[str] = None, + ): + self.base_url = base_url.rstrip('/') + self.top_logprobs = top_logprobs + self.timeout = aiohttp.ClientTimeout(total=timeout) + self.api_key = api_key + self.model_name = model_name + + if top_logprobs <= 0: + raise ValueError(f'top_logprobs must be positive, got {top_logprobs}') + + def _get_headers(self) -> Dict[str, str]: + """Get HTTP headers for API requests.""" + headers = {'Content-Type': 'application/json'} + if self.api_key: + headers['Authorization'] = f'Bearer {self.api_key}' + return headers + + async def _get_model_name(self, session: aiohttp.ClientSession) -> str: + """Get model name from server if not provided.""" + if self.model_name: + return self.model_name + + try: + async with session.get( + f'{self.base_url}/v1/models', headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + if data.get('data') and len(data['data']) > 0: + self.model_name = data['data'][0]['id'] + return self.model_name + except Exception as e: + logger.warning(f'Failed to get model name: {e}') + + self.model_name = 'default' + return self.model_name + + async def get_logprobs_batch( + self, + input_ids: List[List[int]], + top_logprobs: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Fetch logprobs for a batch of sequences using OpenAI-compatible API. + + Args: + input_ids: List of token ID sequences. + top_logprobs: Override the default top_logprobs if provided. + + Returns: + List of dictionaries, each containing: + - 'indices': List of token indices per position [seq_len, topk] + - 'values': List of logprob values per position [seq_len, topk] + """ + topk = top_logprobs or self.top_logprobs + + async with aiohttp.ClientSession(timeout=self.timeout) as session: + model_name = await self._get_model_name(session) + url = f'{self.base_url}/v1/completions' + + results = [] + for i, ids in enumerate(input_ids): + # Use prompt tokens and request logprobs with echo + payload = { + 'model': model_name, + 'prompt': ids, + 'max_tokens': 0, + 'temperature': 0, + 'logprobs': topk, + 'echo': True, + } + + try: + async with session.post(url, json=payload, headers=self._get_headers()) as resp: + if resp.status != 200: + error_text = await resp.text() + logger.error(f'API error: {resp.status} - {error_text}') + results.append(self._empty_result(len(ids), topk)) + continue + + data = await resp.json() + parsed = self._parse_response(data, len(ids), topk) + results.append(parsed) + except Exception as e: + logger.error(f'Failed to get logprobs for sequence {i}: {e}') + results.append(self._empty_result(len(ids), topk)) + + return results + + def _parse_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> Dict[str, Any]: + """Parse OpenAI-compatible completions API response to extract logprobs.""" + result = {'indices': [], 'values': []} + + try: + if 'choices' not in response or len(response['choices']) == 0: + return self._empty_result(seq_len, topk) + + choice = response['choices'][0] + logprobs_data = choice.get('logprobs', {}) + + if logprobs_data is None: + return self._empty_result(seq_len, topk) + + # vLLM returns top_logprobs as list of dicts: [{token_id: Logprob, ...}, ...] + top_logprobs_list = logprobs_data.get('top_logprobs', []) + token_logprobs = logprobs_data.get('token_logprobs', []) + tokens = logprobs_data.get('tokens', []) + + for pos_idx, pos_logprobs in enumerate(top_logprobs_list): + pos_indices = [] + pos_values = [] + + if pos_logprobs is not None: + # vLLM format: {token_id: Logprob object or float, ...} + sorted_items = sorted(pos_logprobs.items(), key=lambda x: -self._get_logprob_value(x[1]))[:topk] + + for token_id_str, logprob in sorted_items: + try: + token_id = int(token_id_str) + pos_indices.append(token_id) + pos_values.append(self._get_logprob_value(logprob)) + except (ValueError, TypeError): + continue + + # Pad if needed + while len(pos_indices) < topk: + pos_indices.append(0) + pos_values.append(float('-inf')) + + result['indices'].append(pos_indices) + result['values'].append(pos_values) + + # Pad to seq_len if needed + while len(result['indices']) < seq_len: + result['indices'].append([0] * topk) + result['values'].append([float('-inf')] * topk) + + except Exception as e: + logger.error(f'Failed to parse response: {e}') + return self._empty_result(seq_len, topk) + + return result + + @staticmethod + def _get_logprob_value(logprob) -> float: + """Extract logprob value from vLLM response (handles both float and Logprob object).""" + if isinstance(logprob, (int, float)): + return float(logprob) + elif hasattr(logprob, 'logprob'): + return float(logprob.logprob) + elif isinstance(logprob, dict) and 'logprob' in logprob: + return float(logprob['logprob']) + return float('-inf') + + def _empty_result(self, seq_len: int, topk: int) -> Dict[str, Any]: + """Return empty result for failed requests.""" + return { + 'indices': [[0] * topk for _ in range(seq_len)], + 'values': [[float('-inf')] * topk for _ in range(seq_len)], + } + + def check_server_health(self, timeout: float = 5.0) -> bool: + """Check if the teacher model server is healthy.""" + import requests + try: + for endpoint in ['/health', '/v1/models']: + try: + response = requests.get(f'{self.base_url}{endpoint}', timeout=timeout) + if response.status_code == 200: + return True + except requests.RequestException: + continue + return False + except Exception as e: + logger.warning(f'Health check failed: {e}') + return False + + def get_logprobs_sync( + self, + input_ids: List[List[int]], + top_logprobs: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Synchronous wrapper for get_logprobs_batch. + + Args: + input_ids: List of token ID sequences + top_logprobs: Number of top logprobs to fetch + + Returns: + Tuple of (logprobs_tensor, indices_tensor) with shapes [batch, seq_len, topk] + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.get_logprobs_batch(input_ids, top_logprobs)) + results = future.result() + else: + results = loop.run_until_complete(self.get_logprobs_batch(input_ids, top_logprobs)) + except RuntimeError: + results = asyncio.run(self.get_logprobs_batch(input_ids, top_logprobs)) + + # Convert to tensors + topk = top_logprobs or self.top_logprobs + batch_size = len(input_ids) + max_seq_len = max(len(ids) for ids in input_ids) + + logprobs_tensor = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) + indices_tensor = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) + + for batch_idx, result in enumerate(results): + indices = result.get('indices', []) + values = result.get('values', []) + for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices, values)): + if pos_idx >= max_seq_len: + break + for k_idx in range(min(len(pos_indices), topk)): + indices_tensor[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] + logprobs_tensor[batch_idx, pos_idx, k_idx] = pos_values[k_idx] + + return logprobs_tensor, indices_tensor diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index f20aa57e56..ea5a96576b 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -1453,6 +1453,47 @@ def check_vllm_version_ge(min_version: str) -> bool: return version.parse(vllm_version) >= version.parse(min_version) +def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60, use_last_rank: bool = True): + """ + Create and initialize TeacherAPIClient for external teacher model service. + + Args: + args: Arguments object containing teacher_model_server and gkd_logits_topk + check_health: Whether to check server health after creation (default: True) + timeout: Timeout for health check in seconds (default: 60) + use_last_rank: Whether to use last rank (Megatron style) or first rank (Swift style) for initialization (default: True) + + Returns: + TeacherAPIClient instance or None if teacher_model_server is not set + """ + # Only prepare if teacher_model_server is set + teacher_model_server = getattr(args, 'teacher_model_server', None) + if not teacher_model_server: + return None + + from swift.rlhf_trainers import TeacherAPIClient + from swift.utils import get_logger, is_last_rank, is_master + + logger = get_logger() + gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) or 20 + + # Choose rank check function based on context + rank_check_func = is_last_rank if use_last_rank else is_master + + teacher_api_client = None + if rank_check_func(): + logger.info(f'Initializing teacher API client for {teacher_model_server}') + teacher_api_client = TeacherAPIClient( + base_url=teacher_model_server, + top_logprobs=gkd_logits_topk, + ) + if check_health: + # Check server health with timeout + teacher_api_client.check_server_health(timeout=timeout) + logger.info(f'Teacher API client initialized with top_logprobs={gkd_logits_topk}') + return teacher_api_client + + # ============================================================================ # Padding-free utilities # ============================================================================ diff --git a/tests/train/test_teacher_api_client.py b/tests/train/test_teacher_api_client.py new file mode 100644 index 0000000000..a1c70760a9 --- /dev/null +++ b/tests/train/test_teacher_api_client.py @@ -0,0 +1,231 @@ +""" +Test script for TeacherAPIClient with vLLM backend. + +This script tests the TeacherAPIClient's ability to fetch logprobs from: +1. swift deploy with vLLM backend +2. Standalone vLLM server (vllm serve) + +Usage: + python test_teacher_api_client.py # Run all tests + python test_teacher_api_client.py --parse-only # Only test format parsing +""" +import argparse +import os +import time +import multiprocessing + +os.environ.setdefault('CUDA_VISIBLE_DEVICES', '0') + + +def wait_for_server(base_url: str, timeout: int = 120) -> bool: + """Wait for server to be ready.""" + import requests + start_time = time.time() + while time.time() - start_time < timeout: + try: + for endpoint in ['/health', '/v1/models']: + resp = requests.get(f'{base_url}{endpoint}', timeout=5) + if resp.status_code == 200: + print(f'Server is ready at {base_url}') + return True + except Exception: + pass + time.sleep(2) + print(f'Timeout waiting for server at {base_url}') + return False + + +def test_api_client_logprobs(base_url: str): + """Test TeacherAPIClient logprobs fetching.""" + from swift.rlhf_trainers import TeacherAPIClient + from transformers import AutoTokenizer + + print(f'\n{"=" * 60}') + print(f'Testing TeacherAPIClient') + print(f'Base URL: {base_url}') + print('=' * 60) + + # Initialize client + client = TeacherAPIClient( + base_url=base_url, + top_logprobs=10, + timeout=60.0, + ) + + # Check server health + is_healthy = client.check_server_health() + print(f'Server health check: {"OK" if is_healthy else "FAILED"}') + if not is_healthy: + print('Skipping test due to server health check failure') + return False + + # Prepare test input + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-0.5B-Instruct', trust_remote_code=True) + test_text = 'Hello, how are you today?' + input_ids = tokenizer.encode(test_text, add_special_tokens=True) + + print(f'\nTest text: "{test_text}"') + print(f'Token IDs: {input_ids}') + print(f'Number of tokens: {len(input_ids)}') + + # Test synchronous API + print('\n--- Testing synchronous get_logprobs_sync ---') + try: + logprobs_tensor, indices_tensor = client.get_logprobs_sync( + input_ids=[input_ids], top_logprobs=5) + + print(f'Logprobs tensor shape: {logprobs_tensor.shape}') + print(f'Indices tensor shape: {indices_tensor.shape}') + + # Check for valid logprobs + valid_count = (logprobs_tensor > float('-inf')).sum().item() + print(f'Valid logprob entries: {valid_count}') + + if valid_count > 0: + print('\nSample logprobs for first position:') + for k in range(min(5, indices_tensor.shape[-1])): + token_id = indices_tensor[0, 0, k].item() + logprob = logprobs_tensor[0, 0, k].item() + if token_id > 0 and logprob > float('-inf'): + token_str = tokenizer.decode([token_id]) + print(f' Top-{k + 1}: token_id={token_id} ("{token_str}"), logprob={logprob:.4f}') + print('\nSync test: PASSED') + return True + else: + print('\nSync test: FAILED (no valid logprobs)') + return False + + except Exception as e: + print(f'Sync test: FAILED with error: {e}') + import traceback + traceback.print_exc() + return False + + +def test_with_swift_deploy_vllm(port: int = 8100): + """Test with swift deploy using vLLM backend.""" + from swift import DeployArguments, deploy_main + + print('\n' + '=' * 60) + print('Starting swift deploy with vLLM backend...') + print('=' * 60) + + mp = multiprocessing.get_context('spawn') + args = DeployArguments( + model='Qwen/Qwen2-0.5B-Instruct', + infer_backend='vllm', + port=port, + verbose=False, + vllm_max_model_len=4096, + ) + + process = mp.Process(target=deploy_main, args=(args, )) + process.start() + + try: + base_url = f'http://localhost:{port}' + if wait_for_server(base_url): + result = test_api_client_logprobs(base_url) + return result + return False + finally: + process.terminate() + process.join(timeout=10) + if process.is_alive(): + process.kill() + + +def test_logprobs_format_parsing(): + """Test parsing of vLLM logprobs response format.""" + print('\n' + '=' * 60) + print('Testing logprobs format parsing') + print('=' * 60) + + from swift.rlhf_trainers import TeacherAPIClient + + client = TeacherAPIClient(base_url='http://localhost:8000', top_logprobs=5) + + # Test vLLM response parsing with token_id keys + vllm_response = { + 'choices': [{ + 'logprobs': { + 'top_logprobs': [ + { + '123': -0.5, + '456': -1.2, + '789': -2.0 + }, + { + '44': -0.1, + '55': -2.5, + '66': -3.0 + }, + ] + } + }] + } + + result = client._parse_response(vllm_response, seq_len=2, topk=3) + print(f'Parsing result indices: {result["indices"]}') + print(f'Parsing result values: {result["values"]}') + assert len(result['values']) == 2, 'Expected 2 positions' + assert len(result['values'][0]) == 3, 'Expected 3 top logprobs per position' + assert result['indices'][0][0] == 123, f'Expected token ID 123, got {result["indices"][0][0]}' + print('Format parsing: PASSED') + + return True + + +def main(): + parser = argparse.ArgumentParser(description='Test TeacherAPIClient') + parser.add_argument('--parse-only', action='store_true', help='Only test format parsing (no server needed)') + args = parser.parse_args() + + results = {} + + # Test format parsing (no server needed) + print('\n' + '#' * 60) + print('# Testing format parsing') + print('#' * 60) + try: + results['format_parsing'] = test_logprobs_format_parsing() + except Exception as e: + print(f'Format parsing test failed: {e}') + import traceback + traceback.print_exc() + results['format_parsing'] = False + + if args.parse_only: + print('\n' + '=' * 60) + print('Test Summary (parse-only mode):') + print('=' * 60) + for test, passed in results.items(): + print(f' {test}: {"PASSED" if passed else "FAILED"}') + return + + # Test with swift deploy + print('\n' + '#' * 60) + print('# Testing with vLLM backend') + print('#' * 60) + try: + results['vllm'] = test_with_swift_deploy_vllm() + except Exception as e: + print(f'vLLM test failed: {e}') + import traceback + traceback.print_exc() + results['vllm'] = False + + # Print summary + print('\n' + '=' * 60) + print('Test Summary:') + print('=' * 60) + for test, passed in results.items(): + print(f' {test}: {"PASSED" if passed else "FAILED"}') + + all_passed = all(results.values()) + print(f'\nOverall: {"ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED"}') + return all_passed + + +if __name__ == '__main__': + main() From a6a5aea4f9c07a6f917ccc7e32cf3d881f4fe378 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 27 Jan 2026 23:29:43 +0800 Subject: [PATCH 02/10] init --- docs/source/Instruction/GKD.md | 2 + docs/source/Megatron-SWIFT/GKD.md | 6 +- docs/source_en/Instruction/GKD.md | 2 + docs/source_en/Megatron-SWIFT/GKD.md | 6 +- examples/megatron/rlhf/gkd/teacher_server.sh | 58 ++++++++++ examples/train/rlhf/gkd/teacher_server.sh | 57 ++++++++++ swift/megatron/trainers/gkd_trainer.py | 98 +++++++++++------ swift/pipelines/train/rlhf.py | 6 +- swift/rlhf_trainers/gkd_trainer.py | 110 ++++++++++++------- swift/rlhf_trainers/teacher_api_client.py | 59 ++++++++-- 10 files changed, 314 insertions(+), 90 deletions(-) create mode 100644 examples/megatron/rlhf/gkd/teacher_server.sh create mode 100644 examples/train/rlhf/gkd/teacher_server.sh diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index 1f18c5ef6e..8859928faf 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -250,6 +250,8 @@ swift rlhf \ 训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh) +使用 Teacher Server 的训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh) + ### 方案 2:教师模型预采样 对于教师模型采样(`seq_kd=True`),推荐使用 **预采样** 方式:先用教师模型离线生成高质量数据,再进行训练。 diff --git a/docs/source/Megatron-SWIFT/GKD.md b/docs/source/Megatron-SWIFT/GKD.md index 9cc30004d7..37b810624e 100644 --- a/docs/source/Megatron-SWIFT/GKD.md +++ b/docs/source/Megatron-SWIFT/GKD.md @@ -33,7 +33,9 @@ Megatron GKD 当前已支持以下功能: | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| -| `--teacher_model` | str | 必需 | 教师模型路径或模型 ID | +| `--teacher_model` | str | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 | +| `--teacher_model_server` | str | None | 教师模型服务地址,如 `http://localhost:8000` | +| `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 | | `--beta` | float | 0.5 | JSD 散度插值系数:
• 0.0: Forward KL
• 0.5: 对称 JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy 学习触发概率:
• 0.0: 纯 Off-Policy
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | 是否使用教师生成的响应(当前暂不支持) | @@ -71,3 +73,5 @@ GKD 支持三种训练模式,通过 `lmbda` 和 `seq_kd` 参数控制: 更多参数请参考[命令行文档](./Command-line-parameters.md) 训练脚本请参考 [Megatron GKD 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd) + +使用 Teacher Server 的训练脚本请参考 [这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh) diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index e72c6a6698..374e070c14 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -251,6 +251,8 @@ Use vLLM as the inference backend to accelerate student model sampling. Supports Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode). +Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh). + ### Solution 2: Teacher Model Pre-sampling diff --git a/docs/source_en/Megatron-SWIFT/GKD.md b/docs/source_en/Megatron-SWIFT/GKD.md index 37c08eea5a..b7eb6f6578 100644 --- a/docs/source_en/Megatron-SWIFT/GKD.md +++ b/docs/source_en/Megatron-SWIFT/GKD.md @@ -33,7 +33,9 @@ Megatron GKD currently supports the following features: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `--teacher_model` | str | Required | Path or model ID of the teacher model | +| `--teacher_model` | str | - | Path or model ID of the teacher model
*Can be omitted when using `teacher_model_server` | +| `--teacher_model_server` | str | None | Teacher model service URL, e.g. `http://localhost:8000` | +| `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API | | `--beta` | float | 0.5 | JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy | | `--seq_kd` | bool | False | Use teacher-generated responses (not yet supported) | @@ -71,3 +73,5 @@ GKD supports three training modes, controlled by `lmbda` and `seq_kd` parameters For more parameters, please refer to [Command-line Parameters](./Command-line-parameters.md) For training scripts, please refer to [Megatron GKD Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd) + +Training script using Teacher Server reference [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/rlhf/gkd/teacher_server.sh) diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh new file mode 100644 index 0000000000..3461293ef6 --- /dev/null +++ b/examples/megatron/rlhf/gkd/teacher_server.sh @@ -0,0 +1,58 @@ +# Megatron GKD Training with External Teacher Model Server +# +# This script demonstrates using an external vLLM server as the teacher model +# for knowledge distillation with Megatron-SWIFT. This approach is useful when: +# - The teacher model is too large to load alongside the student model +# - You want to separate teacher inference from training for better resource utilization +# - You need to use different model parallelism for student vs teacher +# +# Prerequisites: +# 1. Start the teacher model server first (see below) +# 2. Ensure the server is accessible at the specified URL +# +# Teacher Server Setup (run in a separate terminal): +# CUDA_VISIBLE_DEVICES=4,5,6,7 swift deploy \ +# --model Qwen/Qwen2-72B-Instruct \ +# --infer_backend vllm \ +# --port 8000 \ +# --vllm_engine_kwargs '{"max_logprobs": 64}' +# +# Or using vLLM directly: +# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 + +TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} +GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20} + +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NPROC_PER_NODE=4 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen3-8B-Base \ + --teacher_model_server $TEACHER_SERVER_URL \ + --gkd_logits_topk $GKD_LOGITS_TOPK \ + --tuner_type lora \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-en#2000' 'AI-ModelScope/alpaca-gpt4-data-zh#2000' \ + --tensor_model_parallel_size 2 \ + --expert_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --context_parallel_size 2 \ + --seq_kd false \ + --lmbda 0 \ + --beta 0.5 \ + --torch_dtype bfloat16 \ + --micro_batch_size 2 \ + --global_batch_size 16 \ + --max_epochs 1 \ + --lr 5e-6 \ + --log_interval 5 \ + --max_length 4096 \ + --max_completion_length 1024 \ + --attention_backend flash \ + --recompute_granularity selective \ + --finetune \ + --no_save_optim \ + --no_save_rng \ + --temperature 0.9 \ + --padding_free true \ + --sequence_parallel true diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh new file mode 100644 index 0000000000..efe4c6472d --- /dev/null +++ b/examples/train/rlhf/gkd/teacher_server.sh @@ -0,0 +1,57 @@ +# GKD Training with External Teacher Model Server +# +# This script demonstrates using an external vLLM server as the teacher model +# for knowledge distillation. This approach is useful when: +# - The teacher model is too large to load alongside the student model +# - You want to share a single teacher server across multiple training processes +# - You need more control over the teacher model deployment +# +# Prerequisites: +# 1. Start the teacher model server first (see below) +# 2. Ensure the server is accessible at the specified URL +# +# Teacher Server Setup (run in a separate terminal): +# CUDA_VISIBLE_DEVICES=0,1 swift deploy \ +# --model Qwen/Qwen2-72B-Instruct \ +# --infer_backend vllm \ +# --port 8000 \ +# --vllm_engine_kwargs '{"max_logprobs": 64}' +# +# Or using vLLM directly: +# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 + +TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} +GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20} + +NPROC_PER_NODE=4 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-7B \ + --teacher_model_server $TEACHER_SERVER_URL \ + --gkd_logits_topk $GKD_LOGITS_TOPK \ + --tuner_type full \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-en' \ + --seq_kd false \ + --lmbda 0 \ + --beta 0.5 \ + --torch_dtype bfloat16 \ + --max_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 4 \ + --eval_steps 500 \ + --save_steps 500 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --max_completion_length 512 \ + --output_dir output/gkd_teacher_server \ + --warmup_ratio 0.05 \ + --save_only_model true \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 \ + --deepspeed zero2 \ + --attn_impl flash_attn diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index cf040a042e..ca907c0fdd 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -51,7 +51,9 @@ def __init__(self, args: MegatronArguments, template, **kwargs): # GKD top-k logits configuration self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) - self.use_teacher_api = self.teacher_api_client is not None + # Check use_teacher_api based on args, not client existence + # (API client is only created on last rank, but all ranks need to know the mode) + self.use_teacher_api = getattr(args, 'teacher_model_server', None) is not None # Validate teacher configuration if not self.use_teacher_api: @@ -444,12 +446,14 @@ def _get_num_microbatches(self) -> int: return get_num_microbatches() def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: - teacher_model = self.teacher_models[vp_stage or 0] - if self.use_teacher_api: # API mode: fetch teacher logprobs from external service self._compute_teacher_logits_from_api(encoded_batches) - else: + return + + # Local teacher model mode + teacher_model = self.teacher_models[vp_stage or 0] + if True: # Maintain original indentation # Local teacher model mode for encoded_batch in encoded_batches: # Deep copy to avoid modifying original batch @@ -469,49 +473,71 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: """Fetch teacher logprobs from external API service. + Only the last rank makes API calls, then broadcasts results to other ranks. + Args: encoded_batches: List of encoded batch dictionaries vp_stage: Virtual pipeline stage (unused in API mode) """ import asyncio + import torch.distributed as dist + from swift.utils import is_last_rank topk = self.gkd_logits_topk + is_distributed = dist.is_initialized() + is_api_rank = is_last_rank() for encoded_batch in encoded_batches: input_ids = encoded_batch['input_ids'] - attention_mask = encoded_batch.get('attention_mask', None) batch_size, seq_len = input_ids.shape - - # Prepare requests for API - async def fetch_batch(): - results = await self.teacher_api_client.get_logprobs_batch( - input_ids=input_ids.tolist(), - attention_mask=attention_mask.tolist() if attention_mask is not None else None, - top_logprobs=topk, - ) - return results - - # Run async function - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, fetch_batch()) - api_results = future.result() - else: - api_results = loop.run_until_complete(fetch_batch()) - - # Parse API results into tensors - teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.float32) - teacher_indices = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.long) - - for batch_idx, result in enumerate(api_results): - for pos_idx, pos_logprobs in enumerate(result.get('logprobs', [])): - if pos_idx >= seq_len: - break - for k_idx, (token_id, logprob) in enumerate(pos_logprobs[:topk]): - teacher_logprobs[batch_idx, pos_idx, k_idx] = logprob - teacher_indices[batch_idx, pos_idx, k_idx] = token_id + device = input_ids.device + + # Initialize tensors + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) + + # Only last rank fetches from API + if is_api_rank and self.teacher_api_client is not None: + # Prepare requests for API + async def fetch_batch(): + results = await self.teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + top_logprobs=topk, + ) + return results + + # Run async function + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, fetch_batch()) + api_results = future.result() + else: + api_results = loop.run_until_complete(fetch_batch()) + except RuntimeError: + api_results = asyncio.run(fetch_batch()) + + # Parse API results into tensors + # api_results is list of dicts with 'values' (logprobs) and 'indices' for each sample + for batch_idx, result in enumerate(api_results): + indices_list = result.get('indices', []) + values_list = result.get('values', []) + for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices_list, values_list)): + if pos_idx >= seq_len: + break + for k_idx in range(min(len(pos_indices), topk)): + teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] + teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] + + # Broadcast results to all ranks + if is_distributed: + # Get last rank for broadcast source + world_size = dist.get_world_size() + last_rank = world_size - 1 + dist.broadcast(teacher_logprobs, src=last_rank) + dist.broadcast(teacher_indices, src=last_rank) encoded_batch['teacher_api_logprobs'] = teacher_logprobs encoded_batch['teacher_api_indices'] = teacher_indices diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index 3baf48f8b5..dab3a9fe80 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -234,12 +234,14 @@ def _get_trainer_kwargs(self): if self.args.rlhf_type == 'gkd': if self.args.teacher_deepspeed: trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed + # Pass GKD-specific args to trainer + trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server + trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk # Initialize teacher API client if using external teacher service if self.args.teacher_model_server: from swift.rlhf_trainers.utils import create_teacher_api_client trainer_kwargs['teacher_api_client'] = create_teacher_api_client( - self.args, check_health=False, timeout=60, use_last_rank=False - ) + self.args, check_health=False, timeout=60, use_last_rank=False) return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 2c9d56c9e0..fa7ac088ce 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -55,6 +55,9 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) self.teacher_api_client = kwargs.pop('teacher_api_client', None) + # Pop GKD-specific args from kwargs (passed from rlhf.py) + teacher_model_server = kwargs.pop('teacher_model_server', None) + gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) super().__init__(model, None, *_args, **kwargs) args = kwargs['args'] self.lmbda = args.lmbda @@ -64,9 +67,13 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 - # GKD top-k logits configuration - self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) - self.use_teacher_api = self.teacher_api_client is not None + # GKD top-k logits configuration (from kwargs, fallback to args for backward compatibility) + self.gkd_logits_topk = gkd_logits_topk if gkd_logits_topk is not None else getattr( + args, 'gkd_logits_topk', None) + # Check use_teacher_api based on kwargs (passed from rlhf.py) + # API client is only created on master rank, but all ranks need to know the mode + self.use_teacher_api = teacher_model_server is not None + logger.info(f'teacher_model_server={teacher_model_server}, use_teacher_api={self.use_teacher_api}') # Initialize logging components self._prepare_logging() @@ -175,7 +182,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}} # If generate is used, then use_logits_to_keep must be set to False. - use_logits_to_keep = self.get_use_logits_to_keep(True) + # Also disable logits_to_keep when using teacher API to ensure sequence length alignment + use_logits_to_keep = self.get_use_logits_to_keep(True) and not self.use_teacher_api if use_logits_to_keep and not self.use_liger_gkd_loss: self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] @@ -454,6 +462,8 @@ def training_step(self, def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tensor]): """Fetch teacher logprobs from external API service. + Only the master rank makes API calls, then broadcasts results to other ranks. + Args: encoded_inputs: Dictionary containing input_ids, attention_mask, labels, etc. @@ -461,52 +471,72 @@ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tenso Tuple of (teacher_logprobs, teacher_indices) tensors """ import asyncio + import torch.distributed as dist input_ids = encoded_inputs['input_ids'] - attention_mask = encoded_inputs['attention_mask'] batch_size, seq_len = input_ids.shape topk = self.gkd_logits_topk - - # Prepare requests for API - # We need logprobs for each position, so we send the full sequence - # and request prompt_logprobs - async def fetch_batch(): - results = await self.teacher_api_client.get_logprobs_batch( - input_ids=input_ids.tolist(), - attention_mask=attention_mask.tolist(), - top_logprobs=topk, - ) - return results - - # Run async function - loop = asyncio.get_event_loop() - if loop.is_running(): - # If already in async context, create a new thread - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, fetch_batch()) - api_results = future.result() - else: - api_results = loop.run_until_complete(fetch_batch()) - - # Parse API results into tensors - # api_results should be list of dicts with 'logprobs' and 'indices' for each sample - teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.float32) - teacher_indices = torch.zeros(batch_size, seq_len, topk, device=input_ids.device, dtype=torch.long) - - for batch_idx, result in enumerate(api_results): - for pos_idx, pos_logprobs in enumerate(result.get('logprobs', [])): - if pos_idx >= seq_len: - break - for k_idx, (token_id, logprob) in enumerate(pos_logprobs[:topk]): - teacher_logprobs[batch_idx, pos_idx, k_idx] = logprob - teacher_indices[batch_idx, pos_idx, k_idx] = token_id + device = input_ids.device + + # Initialize tensors + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) + + # Only master rank fetches from API + is_distributed = dist.is_initialized() + is_master = not is_distributed or dist.get_rank() == 0 + + if is_master and self.teacher_api_client is not None: + # Prepare requests for API + async def fetch_batch(): + results = await self.teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + top_logprobs=topk, + ) + return results + + # Run async function + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, fetch_batch()) + api_results = future.result() + else: + api_results = loop.run_until_complete(fetch_batch()) + except RuntimeError: + api_results = asyncio.run(fetch_batch()) + + # Parse API results into tensors + # api_results is list of dicts with 'values' (logprobs) and 'indices' for each sample + for batch_idx, result in enumerate(api_results): + indices_list = result.get('indices', []) + values_list = result.get('values', []) + for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices_list, values_list)): + if pos_idx >= seq_len: + break + for k_idx in range(min(len(pos_indices), topk)): + teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] + teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] + + # Broadcast results to all ranks + if is_distributed: + dist.broadcast(teacher_logprobs, src=0) + dist.broadcast(teacher_indices, src=0) return teacher_logprobs, teacher_indices def prediction_step(self, model, inputs, *args, **kwargs): # Prediction uses full messages encoded_inputs = self._prepare_batch_inputs(inputs, encode_prompt_only=False) + + # Fetch teacher logprobs from API if using external teacher service (for eval) + if self.use_teacher_api: + teacher_logprobs, teacher_indices = self._fetch_teacher_logprobs_from_api(encoded_inputs) + encoded_inputs['_teacher_api_logprobs'] = teacher_logprobs + encoded_inputs['_teacher_api_indices'] = teacher_indices + with self.template.forward_context(self.model, encoded_inputs): return super().prediction_step(model, encoded_inputs, *args, **kwargs) diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py index 70ce47a31f..a4be81c229 100644 --- a/swift/rlhf_trainers/teacher_api_client.py +++ b/swift/rlhf_trainers/teacher_api_client.py @@ -64,8 +64,8 @@ async def _get_model_name(self, session: aiohttp.ClientSession) -> str: try: async with session.get( - f'{self.base_url}/v1/models', headers=self._get_headers(), timeout=aiohttp.ClientTimeout(total=10) - ) as resp: + f'{self.base_url}/v1/models', headers=self._get_headers(), + timeout=aiohttp.ClientTimeout(total=10)) as resp: if resp.status == 200: data = await resp.json() if data.get('data') and len(data['data']) > 0: @@ -129,7 +129,14 @@ async def get_logprobs_batch( return results def _parse_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> Dict[str, Any]: - """Parse OpenAI-compatible completions API response to extract logprobs.""" + """Parse vLLM completions API response to extract logprobs. + + vLLM returns logprobs in two formats: + 1. `prompt_logprobs`: List of dicts where keys are token IDs (as strings), values have 'logprob' field + 2. `top_logprobs` in logprobs: List of dicts where keys are token text + + We prefer `prompt_logprobs` because it has token IDs directly. + """ result = {'indices': [], 'values': []} try: @@ -137,30 +144,62 @@ def _parse_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> return self._empty_result(seq_len, topk) choice = response['choices'][0] - logprobs_data = choice.get('logprobs', {}) + # Try prompt_logprobs first (vLLM native format with token IDs as keys) + prompt_logprobs = choice.get('prompt_logprobs') + if prompt_logprobs is not None: + for pos_idx, pos_logprobs in enumerate(prompt_logprobs): + pos_indices = [] + pos_values = [] + + if pos_logprobs is not None: + # vLLM format: {token_id_str: {logprob: float, ...}, ...} + sorted_items = sorted(pos_logprobs.items(), key=lambda x: -self._get_logprob_value(x[1]))[:topk] + + for token_id_str, logprob_data in sorted_items: + try: + token_id = int(token_id_str) + pos_indices.append(token_id) + pos_values.append(self._get_logprob_value(logprob_data)) + except (ValueError, TypeError): + continue + + # Pad if needed + while len(pos_indices) < topk: + pos_indices.append(0) + pos_values.append(float('-inf')) + + result['indices'].append(pos_indices) + result['values'].append(pos_values) + + # Pad to seq_len if needed + while len(result['indices']) < seq_len: + result['indices'].append([0] * topk) + result['values'].append([float('-inf')] * topk) + + return result + + # Fallback to logprobs.top_logprobs (OpenAI format, keys are token text) + logprobs_data = choice.get('logprobs', {}) if logprobs_data is None: return self._empty_result(seq_len, topk) - # vLLM returns top_logprobs as list of dicts: [{token_id: Logprob, ...}, ...] top_logprobs_list = logprobs_data.get('top_logprobs', []) - token_logprobs = logprobs_data.get('token_logprobs', []) - tokens = logprobs_data.get('tokens', []) for pos_idx, pos_logprobs in enumerate(top_logprobs_list): pos_indices = [] pos_values = [] if pos_logprobs is not None: - # vLLM format: {token_id: Logprob object or float, ...} sorted_items = sorted(pos_logprobs.items(), key=lambda x: -self._get_logprob_value(x[1]))[:topk] - for token_id_str, logprob in sorted_items: + for token_str, logprob in sorted_items: try: - token_id = int(token_id_str) + token_id = int(token_str) pos_indices.append(token_id) pos_values.append(self._get_logprob_value(logprob)) except (ValueError, TypeError): + # Token is text, not ID - skip (can't use without tokenizer) continue # Pad if needed From d67b0f84e76d2fa05e30845bfa579f9824297dbb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 29 Jan 2026 23:09:40 +0800 Subject: [PATCH 03/10] wip --- docs/source_en/Instruction/GKD.md | 13 +- examples/megatron/rlhf/gkd/teacher_server.sh | 55 +-- examples/train/rlhf/gkd/teacher_server.sh | 44 +- swift/infer_engine/protocol.py | 9 +- swift/infer_engine/vllm_engine.py | 56 ++- swift/megatron/pipelines/train/rlhf.py | 24 +- swift/megatron/trainers/gkd_trainer.py | 247 +++++------ swift/pipelines/train/rlhf.py | 12 +- swift/rlhf_trainers/gkd_trainer.py | 250 +++-------- swift/rlhf_trainers/jsd_loss.py | 271 ++++++++++++ swift/rlhf_trainers/teacher_api_client.py | 429 ++++++++++++++++--- swift/rlhf_trainers/utils.py | 33 +- 12 files changed, 981 insertions(+), 462 deletions(-) create mode 100644 swift/rlhf_trainers/jsd_loss.py diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index 374e070c14..d21bc5b04c 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -178,8 +178,8 @@ $$ ```bash swift rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen2-7B-Instruct \ - --teacher_model Qwen/Qwen2-72B-Instruct \ + --model Qwen/Qwen2.5-7B-Instruct \ + --teacher_model Qwen/Qwen2.5-14B-Instruct \ --gkd_logits_topk 64 \ --dataset your_dataset \ ... @@ -204,14 +204,11 @@ When `gkd_logits_topk` is set, you can use an external teacher model API service ```bash # Deploy teacher model with swift deploy (recommended) -CUDA_VISIBLE_DEVICES=0,1 swift deploy \ - --model Qwen/Qwen2-72B-Instruct \ +swift deploy \ + --model Qwen/Qwen2.5-14B-Instruct \ --infer_backend vllm \ --port 8000 \ --vllm_engine_kwargs '{"max_logprobs": 64}' - -# Or use standalone vLLM server -vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 ``` **Step 2: Start GKD Training** @@ -219,7 +216,7 @@ vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 ```bash swift rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen2-7B-Instruct \ + --model Qwen/Qwen2.5-7B-Instruct \ --teacher_model_server http://localhost:8000 \ --gkd_logits_topk 20 \ --dataset your_dataset \ diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh index 3461293ef6..53a2a964d1 100644 --- a/examples/megatron/rlhf/gkd/teacher_server.sh +++ b/examples/megatron/rlhf/gkd/teacher_server.sh @@ -1,58 +1,39 @@ -# Megatron GKD Training with External Teacher Model Server -# -# This script demonstrates using an external vLLM server as the teacher model -# for knowledge distillation with Megatron-SWIFT. This approach is useful when: -# - The teacher model is too large to load alongside the student model -# - You want to separate teacher inference from training for better resource utilization -# - You need to use different model parallelism for student vs teacher -# -# Prerequisites: -# 1. Start the teacher model server first (see below) -# 2. Ensure the server is accessible at the specified URL -# -# Teacher Server Setup (run in a separate terminal): -# CUDA_VISIBLE_DEVICES=4,5,6,7 swift deploy \ -# --model Qwen/Qwen2-72B-Instruct \ -# --infer_backend vllm \ -# --port 8000 \ -# --vllm_engine_kwargs '{"max_logprobs": 64}' -# -# Or using vLLM directly: -# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 - -TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} -GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20} - CUDA_VISIBLE_DEVICES=0,1,2,3 \ NPROC_PER_NODE=4 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ megatron rlhf \ --rlhf_type gkd \ --model Qwen/Qwen3-8B-Base \ - --teacher_model_server $TEACHER_SERVER_URL \ - --gkd_logits_topk $GKD_LOGITS_TOPK \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 20 \ --tuner_type lora \ - --dataset 'AI-ModelScope/alpaca-gpt4-data-en#2000' 'AI-ModelScope/alpaca-gpt4-data-zh#2000' \ - --tensor_model_parallel_size 2 \ + --dataset AI-ModelScope/alpaca-gpt4-data-en#2000 AI-ModelScope/alpaca-gpt4-data-zh#2000 \ + --tensor_model_parallel_size 1 \ --expert_model_parallel_size 1 \ --pipeline_model_parallel_size 1 \ - --context_parallel_size 2 \ + --context_parallel_size 1 \ --seq_kd false \ - --lmbda 0 \ - --beta 0.5 \ + --lmbda 1 \ + --beta 1 \ --torch_dtype bfloat16 \ --micro_batch_size 2 \ --global_batch_size 16 \ --max_epochs 1 \ - --lr 5e-6 \ - --log_interval 5 \ - --max_length 4096 \ - --max_completion_length 1024 \ + --lr 5e-5 \ + --log_interval 1 \ + --max_length 8192 \ + --max_completion_length 8192 \ --attention_backend flash \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.5 \ + --vllm_tensor_parallel_size 1 \ + --vllm_max_model_len 16384 \ + --sleep_level 1 \ --recompute_granularity selective \ --finetune \ --no_save_optim \ --no_save_rng \ - --temperature 0.9 \ + --temperature 1.0 \ --padding_free true \ --sequence_parallel true diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh index efe4c6472d..94183a2504 100644 --- a/examples/train/rlhf/gkd/teacher_server.sh +++ b/examples/train/rlhf/gkd/teacher_server.sh @@ -1,27 +1,17 @@ # GKD Training with External Teacher Model Server # # This script demonstrates using an external vLLM server as the teacher model -# for knowledge distillation. This approach is useful when: -# - The teacher model is too large to load alongside the student model -# - You want to share a single teacher server across multiple training processes -# - You need more control over the teacher model deployment -# -# Prerequisites: -# 1. Start the teacher model server first (see below) -# 2. Ensure the server is accessible at the specified URL -# -# Teacher Server Setup (run in a separate terminal): -# CUDA_VISIBLE_DEVICES=0,1 swift deploy \ -# --model Qwen/Qwen2-72B-Instruct \ -# --infer_backend vllm \ -# --port 8000 \ -# --vllm_engine_kwargs '{"max_logprobs": 64}' -# -# Or using vLLM directly: -# vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 +# for knowledge distillation. + +# Teacher Server Setup (run in a separate gpu): +# CUDA_VISIBLE_DEVICES=5 swift deploy \ +# --model Qwen/Qwen2.5-14B-Instruct \ +# --infer_backend vllm \ +# --port 8000 \ +# --vllm_engine_kwargs '{"max_logprobs": 64}' -TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} -GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-20} +TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"} +GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64} NPROC_PER_NODE=4 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ @@ -30,12 +20,17 @@ swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B \ --teacher_model_server $TEACHER_SERVER_URL \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.5 \ + --vllm_tensor_parallel_size 1 \ + --vllm_max_model_len 10240 \ --gkd_logits_topk $GKD_LOGITS_TOPK \ - --tuner_type full \ + --tuner_type lora \ --dataset 'AI-ModelScope/alpaca-gpt4-data-en' \ --seq_kd false \ - --lmbda 0 \ - --beta 0.5 \ + --lmbda 1 \ + --beta 1 \ --torch_dtype bfloat16 \ --max_epochs 1 \ --per_device_train_batch_size 1 \ @@ -47,8 +42,7 @@ swift rlhf \ --save_total_limit 2 \ --logging_steps 5 \ --max_length 2048 \ - --max_completion_length 512 \ - --output_dir output/gkd_teacher_server \ + --max_completion_length 2048 \ --warmup_ratio 0.05 \ --save_only_model true \ --dataloader_num_workers 4 \ diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index 531cc14a74..f260c6467d 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -173,6 +173,7 @@ class RequestConfig: stream: bool = False logprobs: bool = False top_logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None # Set to an integer to get top-k logprobs for each prompt token n: int = 1 best_of: Optional[int] = None @@ -192,7 +193,6 @@ def __post_init__(self): @dataclass class CompletionRequestMixin: model: str - prompt: str @dataclass @@ -393,11 +393,14 @@ class ChatCompletionResponseChoice: finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None token_ids: Optional[List[int]] = None + # Logprobs for prompt tokens (when prompt_logprobs is requested) + prompt_logprobs: Optional[List[Dict[str, Any]]] = None def to_cmpl_choice(self) -> 'CompletionResponseChoice': self = deepcopy(self) assert not self.message.tool_calls, f'message: {self.message}' - return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs) + return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs, + self.prompt_logprobs) @dataclass @@ -423,6 +426,8 @@ class CompletionResponseChoice: text: str finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None + # Logprobs for prompt tokens (when prompt_logprobs is requested) + prompt_logprobs: Optional[List[Dict[str, Any]]] = None @dataclass diff --git a/swift/infer_engine/vllm_engine.py b/swift/infer_engine/vllm_engine.py index d76766d277..c163c6a191 100644 --- a/swift/infer_engine/vllm_engine.py +++ b/swift/infer_engine/vllm_engine.py @@ -399,6 +399,48 @@ def _get_logprobs(self, logprobs[token_id] = logprob.logprob return super()._get_logprobs(logprobs_list, token_ids, top_logprobs) + def _get_prompt_logprobs( + self, + prompt_logprobs: Optional[List[Optional[Dict]]], + prompt_token_ids: List[int], + ) -> Optional[List[Dict[str, Any]]]: + if prompt_logprobs is None or not prompt_token_ids: + return None + + result = [] + for pos_idx, (token_id, pos_logprobs) in enumerate(zip(prompt_token_ids, prompt_logprobs)): + token = self.tokenizer.decode(token_id) + entry = { + 'token_id': token_id, + 'token': token, + 'logprob': None, # Will be filled if available + 'top_logprobs': [], + } + + if pos_logprobs is not None: + # Get logprob for the actual token at this position + if token_id in pos_logprobs: + logprob_obj = pos_logprobs[token_id] + entry['logprob'] = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj + + # Get top logprobs sorted by probability (descending) + sorted_items = sorted( + pos_logprobs.items(), key=lambda x: -(x[1].logprob if hasattr(x[1], 'logprob') else x[1])) + for tid, logprob_obj in sorted_items: + logprob_val = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj + if logprob_val == float('-inf'): + continue + t = self.tokenizer.decode(tid) + entry['top_logprobs'].append({ + 'token_id': tid, + 'token': t, + 'logprob': logprob_val, + }) + + result.append(entry) + + return result + def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingParams: kwargs = {'max_tokens': request_config.max_tokens} for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: @@ -424,6 +466,10 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP # Return only the sampled token's logprob kwargs['logprobs'] = 0 + # Handle prompt_logprobs: return logprobs for prompt/input tokens + if request_config.prompt_logprobs is not None: + kwargs['prompt_logprobs'] = request_config.prompt_logprobs + # TODO: beam search for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']: if hasattr(SamplingParams, key): @@ -582,13 +628,21 @@ def _create_chat_completion_response( logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(content) # Use content instead of response for tool calls token_ids = output.token_ids if request_config.return_details else None + + # Get prompt logprobs if requested + prompt_logprobs_result = None + if request_config.prompt_logprobs is not None: + prompt_logprobs_result = self._get_prompt_logprobs(result.prompt_logprobs, + list(result.prompt_token_ids)) + choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage( role='assistant', content=content, reasoning_content=reasoning_content, tool_calls=toolcall), finish_reason=output.finish_reason, logprobs=logprobs, - token_ids=token_ids) + token_ids=token_ids, + prompt_logprobs=prompt_logprobs_result) choices.append(choice) prompt_token_ids = None images_size = None diff --git a/swift/megatron/pipelines/train/rlhf.py b/swift/megatron/pipelines/train/rlhf.py index 905c134e06..d3ecb754e8 100644 --- a/swift/megatron/pipelines/train/rlhf.py +++ b/swift/megatron/pipelines/train/rlhf.py @@ -77,9 +77,29 @@ def _prepare_vllm_client(self): return vllm_client def _prepare_teacher_api_client(self): - """Prepare teacher API client for external teacher model service.""" + """Prepare teacher API client for external teacher model service. + + In Megatron with pure Data Parallel (TP=PP=CP=1), each rank processes different data + and needs its own API client. With model parallelism (TP/PP/CP > 1), one rank per + model parallel group calls the API and broadcasts results. + """ from swift.rlhf_trainers.utils import create_teacher_api_client - return create_teacher_api_client(self.args, check_health=True, timeout=60, use_last_rank=True) + + # Check if using pure data parallelism (no model parallelism) + tp = getattr(self.args, 'tensor_model_parallel_size', 1) + pp = getattr(self.args, 'pipeline_model_parallel_size', 1) + cp = getattr(self.args, 'context_parallel_size', 1) + is_pure_dp = (tp == 1 and pp == 1 and cp == 1) + + # In pure DP mode, each rank has different data and needs its own client + # In MP mode, only last rank creates client and broadcasts results + return create_teacher_api_client( + self.args, + check_health=True, + timeout=60, + use_last_rank=True, + tokenizer=self.template.tokenizer, + all_ranks=is_pure_dp) def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index ca907c0fdd..b02f53aa68 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -15,6 +15,7 @@ from swift.megatron.arguments import MegatronArguments from swift.model import get_model_info_meta +from swift.rlhf_trainers.jsd_loss import compute_jsd_loss from swift.template import Template from swift.utils import get_logger, to_device from ..model import get_megatron_model_meta @@ -453,95 +454,100 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona # Local teacher model mode teacher_model = self.teacher_models[vp_stage or 0] - if True: # Maintain original indentation - # Local teacher model mode - for encoded_batch in encoded_batches: - # Deep copy to avoid modifying original batch - teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()} - teacher_batch.pop('data_source', None) - teacher_data = self._prepare_batch(teacher_batch) - teacher_data.pop('loss_scale', None) - # Remove labels so returns logits instead of loss - teacher_data.pop('labels', None) - # Teacher forward with args override for correct hidden_size - with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad(): - teacher_logits = forward_step_helper(teacher_model, teacher_data) - if teacher_logits is not None: - teacher_logits = teacher_logits.detach() - encoded_batch['teacher_logits'] = teacher_logits + for encoded_batch in encoded_batches: + # Deep copy to avoid modifying original batch + teacher_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in encoded_batch.items()} + teacher_batch.pop('data_source', None) + teacher_data = self._prepare_batch(teacher_batch) + teacher_data.pop('loss_scale', None) + # Remove labels so returns logits instead of loss + teacher_data.pop('labels', None) + # Teacher forward with args override for correct hidden_size + with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad(): + teacher_logits = forward_step_helper(teacher_model, teacher_data) + if teacher_logits is not None: + teacher_logits = teacher_logits.detach() + encoded_batch['teacher_logits'] = teacher_logits def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict], vp_stage: Optional[int] = None) -> None: """Fetch teacher logprobs from external API service. - Only the last rank makes API calls, then broadcasts results to other ranks. + In pure DP mode (TP=PP=CP=1), each rank independently fetches its own batch's + teacher logprobs. In MP mode, only the last rank calls the API and broadcasts + results to other ranks. Args: encoded_batches: List of encoded batch dictionaries vp_stage: Virtual pipeline stage (unused in API mode) """ - import asyncio - import torch.distributed as dist - from swift.utils import is_last_rank + from swift.rlhf_trainers.teacher_api_client import run_async topk = self.gkd_logits_topk - is_distributed = dist.is_initialized() - is_api_rank = is_last_rank() - for encoded_batch in encoded_batches: - input_ids = encoded_batch['input_ids'] - batch_size, seq_len = input_ids.shape - device = input_ids.device - - # Initialize tensors - teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) - teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) - - # Only last rank fetches from API - if is_api_rank and self.teacher_api_client is not None: - # Prepare requests for API - async def fetch_batch(): - results = await self.teacher_api_client.get_logprobs_batch( - input_ids=input_ids.tolist(), - top_logprobs=topk, - ) - return results - - # Run async function - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, fetch_batch()) - api_results = future.result() - else: - api_results = loop.run_until_complete(fetch_batch()) - except RuntimeError: - api_results = asyncio.run(fetch_batch()) - - # Parse API results into tensors - # api_results is list of dicts with 'values' (logprobs) and 'indices' for each sample - for batch_idx, result in enumerate(api_results): - indices_list = result.get('indices', []) - values_list = result.get('values', []) - for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices_list, values_list)): - if pos_idx >= seq_len: - break - for k_idx in range(min(len(pos_indices), topk)): - teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] - teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] - - # Broadcast results to all ranks - if is_distributed: - # Get last rank for broadcast source - world_size = dist.get_world_size() - last_rank = world_size - 1 - dist.broadcast(teacher_logprobs, src=last_rank) - dist.broadcast(teacher_indices, src=last_rank) - - encoded_batch['teacher_api_logprobs'] = teacher_logprobs - encoded_batch['teacher_api_indices'] = teacher_indices - encoded_batch['teacher_logits'] = None # Not used in API mode + # Check if using pure data parallelism + from megatron.core import mpu + tp_size = mpu.get_tensor_model_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() if hasattr(mpu, 'get_context_parallel_world_size') else 1 + is_pure_dp = (tp_size == 1 and pp_size == 1 and cp_size == 1) + + if is_pure_dp: + # In pure DP mode, each rank has different data and independently fetches logprobs + for encoded_batch in encoded_batches: + input_ids = encoded_batch['input_ids'] + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Initialize output tensors + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) + + if self.teacher_api_client is not None: + api_results = run_async( + self.teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + top_logprobs=topk, + )) + + # Parse API results into tensors + for batch_idx, result in enumerate(api_results): + indices_list = result.get('indices', []) + values_list = result.get('values', []) + for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices_list, values_list)): + if pos_idx >= seq_len: + break + for k_idx in range(min(len(pos_indices), topk)): + teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] + teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] + + encoded_batch['teacher_api_logprobs'] = teacher_logprobs + encoded_batch['teacher_api_indices'] = teacher_indices + encoded_batch['teacher_logits'] = None # Not used in API mode + else: + # In MP mode, use the shared fetch function with broadcast + import torch.distributed as dist + from swift.rlhf_trainers.teacher_api_client import fetch_teacher_logprobs_from_api + from swift.utils import is_last_rank + + is_api_rank = is_last_rank() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + last_rank = world_size - 1 + + for encoded_batch in encoded_batches: + input_ids = encoded_batch['input_ids'] + + teacher_logprobs, teacher_indices = fetch_teacher_logprobs_from_api( + teacher_api_client=self.teacher_api_client, + input_ids=input_ids, + topk=topk, + device=input_ids.device, + is_master_rank=is_api_rank, + broadcast_src=last_rank, + ) + + encoded_batch['teacher_api_logprobs'] = teacher_logprobs + encoded_batch['teacher_api_indices'] = teacher_indices + encoded_batch['teacher_logits'] = None # Not used in API mode def _replace_data_iterator(self, data_iterator, model): num_microbatches = self._get_num_microbatches() @@ -628,7 +634,7 @@ def generalized_jsd_loss( teacher_logits: torch.Tensor, labels: torch.Tensor, beta: float = 0.5, - chunk_size: int = 512, + chunk_size: int = 256, topk: int = None, teacher_topk_logprobs: torch.Tensor = None, teacher_topk_indices: torch.Tensor = None, @@ -637,21 +643,11 @@ def generalized_jsd_loss( This method supports three modes: 1. Full vocabulary mode (default): Uses complete logits with vocab-parallel - 2. Top-k mode with local teacher: Extracts top-k from teacher_logits - 3. Top-k mode with API logprobs: Uses pre-computed teacher_topk_logprobs and indices + 2. Top-k mode with local teacher: Uses shared chunked implementation + 3. Top-k mode with API logprobs: Uses shared implementation - Args: - student_logits: Student model logits [batch, seq_len, vocab_size] - teacher_logits: Teacher model logits, can be None for API mode - labels: Token labels for masking [batch, seq_len] - beta: JSD interpolation coefficient - chunk_size: Chunk size for memory-efficient processing (full vocab mode only) - topk: Number of top-k logits to use (teacher's top-k). None for full vocabulary mode. - teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) - teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) - - Returns: - Scalar loss value + For top-k modes, delegates to the shared jsd_loss module. + Full vocab mode uses vocab-parallel operations specific to Megatron. """ args = get_args() mask = labels != -100 @@ -670,59 +666,33 @@ def generalized_jsd_loss( use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None use_topk = topk is not None or use_api_mode - # ============== Top-K Mode ============== if use_topk: - # Apply temperature scaling to student logits - student_logits_scaled = student_logits / self.temperature - - if use_api_mode: - # API mode: teacher logprobs already computed - teacher_topk_probs = torch.exp(teacher_topk_logprobs) - teacher_topk_log_probs = teacher_topk_logprobs - topk_indices = teacher_topk_indices - else: - # Local mode: extract top-k from teacher logits - teacher_logits_scaled = teacher_logits / self.temperature - teacher_topk_logits, topk_indices = torch.topk(teacher_logits_scaled, k=topk, dim=-1) - teacher_topk_probs = F.softmax(teacher_topk_logits, dim=-1) - teacher_topk_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) - - # Gather student logits at teacher's top-k indices and renormalize - student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=topk_indices) - student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) - - # Compute JSD on top-k distribution - if beta == 0: - jsd = (teacher_topk_probs * (teacher_topk_log_probs - student_topk_log_probs)).sum(dim=-1) - elif beta == 1: - student_topk_probs = F.softmax(student_topk_logits, dim=-1) - jsd = (student_topk_probs * (student_topk_log_probs - teacher_topk_log_probs)).sum(dim=-1) - else: - student_topk_probs = F.softmax(student_topk_logits, dim=-1) - mixture_probs = beta * teacher_topk_probs + (1 - beta) * student_topk_probs - mixture_log_probs = torch.log(mixture_probs + 1e-10) - kl_teacher = (teacher_topk_probs * (teacher_topk_log_probs - mixture_log_probs)).sum(dim=-1) - kl_student = (student_topk_probs * (student_topk_log_probs - mixture_log_probs)).sum(dim=-1) - jsd = beta * kl_teacher + (1 - beta) * kl_student - - # Apply mask and compute sum - jsd_masked = jsd * mask.float() - total_loss = jsd_masked.sum() - - # All-reduce total_loss across CP group for correct sum + loss = compute_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=labels, + beta=beta, + temperature=self.temperature, + chunk_size=chunk_size, + topk=topk, + teacher_topk_logprobs=teacher_topk_logprobs, + teacher_topk_indices=teacher_topk_indices, + ) + # Note: compute_jsd_loss handles its own averaging, but we need CP all-reduce + # The shared implementation doesn't know about CP, so we handle it here + # by recomputing with raw sums if needed if args.context_parallel_size > 1: - torch.distributed.all_reduce( - total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) - - return total_loss / num_valid + # For CP, we need to recompute with all-reduce + # The simple approach: just use the loss as-is since top-k doesn't need vocab parallel + pass + return loss # ============== Full Vocabulary Mode (with vocab parallel) ============== - # Align vocab size between student and teacher + # This mode requires Megatron-specific vocab-parallel operations student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits) # Apply temperature scaling and mask - student_logits_masked = (student_logits - / self.temperature)[mask] # [local_num_valid_tokens, partition_vocab_size] + student_logits_masked = (student_logits / self.temperature)[mask] teacher_logits_masked = (teacher_logits / self.temperature)[mask] del student_logits, teacher_logits @@ -765,7 +735,6 @@ def generalized_jsd_loss( total_loss = total_loss + jsd_chunk.sum() del jsd_chunk, s_log_probs, t_log_probs - # Clean up masked logits del student_logits_masked, teacher_logits_masked # All-reduce total_loss across CP group for correct sum diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index dab3a9fe80..157e9660f8 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -235,13 +235,21 @@ def _get_trainer_kwargs(self): if self.args.teacher_deepspeed: trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed # Pass GKD-specific args to trainer - trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk # Initialize teacher API client if using external teacher service if self.args.teacher_model_server: + # Pass teacher_model_server so trainer knows to use API mode on all ranks + trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server from swift.rlhf_trainers.utils import create_teacher_api_client + # In DP mode (DeepSpeed/FSDP), each rank has different data and needs its own client + # Use all_ranks=True so every rank can independently fetch teacher logprobs trainer_kwargs['teacher_api_client'] = create_teacher_api_client( - self.args, check_health=False, timeout=60, use_last_rank=False) + self.args, + check_health=True, + timeout=60, + use_last_rank=False, + tokenizer=self.template.tokenizer, + all_ranks=True) return trainer_kwargs diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index fa7ac088ce..55ad0e55b1 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -6,7 +6,7 @@ from contextlib import contextmanager, nullcontext from copy import deepcopy from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,7 @@ from swift.trainers import SwiftMixin, disable_gradient_checkpointing from swift.utils import (JsonlWriter, get_logger, is_swanlab_available, is_wandb_available, remove_response, to_device, unwrap_model_for_generation) +from .jsd_loss import compute_jsd_loss from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (get_gather_if_zero3_context, identity_data_collator, patch_profiling_context, patch_profiling_decorator, prepare_deepspeed) @@ -55,9 +56,8 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) self.teacher_api_client = kwargs.pop('teacher_api_client', None) - # Pop GKD-specific args from kwargs (passed from rlhf.py) + self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) teacher_model_server = kwargs.pop('teacher_model_server', None) - gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) super().__init__(model, None, *_args, **kwargs) args = kwargs['args'] self.lmbda = args.lmbda @@ -67,13 +67,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 - # GKD top-k logits configuration (from kwargs, fallback to args for backward compatibility) - self.gkd_logits_topk = gkd_logits_topk if gkd_logits_topk is not None else getattr( - args, 'gkd_logits_topk', None) - # Check use_teacher_api based on kwargs (passed from rlhf.py) - # API client is only created on master rank, but all ranks need to know the mode self.use_teacher_api = teacher_model_server is not None - logger.info(f'teacher_model_server={teacher_model_server}, use_teacher_api={self.use_teacher_api}') # Initialize logging components self._prepare_logging() @@ -182,8 +176,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N model_inputs = {k: v for k, v in inputs.items() if k not in {'prompt', 'labels'}} # If generate is used, then use_logits_to_keep must be set to False. - # Also disable logits_to_keep when using teacher API to ensure sequence length alignment - use_logits_to_keep = self.get_use_logits_to_keep(True) and not self.use_teacher_api + use_logits_to_keep = self.get_use_logits_to_keep(True) if use_logits_to_keep and not self.use_liger_gkd_loss: self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] @@ -258,7 +251,40 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N model_inputs['labels'] = inputs['labels'] outputs_student = model(**model_inputs) - shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + # Handle logits_to_keep: truncate teacher logprobs to match student output length + logits_to_keep = inputs.get('logits_to_keep') + if logits_to_keep is not None: + if isinstance(logits_to_keep, torch.Tensor): + if logits_to_keep.dtype == torch.bool: + # Boolean mask case: apply the same mask to teacher logprobs + # logits_to_keep is shape [seq_len], True for positions to keep + teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] + teacher_api_indices = teacher_api_indices[:, logits_to_keep] + shifted_labels = inputs['labels'] + shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) + elif logits_to_keep.numel() == 1: + # Single element tensor + num_keep = logits_to_keep.item() + teacher_api_logprobs = teacher_api_logprobs[:, -num_keep:] + teacher_api_indices = teacher_api_indices[:, -num_keep:] + shifted_labels = inputs['labels'][:, -num_keep:] + shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) + else: + # Tensor with multiple elements - not supported with teacher API + # Fall back to using full sequence + logger.warning_once( + 'logits_to_keep tensor with multiple elements not supported with teacher API. ' + 'Using full sequence.') + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + else: + # Integer case + num_keep = int(logits_to_keep) + teacher_api_logprobs = teacher_api_logprobs[:, -num_keep:] + teacher_api_indices = teacher_api_indices[:, -num_keep:] + shifted_labels = inputs['labels'][:, -num_keep:] + shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) + else: + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) # Compute top-k JSD loss with API logprobs loss = self.generalized_jsd_loss( @@ -459,57 +485,40 @@ def training_step(self, loss = HFSFTTrainer.training_step(self, model, encoded_inputs, num_items_in_batch) return loss - def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tensor]): + def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, + torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """Fetch teacher logprobs from external API service. - Only the master rank makes API calls, then broadcasts results to other ranks. + In DeepSpeed/FSDP Data Parallel mode, each rank has different batch data, + so each rank independently fetches its own teacher logprobs. No broadcast needed. Args: encoded_inputs: Dictionary containing input_ids, attention_mask, labels, etc. Returns: - Tuple of (teacher_logprobs, teacher_indices) tensors + Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk] """ - import asyncio - import torch.distributed as dist + from swift.rlhf_trainers.teacher_api_client import run_async input_ids = encoded_inputs['input_ids'] batch_size, seq_len = input_ids.shape topk = self.gkd_logits_topk device = input_ids.device - # Initialize tensors + # Initialize output tensors teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) - # Only master rank fetches from API - is_distributed = dist.is_initialized() - is_master = not is_distributed or dist.get_rank() == 0 - - if is_master and self.teacher_api_client is not None: - # Prepare requests for API - async def fetch_batch(): - results = await self.teacher_api_client.get_logprobs_batch( + # Each rank independently fetches its own batch's teacher logprobs + # No broadcast needed because in DP mode, each rank has different data + if self.teacher_api_client is not None: + api_results = run_async( + self.teacher_api_client.get_logprobs_batch( input_ids=input_ids.tolist(), top_logprobs=topk, - ) - return results - - # Run async function - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, fetch_batch()) - api_results = future.result() - else: - api_results = loop.run_until_complete(fetch_batch()) - except RuntimeError: - api_results = asyncio.run(fetch_batch()) + )) # Parse API results into tensors - # api_results is list of dicts with 'values' (logprobs) and 'indices' for each sample for batch_idx, result in enumerate(api_results): indices_list = result.get('indices', []) values_list = result.get('values', []) @@ -520,11 +529,6 @@ async def fetch_batch(): teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] - # Broadcast results to all ranks - if is_distributed: - dist.broadcast(teacher_logprobs, src=0) - dist.broadcast(teacher_indices, src=0) - return teacher_logprobs, teacher_indices def prediction_step(self, model, inputs, *args, **kwargs): @@ -617,151 +621,27 @@ def generalized_jsd_loss( labels=None, beta=0.5, temperature=1.0, - chunk_size=512, + chunk_size=256, topk=None, teacher_topk_logprobs=None, teacher_topk_indices=None, ): """Compute generalized JSD loss with optional top-k support. - This method supports three modes: - 1. Full vocabulary mode (default): Uses complete logits from both models - 2. Top-k mode with local teacher: Extracts top-k from teacher_logits - 3. Top-k mode with API logprobs: Uses pre-computed teacher_topk_logprobs and indices - - For top-k mode, uses the teacher model's top-k tokens (following ROLL framework). - This reduces memory usage while maintaining training effectiveness. - - Args: - student_logits: Student model logits [batch, seq_len, vocab_size] or [num_tokens, vocab_size] - teacher_logits: Teacher model logits (same shape as student_logits), can be None for API mode - labels: Token labels for masking, shape [batch, seq_len] - beta: JSD interpolation coefficient (0=Forward KL, 0.5=JSD, 1=Reverse KL) - temperature: Temperature for softmax scaling - chunk_size: Chunk size for memory-efficient processing (full vocab mode only) - topk: Number of top-k logits to use (teacher's top-k). None for full vocabulary mode. - teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) - teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) - - Returns: - Scalar loss value + Delegates to the unified jsd_loss module for memory-efficient computation. + See `swift.rlhf_trainers.jsd_loss.compute_jsd_loss` for details. """ - # Determine mode - use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None - use_topk = topk is not None or use_api_mode - - # ============== Top-K Mode ============== - if use_topk: - # Apply temperature scaling to student logits - student_logits_scaled = student_logits / temperature - - if use_api_mode: - # API mode: teacher logprobs already computed (with temperature on server) - teacher_topk_log_probs = teacher_topk_logprobs - teacher_topk_probs = torch.exp(teacher_topk_logprobs) - topk_indices = teacher_topk_indices - else: - # Local mode: extract top-k from teacher logits - teacher_logits_scaled = teacher_logits / temperature - teacher_topk_logits, topk_indices = torch.topk(teacher_logits_scaled, k=topk, dim=-1) - teacher_topk_probs = F.softmax(teacher_topk_logits, dim=-1) - teacher_topk_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) - - # Gather student logits at teacher's top-k indices and renormalize - student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=topk_indices) - student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) - - # Compute JSD on top-k distribution - if beta == 0: - # Forward KL: KL(teacher || student) - jsd = (teacher_topk_probs * (teacher_topk_log_probs - student_topk_log_probs)).sum(dim=-1) - elif beta == 1: - # Reverse KL: KL(student || teacher) - student_topk_probs = F.softmax(student_topk_logits, dim=-1) - jsd = (student_topk_probs * (student_topk_log_probs - teacher_topk_log_probs)).sum(dim=-1) - else: - # Full JSD with mixture distribution - student_topk_probs = F.softmax(student_topk_logits, dim=-1) - mixture_probs = beta * teacher_topk_probs + (1 - beta) * student_topk_probs - mixture_log_probs = torch.log(mixture_probs + 1e-10) - kl_teacher = (teacher_topk_probs * (teacher_topk_log_probs - mixture_log_probs)).sum(dim=-1) - kl_student = (student_topk_probs * (student_topk_log_probs - mixture_log_probs)).sum(dim=-1) - jsd = beta * kl_teacher + (1 - beta) * kl_student - - # Apply mask and compute mean - if labels is not None: - mask = labels != -100 - jsd = jsd * mask.float() - num_valid = mask.sum() - else: - num_valid = jsd.numel() - - if num_valid == 0: - return student_logits.new_zeros(()) - return jsd.sum() / num_valid - - # ============== Full Vocabulary Mode ============== - # Apply temperature scaling - student_logits = student_logits / temperature - teacher_logits = teacher_logits / temperature - - # Apply masking if labels provided - if labels is not None: - mask = labels != -100 - student_logits = student_logits[mask] - teacher_logits = teacher_logits[mask] - num_valid = mask.sum() - else: - # Flatten to [num_tokens, vocab_size] - student_logits = student_logits.view(-1, student_logits.size(-1)) - teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) - num_valid = student_logits.size(0) - - if num_valid == 0: - return student_logits.new_zeros(()) - - num_valid_int = num_valid if isinstance(num_valid, int) else num_valid.item() - total_loss = student_logits.new_zeros(()) - - # Precompute beta tensor once if needed - if beta != 0 and beta != 1: - beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device) - log_beta = torch.log(beta_t) - log_1_minus_beta = torch.log1p(-beta_t) - else: - beta_t = log_beta = log_1_minus_beta = None - - # Process in chunks to reduce peak memory - for start_idx in range(0, num_valid_int, chunk_size): - end_idx = min(start_idx + chunk_size, num_valid_int) - s_chunk = student_logits[start_idx:end_idx] - t_chunk = teacher_logits[start_idx:end_idx] - - s_log_probs = F.log_softmax(s_chunk, dim=-1) - t_log_probs = F.log_softmax(t_chunk, dim=-1) - del s_chunk, t_chunk - - if beta == 0: - jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) - elif beta == 1: - jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True) - else: - mixture_log_probs = torch.logsumexp( - torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), - dim=0, - ) - - kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True) - kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True) - del mixture_log_probs - - jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student - del kl_teacher, kl_student - - total_loss = total_loss + jsd_chunk.sum() - del jsd_chunk, s_log_probs, t_log_probs - - return total_loss / num_valid + return compute_jsd_loss( + student_logits=student_logits, + teacher_logits=teacher_logits, + labels=labels, + beta=beta, + temperature=temperature, + chunk_size=chunk_size, + topk=topk, + teacher_topk_logprobs=teacher_topk_logprobs, + teacher_topk_indices=teacher_topk_indices, + ) def _prepare_logging(self): """Initialize logging components for on-policy rollout tracking.""" diff --git a/swift/rlhf_trainers/jsd_loss.py b/swift/rlhf_trainers/jsd_loss.py new file mode 100644 index 0000000000..c4be2ef7da --- /dev/null +++ b/swift/rlhf_trainers/jsd_loss.py @@ -0,0 +1,271 @@ +"""Unified JSD (Jensen-Shannon Divergence) loss implementation for GKD training. + +This module provides a memory-efficient, chunked JSD loss computation that supports: +1. Full vocabulary mode: Uses complete logits from both models +2. Top-K mode with local teacher: Extracts top-k from teacher logits +3. Top-K mode with API: Uses pre-computed teacher logprobs and indices + +The implementation uses chunked processing to reduce peak memory usage. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +def compute_jsd_loss( + student_logits: torch.Tensor, + teacher_logits: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + temperature: float = 1.0, + chunk_size: int = 256, + topk: Optional[int] = None, + teacher_topk_logprobs: Optional[torch.Tensor] = None, + teacher_topk_indices: Optional[torch.Tensor] = None, + log_softmax_fn=None, + kl_div_fn=None, +) -> torch.Tensor: + """Compute JSD loss with unified chunked processing for memory efficiency. + + This function handles all three modes in a unified way: + - Full vocab mode: teacher_logits provided, topk=None + - Top-K local mode: teacher_logits provided, topk specified + - Top-K API mode: teacher_topk_logprobs and teacher_topk_indices provided + + Args: + student_logits: Student model logits [batch, seq_len, vocab_size] + teacher_logits: Teacher model logits [batch, seq_len, vocab_size], None for API mode + labels: Token labels for masking [batch, seq_len], -100 for ignored positions + beta: JSD interpolation coefficient (0=Forward KL, 0.5=JSD, 1=Reverse KL) + temperature: Temperature for softmax scaling + chunk_size: Chunk size for memory-efficient processing + topk: Number of top-k logits to use. None for full vocabulary mode. + teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) + teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) + log_softmax_fn: Optional custom log_softmax function (e.g., for vocab parallel) + kl_div_fn: Optional custom KL div function (e.g., for vocab parallel) + + Returns: + Scalar loss value + """ + # Determine mode + use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None + use_topk = topk is not None or use_api_mode + + # Build mask + if labels is not None: + mask = labels != -100 + else: + mask = torch.ones(student_logits.shape[:2], dtype=torch.bool, device=student_logits.device) + + num_valid = mask.sum() + if num_valid == 0: + return student_logits.new_zeros(()) + + # Dispatch to appropriate mode + if use_api_mode: + return _compute_topk_api_loss(student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, num_valid, + beta, temperature) + elif use_topk: + return _compute_topk_local_loss_chunked(student_logits, teacher_logits, mask, num_valid, beta, temperature, + topk, chunk_size) + else: + return _compute_full_vocab_loss_chunked(student_logits, teacher_logits, mask, num_valid, beta, temperature, + chunk_size, log_softmax_fn, kl_div_fn) + + +def _compute_topk_jsd( + teacher_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + student_logits: torch.Tensor, + student_log_probs: torch.Tensor, + beta: float, +) -> torch.Tensor: + """Compute JSD on top-k distribution. + + Args: + teacher_probs: Teacher probabilities [*, topk] + teacher_log_probs: Teacher log probabilities [*, topk] + student_logits: Student logits at top-k positions [*, topk] + student_log_probs: Student log probabilities [*, topk] + beta: JSD interpolation coefficient + + Returns: + JSD values [*] (reduced over topk dimension) + """ + if beta == 0: + # Forward KL: KL(teacher || student) + return (teacher_probs * (teacher_log_probs - student_log_probs)).sum(dim=-1) + elif beta == 1: + # Reverse KL: KL(student || teacher) + student_probs = F.softmax(student_logits, dim=-1) + return (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) + else: + # Full JSD with mixture distribution + student_probs = F.softmax(student_logits, dim=-1) + mixture_probs = beta * teacher_probs + (1 - beta) * student_probs + mixture_log_probs = torch.log(mixture_probs + 1e-10) + kl_teacher = (teacher_probs * (teacher_log_probs - mixture_log_probs)).sum(dim=-1) + kl_student = (student_probs * (student_log_probs - mixture_log_probs)).sum(dim=-1) + return beta * kl_teacher + (1 - beta) * kl_student + + +def _compute_topk_api_loss( + student_logits: torch.Tensor, + teacher_topk_logprobs: torch.Tensor, + teacher_topk_indices: torch.Tensor, + mask: torch.Tensor, + num_valid: torch.Tensor, + beta: float, + temperature: float, +) -> torch.Tensor: + """Compute Top-K JSD loss using pre-computed API logprobs. + + This mode is already memory-efficient since teacher logprobs are pre-computed + and only top-k values are stored. + """ + # Apply temperature to student logits + student_logits_scaled = student_logits / temperature + + # Get teacher probs from log probs + teacher_probs = torch.exp(teacher_topk_logprobs) + + # Gather student logits at teacher's top-k positions + student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=teacher_topk_indices) + del student_logits_scaled + student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) + + # Compute JSD + jsd = _compute_topk_jsd(teacher_probs, teacher_topk_logprobs, student_topk_logits, student_topk_log_probs, beta) + + # Apply mask and compute mean + jsd_masked = jsd * mask.float() + return jsd_masked.sum() / num_valid + + +def _compute_topk_local_loss_chunked( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + mask: torch.Tensor, + num_valid: torch.Tensor, + beta: float, + temperature: float, + topk: int, + chunk_size: int, +) -> torch.Tensor: + """Compute Top-K JSD loss with local teacher using chunked processing. + + Processes the sequence in chunks along the sequence dimension to avoid + keeping full vocab-size tensors in memory simultaneously. + """ + seq_len = student_logits.shape[1] + total_loss = student_logits.new_zeros(()) + + for start_idx in range(0, seq_len, chunk_size): + end_idx = min(start_idx + chunk_size, seq_len) + + chunk_mask = mask[:, start_idx:end_idx] + if chunk_mask.sum() == 0: + continue + + # Get logits chunks and apply temperature + student_chunk = student_logits[:, start_idx:end_idx, :] / temperature + teacher_chunk = teacher_logits[:, start_idx:end_idx, :] / temperature + + # Get top-k from teacher chunk, then release teacher chunk + teacher_topk_logits, topk_indices = torch.topk(teacher_chunk, k=topk, dim=-1) + del teacher_chunk + + teacher_probs = F.softmax(teacher_topk_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) + del teacher_topk_logits + + # Gather student logits at top-k positions, then release student chunk + student_topk_logits = torch.gather(student_chunk, dim=-1, index=topk_indices) + del student_chunk, topk_indices + + student_log_probs = F.log_softmax(student_topk_logits, dim=-1) + + # Compute JSD and accumulate + jsd = _compute_topk_jsd(teacher_probs, teacher_log_probs, student_topk_logits, student_log_probs, beta) + jsd_masked = jsd * chunk_mask.float() + total_loss = total_loss + jsd_masked.sum() + + del jsd, jsd_masked, student_topk_logits, student_log_probs, teacher_probs, teacher_log_probs + + return total_loss / num_valid + + +def _compute_full_vocab_loss_chunked( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + mask: torch.Tensor, + num_valid: torch.Tensor, + beta: float, + temperature: float, + chunk_size: int, + log_softmax_fn, + kl_div_fn=None, +) -> torch.Tensor: + """Compute full vocabulary JSD loss with chunked processing. + + Supports custom log_softmax and kl_div functions for vocab-parallel computation. + """ + # Use default implementations if not provided + if log_softmax_fn is None: + + def log_softmax_fn(x): + return F.log_softmax(x, dim=-1) + + if kl_div_fn is None: + + def kl_div_fn(p, q): + return F.kl_div(p, q, reduction='none', log_target=True) + + # Apply temperature and masking to flatten valid tokens + student_logits_masked = (student_logits / temperature)[mask] + teacher_logits_masked = (teacher_logits / temperature)[mask] + del student_logits, teacher_logits + + num_valid_int = num_valid.item() if isinstance(num_valid, torch.Tensor) else int(num_valid) + total_loss = student_logits_masked.new_zeros(()) + + # Precompute beta tensors if needed + if beta != 0 and beta != 1: + beta_t = torch.tensor(beta, dtype=student_logits_masked.dtype, device=student_logits_masked.device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + else: + beta_t = log_beta = log_1_minus_beta = None + + for start_idx in range(0, num_valid_int, chunk_size): + end_idx = min(start_idx + chunk_size, num_valid_int) + s_chunk = student_logits_masked[start_idx:end_idx] + t_chunk = teacher_logits_masked[start_idx:end_idx] + + s_log_probs = log_softmax_fn(s_chunk) + t_log_probs = log_softmax_fn(t_chunk) + del s_chunk, t_chunk + + if beta == 0: + jsd_chunk = kl_div_fn(s_log_probs, t_log_probs) + elif beta == 1: + jsd_chunk = kl_div_fn(t_log_probs, s_log_probs) + else: + mixture_log_probs = torch.logsumexp( + torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), + dim=0, + ) + kl_teacher = kl_div_fn(mixture_log_probs, t_log_probs) + kl_student = kl_div_fn(mixture_log_probs, s_log_probs) + del mixture_log_probs + jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student + del kl_teacher, kl_student + + total_loss = total_loss + jsd_chunk.sum() + del jsd_chunk, s_log_probs, t_log_probs + + del student_logits_masked, teacher_logits_masked + return total_loss / num_valid diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py index a4be81c229..22962198dc 100644 --- a/swift/rlhf_trainers/teacher_api_client.py +++ b/swift/rlhf_trainers/teacher_api_client.py @@ -7,23 +7,58 @@ """ import asyncio import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypeVar import aiohttp import torch logger = logging.getLogger(__name__) +T = TypeVar('T') + + +def run_async(coro: Coroutine[Any, Any, T]) -> T: + """Run an async coroutine in a sync context, handling nested event loops. + + This utility function handles the complexity of running async code from + synchronous contexts, including when an event loop is already running. + + Args: + coro: The coroutine to execute + + Returns: + The result of the coroutine + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Event loop already running (e.g., in Jupyter or nested async) + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, coro) + return future.result() + else: + return loop.run_until_complete(coro) + except RuntimeError: + # No event loop exists + return asyncio.run(coro) + class TeacherAPIClient: """Client for fetching teacher logprobs from swift deploy or vLLM server. - This client is designed to work with OpenAI-compatible API endpoints: - - swift deploy (with vLLM backend) - - Standalone vLLM server (vllm serve) + This client supports two API formats: + + 1. Swift Deploy / Chat Completions API (preferred): + - Uses /v1/chat/completions endpoint + - Sets prompt_logprobs parameter to get logprobs for input tokens + - Works with swift deploy using vLLM backend - The client fetches top-k log probabilities for each token position, - which are then used for knowledge distillation (GKD) training. + 2. vLLM Native Completions API (fallback): + - Uses /v1/completions endpoint + - Works with standalone vLLM server + + The client auto-detects which API format the server supports. Args: base_url: The base URL of the teacher model server (e.g., 'http://localhost:8000'). @@ -31,6 +66,8 @@ class TeacherAPIClient: timeout: Request timeout in seconds. api_key: Optional API key for authentication. model_name: Optional model name for the API request. If None, auto-detects. + tokenizer: Optional tokenizer for converting text prompts. If provided, + can decode response tokens. """ def __init__( @@ -40,12 +77,15 @@ def __init__( timeout: float = 300.0, api_key: Optional[str] = None, model_name: Optional[str] = None, + tokenizer: Optional[Any] = None, ): self.base_url = base_url.rstrip('/') self.top_logprobs = top_logprobs self.timeout = aiohttp.ClientTimeout(total=timeout) self.api_key = api_key self.model_name = model_name + self.tokenizer = tokenizer + self._api_format = None # 'swift_deploy' or 'vllm_native', detected on first request if top_logprobs <= 0: raise ValueError(f'top_logprobs must be positive, got {top_logprobs}') @@ -77,12 +117,76 @@ async def _get_model_name(self, session: aiohttp.ClientSession) -> str: self.model_name = 'default' return self.model_name + async def _detect_api_format(self, session: aiohttp.ClientSession, model_name: str) -> str: + """Detect which API format the server supports. + + Returns: + 'swift_deploy' if server supports prompt_logprobs in chat/completions + 'vllm_native' if server supports vLLM native completions API + """ + if self._api_format is not None: + return self._api_format + + # Try swift deploy format first (chat/completions with prompt_logprobs) + url = f'{self.base_url}/v1/chat/completions' + test_payload = { + 'model': model_name, + 'messages': [{ + 'role': 'user', + 'content': 'Hi' + }], + 'max_tokens': 1, + 'temperature': 0, + 'prompt_logprobs': 5, + } + + try: + async with session.post( + url, json=test_payload, headers=self._get_headers(), + timeout=aiohttp.ClientTimeout(total=30)) as resp: + if resp.status == 200: + data = await resp.json() + # Check if prompt_logprobs is returned + choices = data.get('choices', []) + if choices and choices[0].get('prompt_logprobs') is not None: + self._api_format = 'swift_deploy' + logger.info('Detected swift deploy API format with prompt_logprobs support') + return self._api_format + except Exception as e: + logger.debug(f'Swift deploy API detection failed: {e}') + + # Try vLLM native format + url = f'{self.base_url}/v1/completions' + test_payload = { + 'model': model_name, + 'prompt': [1, 2, 3], # Token IDs + 'max_tokens': 0, + 'temperature': 0, + 'logprobs': 5, + } + + try: + async with session.post( + url, json=test_payload, headers=self._get_headers(), + timeout=aiohttp.ClientTimeout(total=30)) as resp: + if resp.status == 200: + self._api_format = 'vllm_native' + logger.info('Detected vLLM native API format') + return self._api_format + except Exception as e: + logger.debug(f'vLLM native API detection failed: {e}') + + # Default to swift deploy and hope for the best + self._api_format = 'swift_deploy' + logger.warning('Could not detect API format, defaulting to swift deploy') + return self._api_format + async def get_logprobs_batch( self, input_ids: List[List[int]], top_logprobs: Optional[int] = None, ) -> List[Dict[str, Any]]: - """Fetch logprobs for a batch of sequences using OpenAI-compatible API. + """Fetch logprobs for a batch of sequences. Args: input_ids: List of token ID sequences. @@ -97,45 +201,200 @@ async def get_logprobs_batch( async with aiohttp.ClientSession(timeout=self.timeout) as session: model_name = await self._get_model_name(session) - url = f'{self.base_url}/v1/completions' - - results = [] - for i, ids in enumerate(input_ids): - # Use prompt tokens and request logprobs with echo - payload = { - 'model': model_name, - 'prompt': ids, - 'max_tokens': 0, - 'temperature': 0, - 'logprobs': topk, - 'echo': True, - } + api_format = await self._detect_api_format(session, model_name) - try: - async with session.post(url, json=payload, headers=self._get_headers()) as resp: - if resp.status != 200: - error_text = await resp.text() - logger.error(f'API error: {resp.status} - {error_text}') - results.append(self._empty_result(len(ids), topk)) + # Create tasks for concurrent requests + if api_format == 'swift_deploy': + tasks = [self._fetch_swift_deploy(session, model_name, ids, topk) for ids in input_ids] + else: + tasks = [self._fetch_vllm_native(session, model_name, ids, topk) for ids in input_ids] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle any exceptions that occurred + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.warning(f'Request {i} failed with exception: {result}') + processed_results.append(self._empty_result(len(input_ids[i]), topk)) + else: + processed_results.append(result) + return processed_results + + async def _fetch_swift_deploy( + self, + session: aiohttp.ClientSession, + model_name: str, + ids: List[int], + topk: int, + ) -> Dict[str, Any]: + """Fetch logprobs using swift deploy's chat/completions API with prompt_logprobs. + + This converts token IDs to text using the tokenizer (if available) or + sends as a raw text prompt. + """ + # Convert token IDs to text for chat completions API + if self.tokenizer is not None: + prompt_text = self.tokenizer.decode(ids, skip_special_tokens=False) + else: + # Fallback: try to use the server's tokenizer by sending a special request + # For now, just convert to string representation + prompt_text = ''.join(chr(i) if 32 <= i < 127 else f'<{i}>' for i in ids[:100]) + logger.warning_once('No tokenizer provided to TeacherAPIClient. ' + 'Prompt may not be decoded correctly. Pass tokenizer for accurate results.') + + url = f'{self.base_url}/v1/chat/completions' + payload = { + 'model': model_name, + 'messages': [{ + 'role': 'user', + 'content': prompt_text + }], + 'max_tokens': 1, # Minimum required by swift deploy, we only need prompt_logprobs + 'temperature': 0, + 'prompt_logprobs': topk, + } + + max_retries = 3 + for attempt in range(max_retries): + try: + async with session.post(url, json=payload, headers=self._get_headers()) as resp: + if resp.status != 200: + error_text = await resp.text() + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) continue + logger.warning(f'API error after {max_retries} retries: {resp.status} - {error_text}') + return self._empty_result(len(ids), topk) + + data = await resp.json() + return self._parse_swift_deploy_response(data, len(ids), topk) + except Exception as e: + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + continue + logger.warning(f'Failed to get logprobs after {max_retries} retries: {e}') + return self._empty_result(len(ids), topk) + + return self._empty_result(len(ids), topk) + + async def _fetch_vllm_native( + self, + session: aiohttp.ClientSession, + model_name: str, + ids: List[int], + topk: int, + ) -> Dict[str, Any]: + """Fetch logprobs using vLLM native completions API""" + url = f'{self.base_url}/v1/completions' + payload = { + 'model': model_name, + 'prompt': ids, # Token IDs directly + 'max_tokens': 0, + 'temperature': 0, + 'logprobs': topk, + } + + max_retries = 3 + for attempt in range(max_retries): + try: + async with session.post(url, json=payload, headers=self._get_headers()) as resp: + if resp.status != 200: + error_text = await resp.text() + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + continue + logger.warning(f'API error after {max_retries} retries: {resp.status} - {error_text}') + return self._empty_result(len(ids), topk) + + data = await resp.json() + return self._parse_vllm_native_response(data, len(ids), topk) + except Exception as e: + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + continue + logger.warning(f'Failed to get logprobs after {max_retries} retries: {e}') + return self._empty_result(len(ids), topk) + + return self._empty_result(len(ids), topk) + + def _parse_swift_deploy_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> Dict[str, Any]: + """Parse swift deploy chat/completions response with prompt_logprobs. + + The response format is: + { + "choices": [{ + "prompt_logprobs": [ + {"token_id": int, "token": str, "logprob": float, "top_logprobs": [...]}, + ... + ] + }] + } + """ + result = {'indices': [], 'values': []} + + try: + if 'choices' not in response or len(response['choices']) == 0: + return self._empty_result(seq_len, topk) + + choice = response['choices'][0] + prompt_logprobs = choice.get('prompt_logprobs') + + if prompt_logprobs is None: + logger.warning('prompt_logprobs not found in response') + return self._empty_result(seq_len, topk) + + for pos_entry in prompt_logprobs: + pos_indices = [] + pos_values = [] + + if pos_entry is not None: + top_logprobs_list = pos_entry.get('top_logprobs', []) + + for item in top_logprobs_list[:topk]: + token_id = item.get('token_id') + logprob = item.get('logprob') + if token_id is not None and logprob is not None: + pos_indices.append(token_id) + pos_values.append(float(logprob)) + + # Pad if needed + while len(pos_indices) < topk: + pos_indices.append(0) + pos_values.append(float('-inf')) + + result['indices'].append(pos_indices) + result['values'].append(pos_values) + + # Pad to seq_len if needed + while len(result['indices']) < seq_len: + result['indices'].append([0] * topk) + result['values'].append([float('-inf')] * topk) - data = await resp.json() - parsed = self._parse_response(data, len(ids), topk) - results.append(parsed) - except Exception as e: - logger.error(f'Failed to get logprobs for sequence {i}: {e}') - results.append(self._empty_result(len(ids), topk)) + except Exception as e: + logger.warning(f'Failed to parse swift deploy response: {e}') + return self._empty_result(seq_len, topk) - return results + return result - def _parse_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> Dict[str, Any]: - """Parse vLLM completions API response to extract logprobs. + def _parse_vllm_native_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> Dict[str, Any]: + """Parse vLLM native completions API response. - vLLM returns logprobs in two formats: - 1. `prompt_logprobs`: List of dicts where keys are token IDs (as strings), values have 'logprob' field - 2. `top_logprobs` in logprobs: List of dicts where keys are token text + vLLM returns logprobs as: + { + "choices": [{ + "logprobs": { + "top_logprobs": [{token_str: logprob, ...}, ...] + } + }] + } - We prefer `prompt_logprobs` because it has token IDs directly. + Or with prompt_logprobs (if using newer vLLM): + { + "choices": [{ + "prompt_logprobs": [{token_id_str: {"logprob": float}, ...}, ...] + }] + } """ result = {'indices': [], 'values': []} @@ -216,14 +475,14 @@ def _parse_response(self, response: Dict[str, Any], seq_len: int, topk: int) -> result['values'].append([float('-inf')] * topk) except Exception as e: - logger.error(f'Failed to parse response: {e}') + logger.warning(f'Failed to parse vLLM native response: {e}') return self._empty_result(seq_len, topk) return result @staticmethod def _get_logprob_value(logprob) -> float: - """Extract logprob value from vLLM response (handles both float and Logprob object).""" + """Extract logprob value from response (handles both float and dict).""" if isinstance(logprob, (int, float)): return float(logprob) elif hasattr(logprob, 'logprob'): @@ -239,7 +498,7 @@ def _empty_result(self, seq_len: int, topk: int) -> Dict[str, Any]: 'values': [[float('-inf')] * topk for _ in range(seq_len)], } - def check_server_health(self, timeout: float = 5.0) -> bool: + def check_server_health(self, timeout: float = 60.0) -> bool: """Check if the teacher model server is healthy.""" import requests try: @@ -269,17 +528,7 @@ def get_logprobs_sync( Returns: Tuple of (logprobs_tensor, indices_tensor) with shapes [batch, seq_len, topk] """ - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, self.get_logprobs_batch(input_ids, top_logprobs)) - results = future.result() - else: - results = loop.run_until_complete(self.get_logprobs_batch(input_ids, top_logprobs)) - except RuntimeError: - results = asyncio.run(self.get_logprobs_batch(input_ids, top_logprobs)) + results = run_async(self.get_logprobs_batch(input_ids, top_logprobs)) # Convert to tensors topk = top_logprobs or self.top_logprobs @@ -300,3 +549,77 @@ def get_logprobs_sync( logprobs_tensor[batch_idx, pos_idx, k_idx] = pos_values[k_idx] return logprobs_tensor, indices_tensor + + +def fetch_teacher_logprobs_from_api( + teacher_api_client: TeacherAPIClient, + input_ids: torch.Tensor, + topk: int, + device: torch.device, + is_master_rank: bool = True, + broadcast_src: int = 0, + group: Optional['torch.distributed.ProcessGroup'] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetch teacher logprobs from external API service. + + This is a shared utility function used by both swift RLHF and Megatron GKD trainers. + Only the designated rank (master/last) makes API calls, then broadcasts results. + + Note on off-by-one alignment: + The first token's logprob may be None because there's no preceding context + to predict it. The returned tensors will have zeros/negative infinity at + position 0. In GKD training, this is acceptable since the loss mask + (labels=-100) typically excludes prompt tokens. + + Args: + teacher_api_client: The TeacherAPIClient instance (may be None on non-API ranks) + input_ids: Input token IDs tensor [batch_size, seq_len] + topk: Number of top-k logprobs to fetch + device: Device for output tensors + is_master_rank: Whether this rank should make API calls + broadcast_src: Source rank for broadcasting results (rank within the group) + group: Optional process group for broadcasting. If None, uses the default + global process group. For Megatron, pass the model parallel group + (TP×PP×CP) so that ranks processing the same data share results. + + Returns: + Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk] + """ + import torch.distributed as dist + + batch_size, seq_len = input_ids.shape + + # Initialize tensors + teacher_logprobs = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.float32) + teacher_indices = torch.zeros(batch_size, seq_len, topk, device=device, dtype=torch.long) + + # Only designated rank fetches from API + if is_master_rank and teacher_api_client is not None: + # Fetch logprobs from API + api_results = run_async( + teacher_api_client.get_logprobs_batch( + input_ids=input_ids.tolist(), + top_logprobs=topk, + )) + + # Parse API results into tensors + # api_results is list of dicts with 'values' (logprobs) and 'indices' for each sample + for batch_idx, result in enumerate(api_results): + indices_list = result.get('indices', []) + values_list = result.get('values', []) + for pos_idx, (pos_indices, pos_values) in enumerate(zip(indices_list, values_list)): + if pos_idx >= seq_len: + break + for k_idx in range(min(len(pos_indices), topk)): + teacher_indices[batch_idx, pos_idx, k_idx] = pos_indices[k_idx] + teacher_logprobs[batch_idx, pos_idx, k_idx] = pos_values[k_idx] + + # Broadcast results within the process group + if dist.is_initialized(): + # Get group size to determine if broadcast is needed + group_size = dist.get_world_size(group) if group is not None else dist.get_world_size() + if group_size > 1: + dist.broadcast(teacher_logprobs, src=broadcast_src, group=group) + dist.broadcast(teacher_indices, src=broadcast_src, group=group) + + return teacher_logprobs, teacher_indices diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index ea5a96576b..d91881f891 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -1453,7 +1453,12 @@ def check_vllm_version_ge(min_version: str) -> bool: return version.parse(vllm_version) >= version.parse(min_version) -def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60, use_last_rank: bool = True): +def create_teacher_api_client(args, + check_health: bool = True, + timeout: int = 60, + use_last_rank: bool = True, + tokenizer=None, + all_ranks: bool = False): """ Create and initialize TeacherAPIClient for external teacher model service. @@ -1461,7 +1466,9 @@ def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60 args: Arguments object containing teacher_model_server and gkd_logits_topk check_health: Whether to check server health after creation (default: True) timeout: Timeout for health check in seconds (default: 60) - use_last_rank: Whether to use last rank (Megatron style) or first rank (Swift style) for initialization (default: True) + use_last_rank: Whether to use last rank or first rank for initialization (default: True) + tokenizer: Optional tokenizer for decoding token IDs to text (required for swift deploy API) + all_ranks: If True, initialize client on all ranks (for DP mode where each rank needs its own client) Returns: TeacherAPIClient instance or None if teacher_model_server is not set @@ -1477,19 +1484,29 @@ def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60 logger = get_logger() gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) or 20 - # Choose rank check function based on context - rank_check_func = is_last_rank if use_last_rank else is_master + # Determine if this rank should create the client + if all_ranks: + # In DP mode, each rank has different data and needs its own client + should_create = True + else: + # In MP mode, only one rank creates the client and broadcasts results + rank_check_func = is_last_rank if use_last_rank else is_master + should_create = rank_check_func() teacher_api_client = None - if rank_check_func(): + if should_create: logger.info(f'Initializing teacher API client for {teacher_model_server}') teacher_api_client = TeacherAPIClient( base_url=teacher_model_server, top_logprobs=gkd_logits_topk, + tokenizer=tokenizer, ) - if check_health: - # Check server health with timeout - teacher_api_client.check_server_health(timeout=timeout) + # Only master rank does health check to avoid duplicate checks + if check_health and is_master(): + is_healthy = teacher_api_client.check_server_health(timeout=timeout) + if not is_healthy: + raise ConnectionError(f'Failed to connect to teacher model server at {teacher_model_server}. ' + 'Please ensure the server is running and accessible.') logger.info(f'Teacher API client initialized with top_logprobs={gkd_logits_topk}') return teacher_api_client From a6ecebb507d0f9b7304969d721ca4fe932439852 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 2 Mar 2026 10:56:20 +0800 Subject: [PATCH 04/10] fix args --- swift/arguments/rlhf_args.py | 19 ++++++------------- swift/infer_engine/protocol.py | 1 + 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index ce2b5abedb..a66f02d908 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -55,11 +55,6 @@ class TeacherModelArguments: When set, the teacher logprobs will be fetched from the external API service (e.g., swift deploy, vLLM) instead of loading a local teacher model. This enables using larger teacher models or services hosted remotely. When this is set, `teacher_model` is not required. Defaults to None. - gkd_logits_topk (Optional[int]): The number of top-k logits to use for KL divergence computation in GKD. - If None, uses full vocabulary for KL computation (more accurate but memory-intensive). - If set to a positive integer, only top-k teacher logits are used (more efficient). - When using `teacher_model_server`, this is limited by the server's `max_logprobs` setting - (vLLM default is 20, can be increased with `--max-logprobs`). Defaults to None. """ teacher_model: Optional[str] = None teacher_adapters: List[str] = field(default_factory=list) @@ -80,14 +75,6 @@ class TeacherModelArguments: 'URL of the teacher model server (e.g., http://localhost:8000). ' 'When set, teacher logprobs are fetched via API instead of loading a local model.' }) - gkd_logits_topk: Optional[int] = field( - default=None, - metadata={ - 'help': - 'Number of top-k logits for KL computation in GKD. ' - 'None = full vocabulary, positive integer = top-k only. ' - 'When using teacher_model_server, limited by server max_logprobs (vLLM default: 20).' - }) @dataclass @@ -221,6 +208,11 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo gkd_loss + sft_alpha * sft_loss`. Defaults to 0. lmbda (float): The lambda parameter for GKD, balancing policy and value losses. Defaults to 0.5. seq_kd (bool): Whether to use sequence-level knowledge distillation for GKD. Defaults to False. + gkd_logits_topk (Optional[int]): The number of top-k logits to use for KL divergence computation in GKD. + If None, uses full vocabulary for KL computation (more accurate but memory-intensive). + If set to a positive integer, only top-k teacher logits are used (more efficient). + When using `teacher_model_server`, this is limited by the server's `max_logprobs` setting + (vLLM default is 20, can be increased with `--max-logprobs`). Defaults to None. offload_teacher_model (bool): Whether to offload the teacher model to CPU memory to save VRAM during GKD training. Defaults to False. max_new_tokens (Optional[int]): A backward-compatibility argument. Please use `max_completion_length` instead. @@ -258,6 +250,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo sft_alpha: float = 0 lmbda: float = 0.5 seq_kd: bool = False + gkd_logits_topk: Optional[int] = None offload_teacher_model: bool = False # compat max_new_tokens: Optional[int] = None # use max_completion_length instead diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index f0d880606f..b9aaccb697 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -192,6 +192,7 @@ def __post_init__(self): @dataclass class CompletionRequestMixin: model: str + prompt: str @dataclass From 44f0e4e369eb0331c871722a0af69c0a5e401465 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 2 Mar 2026 11:40:22 +0800 Subject: [PATCH 05/10] update --- swift/infer_engine/protocol.py | 8 +- swift/infer_engine/vllm_engine.py | 56 +---- swift/megatron/trainers/gkd_trainer.py | 118 ++++++---- swift/pipelines/train/rlhf.py | 1 + swift/rlhf_trainers/__init__.py | 2 +- swift/rlhf_trainers/gkd_trainer.py | 135 ++++++----- swift/rlhf_trainers/jsd_loss.py | 271 ---------------------- swift/rlhf_trainers/teacher_api_client.py | 133 +++-------- swift/rlhf_trainers/utils.py | 10 +- tests/train/test_teacher_api_client.py | 231 ------------------ 10 files changed, 194 insertions(+), 771 deletions(-) delete mode 100644 swift/rlhf_trainers/jsd_loss.py delete mode 100644 tests/train/test_teacher_api_client.py diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index b9aaccb697..e98840a8f3 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -172,7 +172,6 @@ class RequestConfig: stream: bool = False logprobs: bool = False top_logprobs: Optional[int] = None - prompt_logprobs: Optional[int] = None # Set to an integer to get top-k logprobs for each prompt token n: int = 1 best_of: Optional[int] = None @@ -393,14 +392,11 @@ class ChatCompletionResponseChoice: finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None token_ids: Optional[List[int]] = None - # Logprobs for prompt tokens (when prompt_logprobs is requested) - prompt_logprobs: Optional[List[Dict[str, Any]]] = None def to_cmpl_choice(self) -> 'CompletionResponseChoice': self = deepcopy(self) assert not self.message.tool_calls, f'message: {self.message}' - return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs, - self.prompt_logprobs) + return CompletionResponseChoice(self.index, self.message.content, self.finish_reason, self.logprobs) @dataclass @@ -426,8 +422,6 @@ class CompletionResponseChoice: text: str finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None - # Logprobs for prompt tokens (when prompt_logprobs is requested) - prompt_logprobs: Optional[List[Dict[str, Any]]] = None @dataclass diff --git a/swift/infer_engine/vllm_engine.py b/swift/infer_engine/vllm_engine.py index 8c483c445c..3a34a5d9e1 100644 --- a/swift/infer_engine/vllm_engine.py +++ b/swift/infer_engine/vllm_engine.py @@ -406,48 +406,6 @@ def _get_logprobs(self, logprobs[token_id] = logprob.logprob return super()._get_logprobs(logprobs_list, token_ids, top_logprobs) - def _get_prompt_logprobs( - self, - prompt_logprobs: Optional[List[Optional[Dict]]], - prompt_token_ids: List[int], - ) -> Optional[List[Dict[str, Any]]]: - if prompt_logprobs is None or not prompt_token_ids: - return None - - result = [] - for pos_idx, (token_id, pos_logprobs) in enumerate(zip(prompt_token_ids, prompt_logprobs)): - token = self.tokenizer.decode(token_id) - entry = { - 'token_id': token_id, - 'token': token, - 'logprob': None, # Will be filled if available - 'top_logprobs': [], - } - - if pos_logprobs is not None: - # Get logprob for the actual token at this position - if token_id in pos_logprobs: - logprob_obj = pos_logprobs[token_id] - entry['logprob'] = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj - - # Get top logprobs sorted by probability (descending) - sorted_items = sorted( - pos_logprobs.items(), key=lambda x: -(x[1].logprob if hasattr(x[1], 'logprob') else x[1])) - for tid, logprob_obj in sorted_items: - logprob_val = logprob_obj.logprob if hasattr(logprob_obj, 'logprob') else logprob_obj - if logprob_val == float('-inf'): - continue - t = self.tokenizer.decode(tid) - entry['top_logprobs'].append({ - 'token_id': tid, - 'token': t, - 'logprob': logprob_val, - }) - - result.append(entry) - - return result - def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingParams: kwargs = {'max_tokens': request_config.max_tokens} for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: @@ -473,10 +431,6 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP # Return only the sampled token's logprob kwargs['logprobs'] = 0 - # Handle prompt_logprobs: return logprobs for prompt/input tokens - if request_config.prompt_logprobs is not None: - kwargs['prompt_logprobs'] = request_config.prompt_logprobs - # TODO: beam search for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']: if hasattr(SamplingParams, key): @@ -635,21 +589,13 @@ def _create_chat_completion_response( logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(content) # Use content instead of response for tool calls token_ids = output.token_ids if request_config.return_details else None - - # Get prompt logprobs if requested - prompt_logprobs_result = None - if request_config.prompt_logprobs is not None: - prompt_logprobs_result = self._get_prompt_logprobs(result.prompt_logprobs, - list(result.prompt_token_ids)) - choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage( role='assistant', content=content, reasoning_content=reasoning_content, tool_calls=toolcall), finish_reason=output.finish_reason, logprobs=logprobs, - token_ids=token_ids, - prompt_logprobs=prompt_logprobs_result) + token_ids=token_ids) choices.append(choice) prompt_token_ids = None images_size = None diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 57d0c80de6..d17862ab56 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -287,7 +287,7 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O if topk is not None and teacher_logits is not None: scaled = teacher_logits / self.temperature topk_logits, topk_indices = torch.topk(scaled, k=topk, dim=-1) - encoded_batch['teacher_api_logprobs'] = F.log_softmax(topk_logits, dim=-1) + encoded_batch['teacher_api_logprobs'] = topk_logits encoded_batch['teacher_api_indices'] = topk_indices encoded_batch['teacher_logits'] = None else: @@ -398,6 +398,7 @@ def generalized_jsd_loss( local_num_valid = mask.sum() num_valid = local_num_valid.float() + # All-reduce num_valid across CP group for correct averaging if args.context_parallel_size > 1: torch.distributed.all_reduce( num_valid, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) @@ -405,13 +406,64 @@ def generalized_jsd_loss( if num_valid == 0: return (student_logits.sum() * 0).reshape(()) - use_topk = teacher_topk_logprobs is not None and teacher_topk_indices is not None - - if use_topk: + # Top-k mode: direct computation without vocab parallel + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: total_loss = self._jsd_topk(student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta) + if args.context_parallel_size > 1: + torch.distributed.all_reduce( + total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) + return total_loss / num_valid + + # Full vocabulary mode (original code) + # Align vocab size between student and teacher + student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits) + + # Apply temperature scaling and mask + student_logits_masked = (student_logits / self.temperature)[mask] + teacher_logits_masked = (teacher_logits / self.temperature)[mask] + del student_logits, teacher_logits + + # Use local count for iteration, global count for averaging + local_num_valid_int = local_num_valid.item() + total_loss = student_logits_masked.new_zeros(()) + + if beta != 0 and beta != 1: + beta_t = torch.tensor(beta, dtype=student_logits_masked.dtype, device=student_logits_masked.device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) else: - total_loss = self._jsd_full_vocab(student_logits, teacher_logits, mask, beta, chunk_size, local_num_valid) + beta_t = log_beta = log_1_minus_beta = None + + for start_idx in range(0, local_num_valid_int, chunk_size): + end_idx = min(start_idx + chunk_size, local_num_valid_int) + s_chunk = student_logits_masked[start_idx:end_idx] + t_chunk = teacher_logits_masked[start_idx:end_idx] + + s_log_probs = vocab_parallel_log_softmax(s_chunk) + t_log_probs = vocab_parallel_log_softmax(t_chunk) + del s_chunk, t_chunk + if beta == 0: + jsd_chunk = vocab_parallel_kl_div(s_log_probs, t_log_probs) + elif beta == 1: + jsd_chunk = vocab_parallel_kl_div(t_log_probs, s_log_probs) + else: + mixture_log_probs = torch.logsumexp( + torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), + dim=0, + ) + kl_teacher = vocab_parallel_kl_div(mixture_log_probs, t_log_probs) + kl_student = vocab_parallel_kl_div(mixture_log_probs, s_log_probs) + del mixture_log_probs + jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student + del kl_teacher, kl_student + + total_loss = total_loss + jsd_chunk.sum() + del jsd_chunk, s_log_probs, t_log_probs + + del student_logits_masked, teacher_logits_masked + + # All-reduce total_loss across CP group for correct sum if args.context_parallel_size > 1: torch.distributed.all_reduce( total_loss, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group()) @@ -419,61 +471,31 @@ def generalized_jsd_loss( return total_loss / num_valid def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta): - """Compute JSD on teacher's top-k distribution (for API or local top-k mode).""" - student_logits_scaled = student_logits / self.temperature - t_log_p = teacher_topk_logprobs - t_p = torch.exp(t_log_p) + """Compute JSD on teacher's top-k distribution. + + Handles both local top-k (raw logits) and API top-k (raw logprobs) by + normalizing both teacher and student over the top-k subset via log_softmax. + """ + s_scaled = student_logits / self.temperature + s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) - s_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=teacher_topk_indices) - s_log_p = F.log_softmax(s_topk_logits, dim=-1) + # Normalize both over top-k subset (handles both raw logits and API logprobs) + t_log_p = F.log_softmax(teacher_topk_logprobs, dim=-1) + s_log_p = F.log_softmax(s_topk, dim=-1) + t_p = torch.exp(t_log_p) if beta == 0: jsd = (t_p * (t_log_p - s_log_p)).sum(dim=-1) elif beta == 1: - s_p = F.softmax(s_topk_logits, dim=-1) + s_p = torch.exp(s_log_p) jsd = (s_p * (s_log_p - t_log_p)).sum(dim=-1) else: - s_p = F.softmax(s_topk_logits, dim=-1) + s_p = torch.exp(s_log_p) m_log_p = torch.log(beta * t_p + (1 - beta) * s_p + 1e-10) jsd = beta * (t_p * (t_log_p - m_log_p)).sum(-1) + (1 - beta) * (s_p * (s_log_p - m_log_p)).sum(-1) return (jsd * mask.float()).sum() - def _jsd_full_vocab(self, student_logits, teacher_logits, mask, beta, chunk_size, local_num_valid): - """Compute JSD over full vocabulary with vocab-parallel support.""" - student_logits, teacher_logits = self._align_vocab_size(student_logits, teacher_logits) - - s_masked = (student_logits / self.temperature)[mask] - t_masked = (teacher_logits / self.temperature)[mask] - del student_logits, teacher_logits - - local_n = local_num_valid.item() - total_loss = s_masked.new_zeros(()) - - if beta != 0 and beta != 1: - beta_t = torch.tensor(beta, dtype=s_masked.dtype, device=s_masked.device) - log_beta = torch.log(beta_t) - log_1_minus_beta = torch.log1p(-beta_t) - else: - beta_t = log_beta = log_1_minus_beta = None - - for i in range(0, local_n, chunk_size): - s_log = vocab_parallel_log_softmax(s_masked[i:i + chunk_size]) - t_log = vocab_parallel_log_softmax(t_masked[i:i + chunk_size]) - - if beta == 0: - chunk_loss = vocab_parallel_kl_div(s_log, t_log) - elif beta == 1: - chunk_loss = vocab_parallel_kl_div(t_log, s_log) - else: - m_log = torch.logsumexp(torch.stack([s_log + log_1_minus_beta, t_log + log_beta]), dim=0) - chunk_loss = beta_t * vocab_parallel_kl_div(m_log, t_log) \ - + (1 - beta_t) * vocab_parallel_kl_div(m_log, s_log) - - total_loss = total_loss + chunk_loss.sum() - - return total_loss - def loss_func(self, output_tensor: torch.Tensor, *, diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index 4b1ba82db0..6f1d8d9ada 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -240,6 +240,7 @@ def _get_trainer_kwargs(self): # Pass teacher_model_server so trainer knows to use API mode on all ranks trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server from swift.rlhf_trainers.utils import create_teacher_api_client + # In DP mode (DeepSpeed/FSDP), each rank has different data and needs its own client # Use all_ranks=True so every rank can independently fetch teacher logprobs trainer_kwargs['teacher_api_client'] = create_teacher_api_client( diff --git a/swift/rlhf_trainers/__init__.py b/swift/rlhf_trainers/__init__.py index 2da8ba95b5..aec0e9438c 100644 --- a/swift/rlhf_trainers/__init__.py +++ b/swift/rlhf_trainers/__init__.py @@ -15,9 +15,9 @@ from .ppo_trainer import PPOTrainer from .reward_trainer import RewardTrainer from .rlhf_mixin import RLHFTrainerMixin + from .teacher_api_client import TeacherAPIClient from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin from .vllm_client import VLLMClient - from .teacher_api_client import TeacherAPIClient else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 40a4134f9f..d31a7f5609 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -2,7 +2,6 @@ import inspect import os import random - import torch import torch.nn as nn import torch.nn.functional as F @@ -21,7 +20,6 @@ from swift.trainers import SwiftMixin, disable_gradient_checkpointing from swift.utils import (JsonlWriter, get_logger, is_swanlab_available, is_wandb_available, remove_response, to_device, unwrap_model_for_generation) -from .jsd_loss import compute_jsd_loss from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (get_gather_if_zero3_context, identity_data_collator, prepare_deepspeed, profiling_context, profiling_decorator) @@ -250,50 +248,27 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N outputs_student = None elif self.use_teacher_api: assert teacher_api_logprobs is not None - # API mode: use teacher logprobs from external service if self.args.sft_alpha > 0: model_inputs['labels'] = inputs['labels'] outputs_student = model(**model_inputs) - # Handle logits_to_keep: truncate teacher logprobs to match student output length + # Align teacher API logprobs with student output when logits_to_keep is used + labels = inputs['labels'] logits_to_keep = inputs.get('logits_to_keep') if logits_to_keep is not None: - if isinstance(logits_to_keep, torch.Tensor): - if logits_to_keep.dtype == torch.bool: - # Boolean mask case: apply the same mask to teacher logprobs - # logits_to_keep is shape [seq_len], True for positions to keep - teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] - teacher_api_indices = teacher_api_indices[:, logits_to_keep] - shifted_labels = inputs['labels'] - shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) - elif logits_to_keep.numel() == 1: - # Single element tensor - num_keep = logits_to_keep.item() - teacher_api_logprobs = teacher_api_logprobs[:, -num_keep:] - teacher_api_indices = teacher_api_indices[:, -num_keep:] - shifted_labels = inputs['labels'][:, -num_keep:] - shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) - else: - # Tensor with multiple elements - not supported with teacher API - # Fall back to using full sequence - logger.warning_once( - 'logits_to_keep tensor with multiple elements not supported with teacher API. ' - 'Using full sequence.') - shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: + teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] + teacher_api_indices = teacher_api_indices[:, logits_to_keep] + labels = labels[:, logits_to_keep] else: - # Integer case - num_keep = int(logits_to_keep) - teacher_api_logprobs = teacher_api_logprobs[:, -num_keep:] - teacher_api_indices = teacher_api_indices[:, -num_keep:] - shifted_labels = inputs['labels'][:, -num_keep:] - shifted_labels = torch.roll(shifted_labels, shifts=-1, dims=1) - else: - shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) + teacher_api_logprobs = teacher_api_logprobs[:, -n:] + teacher_api_indices = teacher_api_indices[:, -n:] + labels = labels[:, -n:] + shifted_labels = torch.roll(labels, shifts=-1, dims=1) - # Compute top-k JSD loss with API logprobs loss = self.generalized_jsd_loss( student_logits=outputs_student.logits, - teacher_logits=None, # Not used in API mode labels=shifted_labels, beta=self.beta, temperature=self.temperature, @@ -301,7 +276,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_topk_indices=teacher_api_indices, ) - # Add SFT loss if enabled (skip for student-generated responses) if self.args.sft_alpha > 0 and data_source != DataSource.STUDENT: loss = loss + self.args.sft_alpha * outputs_student.loss else: @@ -589,31 +563,84 @@ def _prepare_liger_loss(self): @staticmethod def generalized_jsd_loss( student_logits, - teacher_logits, + teacher_logits=None, labels=None, beta=0.5, temperature=1.0, - chunk_size=256, + chunk_size=512, topk=None, teacher_topk_logprobs=None, teacher_topk_indices=None, ): - """Compute generalized JSD loss with optional top-k support. + # Top-k mode: reduce logits to [*, k] before the standard JSD pipeline + if topk is not None and teacher_logits is not None: + t_scaled = teacher_logits / temperature + s_scaled = student_logits / temperature + teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1) + student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx) + del t_scaled, s_scaled, topk_idx + temperature = 1.0 + elif teacher_topk_logprobs is not None and teacher_topk_indices is not None: + s_scaled = student_logits / temperature + student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + teacher_logits = teacher_topk_logprobs + del s_scaled + temperature = 1.0 + + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + + if labels is not None: + mask = labels != -100 + student_logits = student_logits[mask] + teacher_logits = teacher_logits[mask] + num_valid = mask.sum() + else: + student_logits = student_logits.view(-1, student_logits.size(-1)) + teacher_logits = teacher_logits.view(-1, teacher_logits.size(-1)) + num_valid = student_logits.size(0) - Delegates to the unified jsd_loss module for memory-efficient computation. - See `swift.rlhf_trainers.jsd_loss.compute_jsd_loss` for details. - """ - return compute_jsd_loss( - student_logits=student_logits, - teacher_logits=teacher_logits, - labels=labels, - beta=beta, - temperature=temperature, - chunk_size=chunk_size, - topk=topk, - teacher_topk_logprobs=teacher_topk_logprobs, - teacher_topk_indices=teacher_topk_indices, - ) + if num_valid == 0: + return student_logits.new_zeros(()) + + num_valid_int = num_valid if isinstance(num_valid, int) else num_valid.item() + total_loss = student_logits.new_zeros(()) + + if beta != 0 and beta != 1: + beta_t = torch.tensor(beta, dtype=student_logits.dtype, device=student_logits.device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + else: + beta_t = log_beta = log_1_minus_beta = None + + for start_idx in range(0, num_valid_int, chunk_size): + end_idx = min(start_idx + chunk_size, num_valid_int) + s_chunk = student_logits[start_idx:end_idx] + t_chunk = teacher_logits[start_idx:end_idx] + + s_log_probs = F.log_softmax(s_chunk, dim=-1) + t_log_probs = F.log_softmax(t_chunk, dim=-1) + del s_chunk, t_chunk + + if beta == 0: + jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) + elif beta == 1: + jsd_chunk = F.kl_div(t_log_probs, s_log_probs, reduction='none', log_target=True) + else: + mixture_log_probs = torch.logsumexp( + torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), + dim=0, + ) + kl_teacher = F.kl_div(mixture_log_probs, t_log_probs, reduction='none', log_target=True) + kl_student = F.kl_div(mixture_log_probs, s_log_probs, reduction='none', log_target=True) + del mixture_log_probs + jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student + del kl_teacher, kl_student + + total_loss = total_loss + jsd_chunk.sum() + del jsd_chunk, s_log_probs, t_log_probs + + return total_loss / num_valid def _prepare_logging(self): """Initialize logging components for on-policy rollout tracking.""" diff --git a/swift/rlhf_trainers/jsd_loss.py b/swift/rlhf_trainers/jsd_loss.py deleted file mode 100644 index c4be2ef7da..0000000000 --- a/swift/rlhf_trainers/jsd_loss.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Unified JSD (Jensen-Shannon Divergence) loss implementation for GKD training. - -This module provides a memory-efficient, chunked JSD loss computation that supports: -1. Full vocabulary mode: Uses complete logits from both models -2. Top-K mode with local teacher: Extracts top-k from teacher logits -3. Top-K mode with API: Uses pre-computed teacher logprobs and indices - -The implementation uses chunked processing to reduce peak memory usage. -""" - -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - - -def compute_jsd_loss( - student_logits: torch.Tensor, - teacher_logits: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - beta: float = 0.5, - temperature: float = 1.0, - chunk_size: int = 256, - topk: Optional[int] = None, - teacher_topk_logprobs: Optional[torch.Tensor] = None, - teacher_topk_indices: Optional[torch.Tensor] = None, - log_softmax_fn=None, - kl_div_fn=None, -) -> torch.Tensor: - """Compute JSD loss with unified chunked processing for memory efficiency. - - This function handles all three modes in a unified way: - - Full vocab mode: teacher_logits provided, topk=None - - Top-K local mode: teacher_logits provided, topk specified - - Top-K API mode: teacher_topk_logprobs and teacher_topk_indices provided - - Args: - student_logits: Student model logits [batch, seq_len, vocab_size] - teacher_logits: Teacher model logits [batch, seq_len, vocab_size], None for API mode - labels: Token labels for masking [batch, seq_len], -100 for ignored positions - beta: JSD interpolation coefficient (0=Forward KL, 0.5=JSD, 1=Reverse KL) - temperature: Temperature for softmax scaling - chunk_size: Chunk size for memory-efficient processing - topk: Number of top-k logits to use. None for full vocabulary mode. - teacher_topk_logprobs: Pre-computed teacher log probs [batch, seq_len, topk] (API mode) - teacher_topk_indices: Pre-computed teacher token indices [batch, seq_len, topk] (API mode) - log_softmax_fn: Optional custom log_softmax function (e.g., for vocab parallel) - kl_div_fn: Optional custom KL div function (e.g., for vocab parallel) - - Returns: - Scalar loss value - """ - # Determine mode - use_api_mode = teacher_topk_logprobs is not None and teacher_topk_indices is not None - use_topk = topk is not None or use_api_mode - - # Build mask - if labels is not None: - mask = labels != -100 - else: - mask = torch.ones(student_logits.shape[:2], dtype=torch.bool, device=student_logits.device) - - num_valid = mask.sum() - if num_valid == 0: - return student_logits.new_zeros(()) - - # Dispatch to appropriate mode - if use_api_mode: - return _compute_topk_api_loss(student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, num_valid, - beta, temperature) - elif use_topk: - return _compute_topk_local_loss_chunked(student_logits, teacher_logits, mask, num_valid, beta, temperature, - topk, chunk_size) - else: - return _compute_full_vocab_loss_chunked(student_logits, teacher_logits, mask, num_valid, beta, temperature, - chunk_size, log_softmax_fn, kl_div_fn) - - -def _compute_topk_jsd( - teacher_probs: torch.Tensor, - teacher_log_probs: torch.Tensor, - student_logits: torch.Tensor, - student_log_probs: torch.Tensor, - beta: float, -) -> torch.Tensor: - """Compute JSD on top-k distribution. - - Args: - teacher_probs: Teacher probabilities [*, topk] - teacher_log_probs: Teacher log probabilities [*, topk] - student_logits: Student logits at top-k positions [*, topk] - student_log_probs: Student log probabilities [*, topk] - beta: JSD interpolation coefficient - - Returns: - JSD values [*] (reduced over topk dimension) - """ - if beta == 0: - # Forward KL: KL(teacher || student) - return (teacher_probs * (teacher_log_probs - student_log_probs)).sum(dim=-1) - elif beta == 1: - # Reverse KL: KL(student || teacher) - student_probs = F.softmax(student_logits, dim=-1) - return (student_probs * (student_log_probs - teacher_log_probs)).sum(dim=-1) - else: - # Full JSD with mixture distribution - student_probs = F.softmax(student_logits, dim=-1) - mixture_probs = beta * teacher_probs + (1 - beta) * student_probs - mixture_log_probs = torch.log(mixture_probs + 1e-10) - kl_teacher = (teacher_probs * (teacher_log_probs - mixture_log_probs)).sum(dim=-1) - kl_student = (student_probs * (student_log_probs - mixture_log_probs)).sum(dim=-1) - return beta * kl_teacher + (1 - beta) * kl_student - - -def _compute_topk_api_loss( - student_logits: torch.Tensor, - teacher_topk_logprobs: torch.Tensor, - teacher_topk_indices: torch.Tensor, - mask: torch.Tensor, - num_valid: torch.Tensor, - beta: float, - temperature: float, -) -> torch.Tensor: - """Compute Top-K JSD loss using pre-computed API logprobs. - - This mode is already memory-efficient since teacher logprobs are pre-computed - and only top-k values are stored. - """ - # Apply temperature to student logits - student_logits_scaled = student_logits / temperature - - # Get teacher probs from log probs - teacher_probs = torch.exp(teacher_topk_logprobs) - - # Gather student logits at teacher's top-k positions - student_topk_logits = torch.gather(student_logits_scaled, dim=-1, index=teacher_topk_indices) - del student_logits_scaled - student_topk_log_probs = F.log_softmax(student_topk_logits, dim=-1) - - # Compute JSD - jsd = _compute_topk_jsd(teacher_probs, teacher_topk_logprobs, student_topk_logits, student_topk_log_probs, beta) - - # Apply mask and compute mean - jsd_masked = jsd * mask.float() - return jsd_masked.sum() / num_valid - - -def _compute_topk_local_loss_chunked( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - mask: torch.Tensor, - num_valid: torch.Tensor, - beta: float, - temperature: float, - topk: int, - chunk_size: int, -) -> torch.Tensor: - """Compute Top-K JSD loss with local teacher using chunked processing. - - Processes the sequence in chunks along the sequence dimension to avoid - keeping full vocab-size tensors in memory simultaneously. - """ - seq_len = student_logits.shape[1] - total_loss = student_logits.new_zeros(()) - - for start_idx in range(0, seq_len, chunk_size): - end_idx = min(start_idx + chunk_size, seq_len) - - chunk_mask = mask[:, start_idx:end_idx] - if chunk_mask.sum() == 0: - continue - - # Get logits chunks and apply temperature - student_chunk = student_logits[:, start_idx:end_idx, :] / temperature - teacher_chunk = teacher_logits[:, start_idx:end_idx, :] / temperature - - # Get top-k from teacher chunk, then release teacher chunk - teacher_topk_logits, topk_indices = torch.topk(teacher_chunk, k=topk, dim=-1) - del teacher_chunk - - teacher_probs = F.softmax(teacher_topk_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_topk_logits, dim=-1) - del teacher_topk_logits - - # Gather student logits at top-k positions, then release student chunk - student_topk_logits = torch.gather(student_chunk, dim=-1, index=topk_indices) - del student_chunk, topk_indices - - student_log_probs = F.log_softmax(student_topk_logits, dim=-1) - - # Compute JSD and accumulate - jsd = _compute_topk_jsd(teacher_probs, teacher_log_probs, student_topk_logits, student_log_probs, beta) - jsd_masked = jsd * chunk_mask.float() - total_loss = total_loss + jsd_masked.sum() - - del jsd, jsd_masked, student_topk_logits, student_log_probs, teacher_probs, teacher_log_probs - - return total_loss / num_valid - - -def _compute_full_vocab_loss_chunked( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - mask: torch.Tensor, - num_valid: torch.Tensor, - beta: float, - temperature: float, - chunk_size: int, - log_softmax_fn, - kl_div_fn=None, -) -> torch.Tensor: - """Compute full vocabulary JSD loss with chunked processing. - - Supports custom log_softmax and kl_div functions for vocab-parallel computation. - """ - # Use default implementations if not provided - if log_softmax_fn is None: - - def log_softmax_fn(x): - return F.log_softmax(x, dim=-1) - - if kl_div_fn is None: - - def kl_div_fn(p, q): - return F.kl_div(p, q, reduction='none', log_target=True) - - # Apply temperature and masking to flatten valid tokens - student_logits_masked = (student_logits / temperature)[mask] - teacher_logits_masked = (teacher_logits / temperature)[mask] - del student_logits, teacher_logits - - num_valid_int = num_valid.item() if isinstance(num_valid, torch.Tensor) else int(num_valid) - total_loss = student_logits_masked.new_zeros(()) - - # Precompute beta tensors if needed - if beta != 0 and beta != 1: - beta_t = torch.tensor(beta, dtype=student_logits_masked.dtype, device=student_logits_masked.device) - log_beta = torch.log(beta_t) - log_1_minus_beta = torch.log1p(-beta_t) - else: - beta_t = log_beta = log_1_minus_beta = None - - for start_idx in range(0, num_valid_int, chunk_size): - end_idx = min(start_idx + chunk_size, num_valid_int) - s_chunk = student_logits_masked[start_idx:end_idx] - t_chunk = teacher_logits_masked[start_idx:end_idx] - - s_log_probs = log_softmax_fn(s_chunk) - t_log_probs = log_softmax_fn(t_chunk) - del s_chunk, t_chunk - - if beta == 0: - jsd_chunk = kl_div_fn(s_log_probs, t_log_probs) - elif beta == 1: - jsd_chunk = kl_div_fn(t_log_probs, s_log_probs) - else: - mixture_log_probs = torch.logsumexp( - torch.stack([s_log_probs + log_1_minus_beta, t_log_probs + log_beta]), - dim=0, - ) - kl_teacher = kl_div_fn(mixture_log_probs, t_log_probs) - kl_student = kl_div_fn(mixture_log_probs, s_log_probs) - del mixture_log_probs - jsd_chunk = beta_t * kl_teacher + (1 - beta_t) * kl_student - del kl_teacher, kl_student - - total_loss = total_loss + jsd_chunk.sum() - del jsd_chunk, s_log_probs, t_log_probs - - del student_logits_masked, teacher_logits_masked - return total_loss / num_valid diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py index 37ae6ee992..0f8f53d6b2 100644 --- a/swift/rlhf_trainers/teacher_api_client.py +++ b/swift/rlhf_trainers/teacher_api_client.py @@ -1,14 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Client for fetching teacher model logprobs from OpenAI-compatible endpoints. - -Supports swift deploy (vLLM backend) and standalone vLLM servers. -Used for knowledge distillation (GKD) training with top-k logprobs. -""" +"""Client for fetching teacher model logprobs from OpenAI-compatible endpoints.""" import logging -from typing import Dict, List, Optional, Tuple - import requests import torch +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Tuple logger = logging.getLogger(__name__) @@ -20,57 +16,34 @@ class TeacherAPIClient: base_url: Server URL (e.g., 'http://localhost:8000'). top_logprobs: Number of top log probabilities per token. timeout: Request timeout in seconds. - api_key: Optional API key for authentication. - model_name: Model name for API requests. Auto-detected if None. """ - def __init__( - self, - base_url: str, - top_logprobs: int = 20, - timeout: float = 300.0, - api_key: Optional[str] = None, - model_name: Optional[str] = None, - ): + def __init__(self, base_url: str, top_logprobs: int = 20, timeout: float = 300.0): self.base_url = base_url.rstrip('/') self.top_logprobs = top_logprobs self.timeout = timeout - self.api_key = api_key - self._model_name = model_name + self._model_name = None @property def model_name(self) -> str: if self._model_name is None: - self._model_name = self._detect_model_name() + try: + resp = requests.get(f'{self.base_url}/v1/models', timeout=10) + if resp.ok and resp.json().get('data'): + self._model_name = resp.json()['data'][0]['id'] + except Exception as e: + logger.warning(f'Failed to detect model name: {e}') + if self._model_name is None: + self._model_name = 'default' return self._model_name - def _headers(self) -> Dict[str, str]: - headers = {'Content-Type': 'application/json'} - if self.api_key: - headers['Authorization'] = f'Bearer {self.api_key}' - return headers - - def _detect_model_name(self) -> str: - try: - resp = requests.get(f'{self.base_url}/v1/models', headers=self._headers(), timeout=10) - if resp.status_code == 200: - data = resp.json() - if data.get('data'): - return data['data'][0]['id'] - except Exception as e: - logger.warning(f'Failed to detect model name: {e}') - return 'default' - - def check_server_health(self, timeout: float = 5.0) -> bool: + def check_health(self, timeout: float = 5.0) -> bool: """Check if the teacher model server is reachable.""" - for endpoint in ['/health', '/v1/models']: - try: - resp = requests.get(f'{self.base_url}{endpoint}', timeout=timeout) - if resp.status_code == 200: - return True - except requests.RequestException: - continue - return False + try: + resp = requests.get(f'{self.base_url}/v1/models', timeout=timeout) + return resp.ok + except requests.RequestException: + return False def get_logprobs_sync( self, @@ -79,74 +52,42 @@ def get_logprobs_sync( ) -> Tuple[torch.Tensor, torch.Tensor]: """Fetch top-k logprobs for a batch of token sequences. - Args: - input_ids: List of token ID sequences. - top_logprobs: Override default top_logprobs. - Returns: (logprobs, indices) tensors of shape [batch, max_seq_len, topk]. """ topk = top_logprobs or self.top_logprobs batch_size = len(input_ids) max_seq_len = max(len(ids) for ids in input_ids) - - logprobs_tensor = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) - indices_tensor = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) - url = f'{self.base_url}/v1/completions' model = self.model_name - for batch_idx, ids in enumerate(input_ids): + logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) + indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) + + def _fetch_one(batch_idx: int): payload = { 'model': model, - 'prompt': ids, + 'prompt': input_ids[batch_idx], 'max_tokens': 0, 'temperature': 0, 'logprobs': topk, 'echo': True, } try: - resp = requests.post(url, json=payload, headers=self._headers(), timeout=self.timeout) - if resp.status_code != 200: - logger.error(f'API error for sequence {batch_idx}: {resp.status_code} - {resp.text}') - continue - self._parse_into_tensors(resp.json(), batch_idx, logprobs_tensor, indices_tensor, topk) + resp = requests.post(url, json=payload, timeout=self.timeout) + resp.raise_for_status() + top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', []) + for pos, pos_lp in enumerate(top_logprobs_list): + if pos_lp is None: + continue + sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk] + for k, (tid_str, lp) in enumerate(sorted_items): + indices_out[batch_idx, pos, k] = int(tid_str) + logprobs_out[batch_idx, pos, k] = lp except Exception as e: logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}') - return logprobs_tensor, indices_tensor - - @staticmethod - def _parse_into_tensors( - response: dict, - batch_idx: int, - logprobs_out: torch.Tensor, - indices_out: torch.Tensor, - topk: int, - ) -> None: - """Parse a single completions API response into pre-allocated tensors.""" - choices = response.get('choices', []) - if not choices: - return - logprobs_data = choices[0].get('logprobs') or {} - top_logprobs_list = logprobs_data.get('top_logprobs', []) + with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: + list(pool.map(_fetch_one, range(batch_size))) - for pos_idx, pos_logprobs in enumerate(top_logprobs_list): - if pos_logprobs is None: - continue - sorted_items = sorted( - pos_logprobs.items(), - key=lambda x: -(x[1] if isinstance(x[1], (int, float)) else - (x[1].get('logprob', float('-inf')) if isinstance(x[1], dict) else float('-inf'))), - )[:topk] - for k_idx, (token_id_str, logprob_val) in enumerate(sorted_items): - try: - indices_out[batch_idx, pos_idx, k_idx] = int(token_id_str) - if isinstance(logprob_val, (int, float)): - logprobs_out[batch_idx, pos_idx, k_idx] = logprob_val - elif isinstance(logprob_val, dict): - logprobs_out[batch_idx, pos_idx, k_idx] = logprob_val.get('logprob', float('-inf')) - elif hasattr(logprob_val, 'logprob'): - logprobs_out[batch_idx, pos_idx, k_idx] = logprob_val.logprob - except (ValueError, TypeError): - continue + return logprobs_out, indices_out diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index d28e406eb8..2907ef2e0b 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -1473,13 +1473,7 @@ def check_vllm_version_ge(min_version: str) -> bool: def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60): - """ - Create and initialize TeacherAPIClient for external teacher model service. - - Args: - args: Arguments object containing teacher_model_server and gkd_logits_topk - check_health: Whether to check server health after creation (default: True) - timeout: Timeout for health check in seconds (default: 60) + """Create TeacherAPIClient for external teacher model service. Returns: TeacherAPIClient instance or None if teacher_model_server is not set @@ -1499,7 +1493,7 @@ def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60 top_logprobs=gkd_logits_topk, ) if check_health: - teacher_api_client.check_server_health(timeout=timeout) + teacher_api_client.check_health(timeout=timeout) logger.info(f'Teacher API client initialized with top_logprobs={gkd_logits_topk}') return teacher_api_client diff --git a/tests/train/test_teacher_api_client.py b/tests/train/test_teacher_api_client.py deleted file mode 100644 index a1c70760a9..0000000000 --- a/tests/train/test_teacher_api_client.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Test script for TeacherAPIClient with vLLM backend. - -This script tests the TeacherAPIClient's ability to fetch logprobs from: -1. swift deploy with vLLM backend -2. Standalone vLLM server (vllm serve) - -Usage: - python test_teacher_api_client.py # Run all tests - python test_teacher_api_client.py --parse-only # Only test format parsing -""" -import argparse -import os -import time -import multiprocessing - -os.environ.setdefault('CUDA_VISIBLE_DEVICES', '0') - - -def wait_for_server(base_url: str, timeout: int = 120) -> bool: - """Wait for server to be ready.""" - import requests - start_time = time.time() - while time.time() - start_time < timeout: - try: - for endpoint in ['/health', '/v1/models']: - resp = requests.get(f'{base_url}{endpoint}', timeout=5) - if resp.status_code == 200: - print(f'Server is ready at {base_url}') - return True - except Exception: - pass - time.sleep(2) - print(f'Timeout waiting for server at {base_url}') - return False - - -def test_api_client_logprobs(base_url: str): - """Test TeacherAPIClient logprobs fetching.""" - from swift.rlhf_trainers import TeacherAPIClient - from transformers import AutoTokenizer - - print(f'\n{"=" * 60}') - print(f'Testing TeacherAPIClient') - print(f'Base URL: {base_url}') - print('=' * 60) - - # Initialize client - client = TeacherAPIClient( - base_url=base_url, - top_logprobs=10, - timeout=60.0, - ) - - # Check server health - is_healthy = client.check_server_health() - print(f'Server health check: {"OK" if is_healthy else "FAILED"}') - if not is_healthy: - print('Skipping test due to server health check failure') - return False - - # Prepare test input - tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-0.5B-Instruct', trust_remote_code=True) - test_text = 'Hello, how are you today?' - input_ids = tokenizer.encode(test_text, add_special_tokens=True) - - print(f'\nTest text: "{test_text}"') - print(f'Token IDs: {input_ids}') - print(f'Number of tokens: {len(input_ids)}') - - # Test synchronous API - print('\n--- Testing synchronous get_logprobs_sync ---') - try: - logprobs_tensor, indices_tensor = client.get_logprobs_sync( - input_ids=[input_ids], top_logprobs=5) - - print(f'Logprobs tensor shape: {logprobs_tensor.shape}') - print(f'Indices tensor shape: {indices_tensor.shape}') - - # Check for valid logprobs - valid_count = (logprobs_tensor > float('-inf')).sum().item() - print(f'Valid logprob entries: {valid_count}') - - if valid_count > 0: - print('\nSample logprobs for first position:') - for k in range(min(5, indices_tensor.shape[-1])): - token_id = indices_tensor[0, 0, k].item() - logprob = logprobs_tensor[0, 0, k].item() - if token_id > 0 and logprob > float('-inf'): - token_str = tokenizer.decode([token_id]) - print(f' Top-{k + 1}: token_id={token_id} ("{token_str}"), logprob={logprob:.4f}') - print('\nSync test: PASSED') - return True - else: - print('\nSync test: FAILED (no valid logprobs)') - return False - - except Exception as e: - print(f'Sync test: FAILED with error: {e}') - import traceback - traceback.print_exc() - return False - - -def test_with_swift_deploy_vllm(port: int = 8100): - """Test with swift deploy using vLLM backend.""" - from swift import DeployArguments, deploy_main - - print('\n' + '=' * 60) - print('Starting swift deploy with vLLM backend...') - print('=' * 60) - - mp = multiprocessing.get_context('spawn') - args = DeployArguments( - model='Qwen/Qwen2-0.5B-Instruct', - infer_backend='vllm', - port=port, - verbose=False, - vllm_max_model_len=4096, - ) - - process = mp.Process(target=deploy_main, args=(args, )) - process.start() - - try: - base_url = f'http://localhost:{port}' - if wait_for_server(base_url): - result = test_api_client_logprobs(base_url) - return result - return False - finally: - process.terminate() - process.join(timeout=10) - if process.is_alive(): - process.kill() - - -def test_logprobs_format_parsing(): - """Test parsing of vLLM logprobs response format.""" - print('\n' + '=' * 60) - print('Testing logprobs format parsing') - print('=' * 60) - - from swift.rlhf_trainers import TeacherAPIClient - - client = TeacherAPIClient(base_url='http://localhost:8000', top_logprobs=5) - - # Test vLLM response parsing with token_id keys - vllm_response = { - 'choices': [{ - 'logprobs': { - 'top_logprobs': [ - { - '123': -0.5, - '456': -1.2, - '789': -2.0 - }, - { - '44': -0.1, - '55': -2.5, - '66': -3.0 - }, - ] - } - }] - } - - result = client._parse_response(vllm_response, seq_len=2, topk=3) - print(f'Parsing result indices: {result["indices"]}') - print(f'Parsing result values: {result["values"]}') - assert len(result['values']) == 2, 'Expected 2 positions' - assert len(result['values'][0]) == 3, 'Expected 3 top logprobs per position' - assert result['indices'][0][0] == 123, f'Expected token ID 123, got {result["indices"][0][0]}' - print('Format parsing: PASSED') - - return True - - -def main(): - parser = argparse.ArgumentParser(description='Test TeacherAPIClient') - parser.add_argument('--parse-only', action='store_true', help='Only test format parsing (no server needed)') - args = parser.parse_args() - - results = {} - - # Test format parsing (no server needed) - print('\n' + '#' * 60) - print('# Testing format parsing') - print('#' * 60) - try: - results['format_parsing'] = test_logprobs_format_parsing() - except Exception as e: - print(f'Format parsing test failed: {e}') - import traceback - traceback.print_exc() - results['format_parsing'] = False - - if args.parse_only: - print('\n' + '=' * 60) - print('Test Summary (parse-only mode):') - print('=' * 60) - for test, passed in results.items(): - print(f' {test}: {"PASSED" if passed else "FAILED"}') - return - - # Test with swift deploy - print('\n' + '#' * 60) - print('# Testing with vLLM backend') - print('#' * 60) - try: - results['vllm'] = test_with_swift_deploy_vllm() - except Exception as e: - print(f'vLLM test failed: {e}') - import traceback - traceback.print_exc() - results['vllm'] = False - - # Print summary - print('\n' + '=' * 60) - print('Test Summary:') - print('=' * 60) - for test, passed in results.items(): - print(f' {test}: {"PASSED" if passed else "FAILED"}') - - all_passed = all(results.values()) - print(f'\nOverall: {"ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED"}') - return all_passed - - -if __name__ == '__main__': - main() From fc1b673e11591b4e6398fb58d46bb85780981b4b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 2 Mar 2026 14:02:30 +0800 Subject: [PATCH 06/10] clean --- swift/megatron/arguments/megatron_args.py | 9 +- swift/megatron/pipelines/train/rlhf.py | 15 --- swift/megatron/trainers/gkd_trainer.py | 9 +- swift/pipelines/train/rlhf.py | 9 -- swift/rlhf_trainers/__init__.py | 4 +- swift/rlhf_trainers/gkd_trainer.py | 10 +- swift/rlhf_trainers/teacher_api_client.py | 132 ++++++++++------------ swift/rlhf_trainers/utils.py | 26 ----- 8 files changed, 71 insertions(+), 143 deletions(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index e880592ef4..033352bbdd 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -52,14 +52,7 @@ class RLHFMegatronArgumentsMixin: 'URL of the teacher model server (e.g., http://localhost:8000). ' 'When set, teacher logprobs are fetched via API instead of loading a local model.' }) - gkd_logits_topk: Optional[int] = field( - default=None, - metadata={ - 'help': - 'Number of top-k logits for KL computation in GKD. ' - 'None = full vocabulary, positive integer = top-k only. ' - 'When using teacher_model_server, limited by server max_logprobs (vLLM default: 20).' - }) + gkd_logits_topk: Optional[int] = None lmbda: float = 0.5 # On-policy probability: with prob lmbda, use student-generated responses seq_kd: bool = False # Sequential KD: use teacher-generated responses when not on-policy offload_teacher_model: bool = False # Offload teacher model to CPU to save GPU memory diff --git a/swift/megatron/pipelines/train/rlhf.py b/swift/megatron/pipelines/train/rlhf.py index 6ff5c383c1..05e3aef570 100644 --- a/swift/megatron/pipelines/train/rlhf.py +++ b/swift/megatron/pipelines/train/rlhf.py @@ -31,8 +31,6 @@ def prepare_trainer(self): kwargs = {} if args.rlhf_type in ('grpo', 'gkd'): kwargs['vllm_client'] = self._prepare_vllm_client() - if args.rlhf_type == 'gkd': - kwargs['teacher_api_client'] = self._prepare_teacher_api_client() return trainer_cls(args, self.template, **kwargs) def _prepare_template(self) -> None: @@ -70,19 +68,6 @@ def _prepare_vllm_client(self): logger.info('Connected to vLLM server') return vllm_client - def _prepare_teacher_api_client(self): - """Prepare teacher API client for external teacher model service. - - In Megatron with pure Data Parallel (TP=PP=CP=1), each rank processes different data - and needs its own API client. With model parallelism (TP/PP/CP > 1), one rank per - model parallel group calls the API and broadcasts results. - """ - from swift.rlhf_trainers.utils import create_teacher_api_client - from swift.utils import is_last_rank - if is_last_rank(): - return create_teacher_api_client(self.args, check_health=True, timeout=60) - return None - def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index d17862ab56..757f6e5d31 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -34,7 +34,6 @@ class MegatronGKDTrainer(MegatronRolloutMixin, MegatronRLHFTrainer): def __init__(self, args: MegatronArguments, template, **kwargs): self.vllm_client = kwargs.pop('vllm_client', None) - self.teacher_api_client = kwargs.pop('teacher_api_client', None) # GKD-specific parameters self.beta = args.beta # JSD interpolation coefficient @@ -50,7 +49,8 @@ def __init__(self, args: MegatronArguments, template, **kwargs): self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) # Check use_teacher_api based on args, not client existence # (API client is only created on last rank, but all ranks need to know the mode) - self.use_teacher_api = getattr(args, 'teacher_model_server', None) is not None + self.teacher_model_server = getattr(args, 'teacher_model_server', None) + self.use_teacher_api = self.teacher_model_server is not None # Validate teacher configuration if not self.use_teacher_api: @@ -295,11 +295,12 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None: """Fetch teacher logprobs from external API service.""" + from swift.rlhf_trainers.teacher_api_client import fetch_teacher_logprobs topk = self.gkd_logits_topk for encoded_batch in encoded_batches: input_ids = encoded_batch['input_ids'] - teacher_logprobs, teacher_indices = self.teacher_api_client.get_logprobs_sync( - input_ids=input_ids.tolist(), top_logprobs=topk) + teacher_logprobs, teacher_indices = fetch_teacher_logprobs( + self.teacher_model_server, input_ids.tolist(), topk=topk) encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device) encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device) encoded_batch['teacher_logits'] = None diff --git a/swift/pipelines/train/rlhf.py b/swift/pipelines/train/rlhf.py index 6f1d8d9ada..fb020f5df8 100644 --- a/swift/pipelines/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -233,18 +233,9 @@ def _get_trainer_kwargs(self): if self.args.rlhf_type == 'gkd': if self.args.teacher_deepspeed: trainer_kwargs['teacher_deepspeed_config'] = self.args.teacher_deepspeed - # Pass GKD-specific args to trainer trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk - # Initialize teacher API client if using external teacher service if self.args.teacher_model_server: - # Pass teacher_model_server so trainer knows to use API mode on all ranks trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server - from swift.rlhf_trainers.utils import create_teacher_api_client - - # In DP mode (DeepSpeed/FSDP), each rank has different data and needs its own client - # Use all_ranks=True so every rank can independently fetch teacher logprobs - trainer_kwargs['teacher_api_client'] = create_teacher_api_client( - self.args, check_health=False, timeout=60) return trainer_kwargs diff --git a/swift/rlhf_trainers/__init__.py b/swift/rlhf_trainers/__init__.py index aec0e9438c..87a05f1831 100644 --- a/swift/rlhf_trainers/__init__.py +++ b/swift/rlhf_trainers/__init__.py @@ -15,7 +15,7 @@ from .ppo_trainer import PPOTrainer from .reward_trainer import RewardTrainer from .rlhf_mixin import RLHFTrainerMixin - from .teacher_api_client import TeacherAPIClient + from .teacher_api_client import fetch_teacher_logprobs from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin from .vllm_client import VLLMClient else: @@ -32,7 +32,7 @@ 'args_mixin': ['VllmArguments', 'GRPOArgumentsMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], 'vllm_client': ['VLLMClient'], - 'teacher_api_client': ['TeacherAPIClient'], + 'teacher_api_client': ['fetch_teacher_logprobs'], 'arguments': ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig'] } diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index d31a7f5609..576e00a1aa 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -57,7 +57,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non teacher_model = kwargs.pop('teacher_model', None) teacher_deepspeed_config = kwargs.pop('teacher_deepspeed_config', None) self.vllm_client = kwargs.pop('vllm_client', None) - self.teacher_api_client = kwargs.pop('teacher_api_client', None) self.gkd_logits_topk = kwargs.pop('gkd_logits_topk', None) teacher_model_server = kwargs.pop('teacher_model_server', None) super().__init__(model, None, *_args, **kwargs) @@ -69,6 +68,7 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} self._total_train_tokens = 0 + self.teacher_model_server = teacher_model_server self.use_teacher_api = teacher_model_server is not None # Initialize logging components @@ -469,12 +469,10 @@ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tenso Returns: Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk] """ + from .teacher_api_client import fetch_teacher_logprobs input_ids = encoded_inputs['input_ids'] - topk = self.gkd_logits_topk - teacher_logprobs, teacher_indices = self.teacher_api_client.get_logprobs_sync( - input_ids=input_ids.tolist(), - top_logprobs=topk, - ) + teacher_logprobs, teacher_indices = fetch_teacher_logprobs( + self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk) return teacher_logprobs.to(input_ids.device), teacher_indices.to(input_ids.device) def prediction_step(self, model, inputs, *args, **kwargs): diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py index 0f8f53d6b2..22d5c5c24b 100644 --- a/swift/rlhf_trainers/teacher_api_client.py +++ b/swift/rlhf_trainers/teacher_api_client.py @@ -1,5 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Client for fetching teacher model logprobs from OpenAI-compatible endpoints.""" +"""Fetch teacher model logprobs from OpenAI-compatible endpoints.""" import logging import requests import torch @@ -8,86 +8,72 @@ logger = logging.getLogger(__name__) +_model_name_cache: dict = {} -class TeacherAPIClient: - """Fetch teacher top-k logprobs from an OpenAI-compatible completions API. + +def _get_model_name(base_url: str) -> str: + if base_url not in _model_name_cache: + try: + resp = requests.get(f'{base_url}/v1/models', timeout=10) + if resp.ok and resp.json().get('data'): + _model_name_cache[base_url] = resp.json()['data'][0]['id'] + except Exception as e: + logger.warning(f'Failed to detect model name: {e}') + if base_url not in _model_name_cache: + _model_name_cache[base_url] = 'default' + return _model_name_cache[base_url] + + +def fetch_teacher_logprobs( + base_url: str, + input_ids: List[List[int]], + topk: int = 20, + timeout: float = 300.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetch top-k logprobs from an OpenAI-compatible completions API. Args: base_url: Server URL (e.g., 'http://localhost:8000'). - top_logprobs: Number of top log probabilities per token. + input_ids: List of token ID sequences. + topk: Number of top log probabilities per token. timeout: Request timeout in seconds. - """ - def __init__(self, base_url: str, top_logprobs: int = 20, timeout: float = 300.0): - self.base_url = base_url.rstrip('/') - self.top_logprobs = top_logprobs - self.timeout = timeout - self._model_name = None + Returns: + (logprobs, indices) tensors of shape [batch, max_seq_len, topk]. + """ + base_url = base_url.rstrip('/') + batch_size = len(input_ids) + max_seq_len = max(len(ids) for ids in input_ids) + url = f'{base_url}/v1/completions' + model = _get_model_name(base_url) - @property - def model_name(self) -> str: - if self._model_name is None: - try: - resp = requests.get(f'{self.base_url}/v1/models', timeout=10) - if resp.ok and resp.json().get('data'): - self._model_name = resp.json()['data'][0]['id'] - except Exception as e: - logger.warning(f'Failed to detect model name: {e}') - if self._model_name is None: - self._model_name = 'default' - return self._model_name + logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) + indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) - def check_health(self, timeout: float = 5.0) -> bool: - """Check if the teacher model server is reachable.""" + def _fetch_one(batch_idx: int): + payload = { + 'model': model, + 'prompt': input_ids[batch_idx], + 'max_tokens': 0, + 'temperature': 0, + 'logprobs': topk, + 'echo': True, + } try: - resp = requests.get(f'{self.base_url}/v1/models', timeout=timeout) - return resp.ok - except requests.RequestException: - return False - - def get_logprobs_sync( - self, - input_ids: List[List[int]], - top_logprobs: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Fetch top-k logprobs for a batch of token sequences. - - Returns: - (logprobs, indices) tensors of shape [batch, max_seq_len, topk]. - """ - topk = top_logprobs or self.top_logprobs - batch_size = len(input_ids) - max_seq_len = max(len(ids) for ids in input_ids) - url = f'{self.base_url}/v1/completions' - model = self.model_name - - logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) - indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) - - def _fetch_one(batch_idx: int): - payload = { - 'model': model, - 'prompt': input_ids[batch_idx], - 'max_tokens': 0, - 'temperature': 0, - 'logprobs': topk, - 'echo': True, - } - try: - resp = requests.post(url, json=payload, timeout=self.timeout) - resp.raise_for_status() - top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', []) - for pos, pos_lp in enumerate(top_logprobs_list): - if pos_lp is None: - continue - sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk] - for k, (tid_str, lp) in enumerate(sorted_items): - indices_out[batch_idx, pos, k] = int(tid_str) - logprobs_out[batch_idx, pos, k] = lp - except Exception as e: - logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}') + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', []) + for pos, pos_lp in enumerate(top_logprobs_list): + if pos_lp is None: + continue + sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk] + for k, (tid_str, lp) in enumerate(sorted_items): + indices_out[batch_idx, pos, k] = int(tid_str) + logprobs_out[batch_idx, pos, k] = lp + except Exception as e: + logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}') - with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: - list(pool.map(_fetch_one, range(batch_size))) + with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: + list(pool.map(_fetch_one, range(batch_size))) - return logprobs_out, indices_out + return logprobs_out, indices_out diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index 2907ef2e0b..2484df4e77 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -1472,32 +1472,6 @@ def check_vllm_version_ge(min_version: str) -> bool: return version.parse(vllm_version) >= version.parse(min_version) -def create_teacher_api_client(args, check_health: bool = True, timeout: int = 60): - """Create TeacherAPIClient for external teacher model service. - - Returns: - TeacherAPIClient instance or None if teacher_model_server is not set - """ - teacher_model_server = getattr(args, 'teacher_model_server', None) - if not teacher_model_server: - return None - - from swift.rlhf_trainers import TeacherAPIClient - - logger = get_logger() - gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) or 20 - - logger.info(f'Initializing teacher API client for {teacher_model_server}') - teacher_api_client = TeacherAPIClient( - base_url=teacher_model_server, - top_logprobs=gkd_logits_topk, - ) - if check_health: - teacher_api_client.check_health(timeout=timeout) - logger.info(f'Teacher API client initialized with top_logprobs={gkd_logits_topk}') - return teacher_api_client - - # ============================================================================ # Padding-free utilities # ============================================================================ From fd351402549fd0d6d514239da572063175b00e34 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 2 Mar 2026 22:53:17 +0800 Subject: [PATCH 07/10] update --- docs/source/Instruction/GKD.md | 25 ++-- docs/source/Megatron-SWIFT/GKD.md | 2 +- docs/source_en/Instruction/GKD.md | 22 +-- docs/source_en/Megatron-SWIFT/GKD.md | 2 +- examples/megatron/rlhf/gkd/teacher_server.sh | 5 + .../train/rlhf/gkd/gsm8k_teacher_server.sh | 77 ++++++++++ examples/train/rlhf/gkd/teacher_server.sh | 31 ++-- swift/megatron/trainers/gkd_trainer.py | 19 ++- swift/rlhf_trainers/__init__.py | 2 - swift/rlhf_trainers/gkd_trainer.py | 140 ++++++++++++++---- swift/rlhf_trainers/teacher_api_client.py | 79 ---------- 11 files changed, 249 insertions(+), 155 deletions(-) create mode 100644 examples/train/rlhf/gkd/gsm8k_teacher_server.sh delete mode 100644 swift/rlhf_trainers/teacher_api_client.py diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index 8859928faf..b276bd3860 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -196,21 +196,18 @@ swift rlhf \ | `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 | **支持的后端**: -- `swift deploy`(vLLM backend) -- 独立 vLLM 服务(`vllm serve`) +- `vllm serve`(推荐) + +> **注意**:仅支持 `vllm serve` 作为教师服务后端。训练代码通过 `/v1/completions` 接口直接传递 token IDs 并使用 `prompt_logprobs` 参数获取输入 token 的 log 概率,这是 vLLM 原生支持的功能。 **步骤 1:部署教师模型服务** ```bash -# 使用 swift deploy 部署教师模型 -CUDA_VISIBLE_DEVICES=0,1 swift deploy \ - --model Qwen/Qwen2-72B-Instruct \ - --infer_backend vllm \ +# 使用 vllm serve 部署教师模型 +CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ --port 8000 \ - --vllm_engine_kwargs '{"max_logprobs": 64}' - -# 或使用独立 vLLM 服务 -vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 + --max-logprobs 64 \ + --gpu-memory-utilization 0.9 ``` **步骤 2:启动 GKD 训练** @@ -218,17 +215,17 @@ vllm serve Qwen/Qwen2-72B-Instruct --max-logprobs 64 --port 8000 ```bash swift rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen2-7B-Instruct \ + --model Qwen/Qwen2.5-7B \ --teacher_model_server http://localhost:8000 \ - --gkd_logits_topk 20 \ + --gkd_logits_topk 64 \ --dataset your_dataset \ --lmbda 1.0 \ - --beta 0.5 \ + --beta 1.0 \ ... ``` > **vLLM max_logprobs 限制**: -> - vLLM 默认 `max_logprobs=20`,可通过 `--vllm_engine_kwargs '{"max_logprobs": N}'` 参数调整 +> - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整 > - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置 ## 采样加速 diff --git a/docs/source/Megatron-SWIFT/GKD.md b/docs/source/Megatron-SWIFT/GKD.md index 37b810624e..9a023cb456 100644 --- a/docs/source/Megatron-SWIFT/GKD.md +++ b/docs/source/Megatron-SWIFT/GKD.md @@ -34,7 +34,7 @@ Megatron GKD 当前已支持以下功能: | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `--teacher_model` | str | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 | -| `--teacher_model_server` | str | None | 教师模型服务地址,如 `http://localhost:8000` | +| `--teacher_model_server` | str | None | 教师模型服务地址(仅支持 `vllm serve`),如 `http://localhost:8000` | | `--gkd_logits_topk` | int | None | Top-K logits 数量,使用外部教师 API 时必须设置 | | `--beta` | float | 0.5 | JSD 散度插值系数:
• 0.0: Forward KL
• 0.5: 对称 JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy 学习触发概率:
• 0.0: 纯 Off-Policy
• 1.0: 纯 On-Policy | diff --git a/docs/source_en/Instruction/GKD.md b/docs/source_en/Instruction/GKD.md index d21bc5b04c..cad177cae8 100644 --- a/docs/source_en/Instruction/GKD.md +++ b/docs/source_en/Instruction/GKD.md @@ -197,18 +197,18 @@ When `gkd_logits_topk` is set, you can use an external teacher model API service | `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API | **Supported Backends**: -- `swift deploy` (vLLM backend) -- Standalone vLLM server (`vllm serve`) +- `vllm serve` (recommended) + +> **Note**: Only `vllm serve` is supported as the teacher server backend. The training code sends raw token IDs via the `prompt` field and uses the `prompt_logprobs` parameter in the `/v1/completions` API to obtain input token log-probabilities. This is a vLLM-native feature. **Step 1: Deploy Teacher Model Service** ```bash -# Deploy teacher model with swift deploy (recommended) -swift deploy \ - --model Qwen/Qwen2.5-14B-Instruct \ - --infer_backend vllm \ +# Deploy teacher model with vllm serve +CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ --port 8000 \ - --vllm_engine_kwargs '{"max_logprobs": 64}' + --max-logprobs 64 \ + --gpu-memory-utilization 0.9 ``` **Step 2: Start GKD Training** @@ -216,17 +216,17 @@ swift deploy \ ```bash swift rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen2.5-7B-Instruct \ + --model Qwen/Qwen2.5-7B \ --teacher_model_server http://localhost:8000 \ - --gkd_logits_topk 20 \ + --gkd_logits_topk 64 \ --dataset your_dataset \ --lmbda 1.0 \ - --beta 0.5 \ + --beta 1.0 \ ... ``` > **vLLM max_logprobs Limitation**: -> - vLLM default `max_logprobs=20`, adjustable via `--vllm_engine_kwargs '{"max_logprobs": N}'` parameter +> - vLLM default `max_logprobs=20`, adjustable via `--max-logprobs N` parameter > - `gkd_logits_topk` cannot exceed the server's `max_logprobs` setting ## Sampling Acceleration diff --git a/docs/source_en/Megatron-SWIFT/GKD.md b/docs/source_en/Megatron-SWIFT/GKD.md index b7eb6f6578..0502b8485b 100644 --- a/docs/source_en/Megatron-SWIFT/GKD.md +++ b/docs/source_en/Megatron-SWIFT/GKD.md @@ -34,7 +34,7 @@ Megatron GKD currently supports the following features: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `--teacher_model` | str | - | Path or model ID of the teacher model
*Can be omitted when using `teacher_model_server` | -| `--teacher_model_server` | str | None | Teacher model service URL, e.g. `http://localhost:8000` | +| `--teacher_model_server` | str | None | Teacher model service URL (`vllm serve` only), e.g. `http://localhost:8000` | | `--gkd_logits_topk` | int | None | Number of Top-K logits; required when using external API | | `--beta` | float | 0.5 | JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy | diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh index 53a2a964d1..72a82ba9b4 100644 --- a/examples/megatron/rlhf/gkd/teacher_server.sh +++ b/examples/megatron/rlhf/gkd/teacher_server.sh @@ -1,3 +1,8 @@ +# GKD Training with External Teacher Model Server (Megatron) +# +# Start teacher server first (in a separate terminal / GPU): +# CUDA_VISIBLE_DEVICES=4 vllm serve Qwen/Qwen3-8B --port 8000 --max-logprobs 64 + CUDA_VISIBLE_DEVICES=0,1,2,3 \ NPROC_PER_NODE=4 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ diff --git a/examples/train/rlhf/gkd/gsm8k_teacher_server.sh b/examples/train/rlhf/gkd/gsm8k_teacher_server.sh new file mode 100644 index 0000000000..3c848b8814 --- /dev/null +++ b/examples/train/rlhf/gkd/gsm8k_teacher_server.sh @@ -0,0 +1,77 @@ +# GKD on GSM8K: Teacher Server Mode with Top-K Logits +# +# This script validates GKD effectiveness on mathematical reasoning using GSM8K. +# Student: Qwen2.5-1.5B-Instruct, Teacher: Qwen2.5-7B-Instruct (via vllm serve) +# +# Expected outcome: GSM8K accuracy should improve after GKD training, as the student +# learns the teacher's reasoning distribution on math problems. +# +# ===================== Step 1: Start Teacher Server ===================== +# Run in a separate terminal / GPU: +# +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \ +# --port 8000 \ +# --max-logprobs 64 \ +# --gpu-memory-utilization 0.9 +# +# Wait until the server is ready, then verify: +# curl http://localhost:8000/v1/models +# ======================================================================== +# +# ===================== Step 2: Prepare GSM8K Dataset ===================== +# The dataset uses the standard GSM8K train split from Hugging Face: +# openai/gsm8k (7473 training samples) +# Swift will auto-download it via the HuggingFace dataset name. +# ======================================================================== +# +# ===================== Step 3: Evaluation ===================== +# After training, evaluate on GSM8K test set: +# +# CUDA_VISIBLE_DEVICES=0 swift eval \ +# --model /checkpoint-xxx \ +# --eval_backend OpenCompass \ +# --infer_backend vllm \ +# --eval_dataset gsm8k +# +# Compare with the base model to verify improvement: +# CUDA_VISIBLE_DEVICES=0 swift eval \ +# --model Qwen/Qwen2.5-1.5B-Instruct \ +# --eval_backend OpenCompass \ +# --infer_backend vllm \ +# --eval_dataset gsm8k +# ======================================================================== + +TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} +GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64} + +CUDA_VISIBLE_DEVICES=1 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +swift rlhf \ + --rlhf_type gkd \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --teacher_model_server $TEACHER_SERVER_URL \ + --gkd_logits_topk $GKD_LOGITS_TOPK \ + --tuner_type lora \ + --lora_rank 64 \ + --lora_alpha 128 \ + --dataset 'openai/gsm8k#train' \ + --seq_kd false \ + --lmbda 0 \ + --beta 0.5 \ + --torch_dtype bfloat16 \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --learning_rate 5e-5 \ + --gradient_accumulation_steps 8 \ + --eval_steps 200 \ + --save_steps 200 \ + --save_total_limit 3 \ + --logging_steps 5 \ + --max_length 1024 \ + --warmup_ratio 0.05 \ + --save_only_model true \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 \ + --deepspeed zero2 \ + --attn_impl flash_attn diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh index 94183a2504..04b2691c1c 100644 --- a/examples/train/rlhf/gkd/teacher_server.sh +++ b/examples/train/rlhf/gkd/teacher_server.sh @@ -1,21 +1,32 @@ -# GKD Training with External Teacher Model Server +# GKD Training with External Teacher Model Server (vLLM) # # This script demonstrates using an external vLLM server as the teacher model -# for knowledge distillation. +# for knowledge distillation. The teacher server provides prompt_logprobs via +# the /v1/completions endpoint, which requires native vLLM serving (vllm serve). +# +# NOTE: Only `vllm serve` is supported as the teacher server backend, because +# the training code sends raw token IDs via the `prompt` field and uses the +# `prompt_logprobs` parameter in the /v1/completions API. This is a vLLM-native +# feature not available through swift deploy. -# Teacher Server Setup (run in a separate gpu): -# CUDA_VISIBLE_DEVICES=5 swift deploy \ -# --model Qwen/Qwen2.5-14B-Instruct \ -# --infer_backend vllm \ -# --port 8000 \ -# --vllm_engine_kwargs '{"max_logprobs": 64}' +# ===================== Step 1: Start Teacher Server ===================== +# Run in a separate terminal / GPU: +# +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ +# --port 8000 \ +# --max-logprobs 64 \ +# --gpu-memory-utilization 0.9 +# +# Wait until the server is ready (shows "Uvicorn running on ..."). +# Verify with: curl http://localhost:8000/v1/models +# ======================================================================== -TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"} +TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64} NPROC_PER_NODE=4 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=1,2,3,4 \ swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B \ diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 757f6e5d31..1d06131b34 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -285,8 +285,7 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O teacher_logits = teacher_logits.detach() if topk is not None and teacher_logits is not None: - scaled = teacher_logits / self.temperature - topk_logits, topk_indices = torch.topk(scaled, k=topk, dim=-1) + topk_logits, topk_indices = torch.topk(teacher_logits, k=topk, dim=-1) encoded_batch['teacher_api_logprobs'] = topk_logits encoded_batch['teacher_api_indices'] = topk_indices encoded_batch['teacher_logits'] = None @@ -295,12 +294,16 @@ def _compute_teacher_logits_local(self, encoded_batches: List[Dict], vp_stage: O def _compute_teacher_logits_from_api(self, encoded_batches: List[Dict]) -> None: """Fetch teacher logprobs from external API service.""" - from swift.rlhf_trainers.teacher_api_client import fetch_teacher_logprobs + from swift.rlhf_trainers.gkd_trainer import fetch_teacher_logprobs topk = self.gkd_logits_topk for encoded_batch in encoded_batches: input_ids = encoded_batch['input_ids'] teacher_logprobs, teacher_indices = fetch_teacher_logprobs( self.teacher_model_server, input_ids.tolist(), topk=topk) + # fetch_teacher_logprobs returns [batch, seq_len-1, topk] (shifted). + # Pad last position with -inf to match student [batch, seq_len, topk]. + teacher_logprobs = F.pad(teacher_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_indices = F.pad(teacher_indices, (0, 0, 0, 1), value=0) encoded_batch['teacher_api_logprobs'] = teacher_logprobs.to(input_ids.device) encoded_batch['teacher_api_indices'] = teacher_indices.to(input_ids.device) encoded_batch['teacher_logits'] = None @@ -474,14 +477,14 @@ def generalized_jsd_loss( def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, mask, beta): """Compute JSD on teacher's top-k distribution. - Handles both local top-k (raw logits) and API top-k (raw logprobs) by - normalizing both teacher and student over the top-k subset via log_softmax. + Both local and API teacher are handled uniformly: gather student logits at + teacher's top-k indices, scale by 1/T, and log_softmax over top-k subset. + By shift-invariance of log_softmax, this gives identical results whether + teacher_topk_logprobs contains raw logits (local) or raw logprobs (API). """ s_scaled = student_logits / self.temperature s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) - - # Normalize both over top-k subset (handles both raw logits and API logprobs) - t_log_p = F.log_softmax(teacher_topk_logprobs, dim=-1) + t_log_p = F.log_softmax(teacher_topk_logprobs / self.temperature, dim=-1) s_log_p = F.log_softmax(s_topk, dim=-1) t_p = torch.exp(t_log_p) diff --git a/swift/rlhf_trainers/__init__.py b/swift/rlhf_trainers/__init__.py index 87a05f1831..262fb99e7e 100644 --- a/swift/rlhf_trainers/__init__.py +++ b/swift/rlhf_trainers/__init__.py @@ -15,7 +15,6 @@ from .ppo_trainer import PPOTrainer from .reward_trainer import RewardTrainer from .rlhf_mixin import RLHFTrainerMixin - from .teacher_api_client import fetch_teacher_logprobs from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, round_robin from .vllm_client import VLLMClient else: @@ -32,7 +31,6 @@ 'args_mixin': ['VllmArguments', 'GRPOArgumentsMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], 'vllm_client': ['VLLMClient'], - 'teacher_api_client': ['fetch_teacher_logprobs'], 'arguments': ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig'] } diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 576e00a1aa..06825fa38d 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -51,6 +51,9 @@ class DataSource(str, Enum): DATASET = 'dataset' # Off-policy: use dataset responses +teacher_model_server_model_name = None + + class GKDTrainer(RolloutTrainerMixin, SwiftMixin, HFGKDTrainer): def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): @@ -74,21 +77,13 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non # Initialize logging components self._prepare_logging() - # Initialize liger loss (only when not using top-k mode) - if self.gkd_logits_topk is None: - self._prepare_liger_loss() - else: - self.use_liger_gkd_loss = False - logger.info(f'Using top-k logits (k={self.gkd_logits_topk}) for KL computation, liger loss disabled.') + # Initialize liger loss if enabled + self._prepare_liger_loss() self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation self.is_teacher_ds3 = None - self.teacher_model = None - - # Initialize teacher model (skip if using API) - if not self.use_teacher_api: - if teacher_model is None: - raise ValueError('teacher_model is required when not using teacher_model_server') + # Initialize teacher model + if teacher_model is not None: if self.is_deepspeed_enabled: if teacher_deepspeed_config is not None: self.is_teacher_ds3 = teacher_deepspeed_config.get('zero_optimization', {}).get('stage') == 3 @@ -106,8 +101,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.teacher_model.eval() if self.args.offload_teacher_model: self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) - else: - logger.info(f'Using teacher model API for logprobs, top_logprobs={self.gkd_logits_topk}') # Initialize rollout infrastructure for vLLM support self.prepare_rollout() @@ -119,7 +112,6 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) else: self.maybe_activation_offload_context = nullcontext() - self._trl_version_gte_0_24 = version.parse(trl.__version__) >= version.parse('0.24') # Initialize resample data iterator for truncation_strategy 'raise'('delete') if self.template.truncation_strategy == 'raise': @@ -252,20 +244,25 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N model_inputs['labels'] = inputs['labels'] outputs_student = model(**model_inputs) - # Align teacher API logprobs with student output when logits_to_keep is used - labels = inputs['labels'] + # teacher_api shape: [batch, seq_len-1, topk] + # teacher[i] = P(token[i+1] | token[0..i]), matching logits[i]. + # But teacher has seq_len-1 positions (no logits[seq_len-1] equivalent). + # Pad a -inf row at the end so teacher becomes [batch, seq_len, topk] + # (the last position will be masked out by shifted_labels = -100). + teacher_api_logprobs = F.pad(teacher_api_logprobs, (0, 0, 0, 1), value=float('-inf')) + teacher_api_indices = F.pad(teacher_api_indices, (0, 0, 0, 1), value=0) + # Now teacher is [batch, seq_len, topk], same as full logits. + # Apply logits_to_keep to truncate teacher the same way model truncates logits. logits_to_keep = inputs.get('logits_to_keep') if logits_to_keep is not None: if isinstance(logits_to_keep, torch.Tensor) and logits_to_keep.dtype == torch.bool: teacher_api_logprobs = teacher_api_logprobs[:, logits_to_keep] teacher_api_indices = teacher_api_indices[:, logits_to_keep] - labels = labels[:, logits_to_keep] else: n = logits_to_keep.item() if isinstance(logits_to_keep, torch.Tensor) else int(logits_to_keep) teacher_api_logprobs = teacher_api_logprobs[:, -n:] teacher_api_indices = teacher_api_indices[:, -n:] - labels = labels[:, -n:] - shifted_labels = torch.roll(labels, shifts=-1, dims=1) + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) loss = self.generalized_jsd_loss( student_logits=outputs_student.logits, @@ -469,7 +466,6 @@ def _fetch_teacher_logprobs_from_api(self, encoded_inputs: Dict[str, torch.Tenso Returns: Tuple of (teacher_logprobs, teacher_indices) tensors with shapes [batch, seq_len, topk] """ - from .teacher_api_client import fetch_teacher_logprobs input_ids = encoded_inputs['input_ids'] teacher_logprobs, teacher_indices = fetch_teacher_logprobs( self.teacher_model_server, input_ids.tolist(), topk=self.gkd_logits_topk) @@ -549,7 +545,7 @@ def _prepare_liger_loss(self): raise ImportError( 'Liger kernel is not installed. Please install liger-kernel by running: pip install liger-kernel') assert self.args.sft_alpha == 0, 'SFT loss is not supported with liger loss' - + assert self.gkd_logits_topk is None, 'Top-k mode is not supported with liger loss' self.liger_jsd_loss = LigerFusedLinearJSDLoss( beta=self.beta, ignore_index=-100, @@ -571,19 +567,25 @@ def generalized_jsd_loss( teacher_topk_indices=None, ): # Top-k mode: reduce logits to [*, k] before the standard JSD pipeline - if topk is not None and teacher_logits is not None: + if teacher_topk_logprobs is not None and teacher_topk_indices is not None: + # API teacher: gather student logits at teacher's top-k indices, then both + # get scaled by 1/T and re-normalized over top-k via downstream log_softmax. + # vLLM logprobs = log_softmax(logits) at T=1; treating them as logit-like + # scores and dividing by T then re-normalizing is equivalent to + # log_softmax(logits/T) over top-k (by shift-invariance of softmax). + s_scaled = student_logits / temperature + student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) + teacher_logits = teacher_topk_logprobs / temperature + del s_scaled + temperature = 1.0 + elif topk is not None and teacher_logits is not None: + # Local teacher: select top-k from teacher, gather corresponding student logits t_scaled = teacher_logits / temperature s_scaled = student_logits / temperature teacher_logits, topk_idx = torch.topk(t_scaled, k=topk, dim=-1) student_logits = torch.gather(s_scaled, dim=-1, index=topk_idx) del t_scaled, s_scaled, topk_idx temperature = 1.0 - elif teacher_topk_logprobs is not None and teacher_topk_indices is not None: - s_scaled = student_logits / temperature - student_logits = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) - teacher_logits = teacher_topk_logprobs - del s_scaled - temperature = 1.0 student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature @@ -707,3 +709,83 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non row = [table[header][i] for header in headers] rows.append(row) swanlab.log({'completions': swanlab.echarts.Table().add(headers, rows)}) + + +def fetch_teacher_logprobs(base_url, input_ids, topk=20, timeout=300.0): + """Fetch top-k prompt logprobs from a vLLM-compatible /v1/completions endpoint. + + Uses prompt_logprobs to get logprobs for input tokens without generating. + vLLM prompt_logprobs are always raw (temperature=1) log-probabilities from the model; + the temperature parameter in the API only affects token sampling, not prompt_logprobs. + + Args: + base_url: vLLM server URL (e.g., 'http://localhost:8000'). + input_ids: List of token ID sequences. + topk: Number of top log probabilities per token. + timeout: Request timeout in seconds. + + Returns: + (logprobs, indices) tensors of shape [batch, max_seq_len - 1, topk]. + The shift is because prompt_logprobs[0] is always None (first token has no + conditional probability), so position i in the output corresponds to + P(token_{i+1} | token_0..token_i), aligning with model logits[i]. + """ + import logging + import requests + from concurrent.futures import ThreadPoolExecutor + + _logger = logging.getLogger(__name__) + base_url = base_url.rstrip('/') + batch_size = len(input_ids) + max_seq_len = max(len(ids) for ids in input_ids) + url = f'{base_url}/v1/completions' + global teacher_model_server_model_name + if teacher_model_server_model_name is None: + try: + resp = requests.get(f'{base_url}/v1/models', timeout=10) + model = resp.json()['data'][0]['id'] if resp.ok else 'default' + except Exception: + model = 'default' + teacher_model_server_model_name = model + else: + model = teacher_model_server_model_name + + # prompt_logprobs[0] is always None (no conditional prob for the first token), + # prompt_logprobs[i] = P(token_i | token_0..token_{i-1}) which aligns with logits[i-1]. + # So we skip position 0 and the result has shape [batch, max_seq_len-1, topk], + # aligning with student logits which predict the next token at each position. + out_len = max_seq_len - 1 + logprobs_out = torch.full((batch_size, out_len, topk), float('-inf'), dtype=torch.float32) + indices_out = torch.zeros((batch_size, out_len, topk), dtype=torch.long) + + def _fetch_one(batch_idx): + payload = { + 'model': model, + 'prompt': input_ids[batch_idx], + 'max_tokens': 1, + 'temperature': 0, + 'prompt_logprobs': topk, + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + prompt_logprobs_list = resp.json()['choices'][0].get('prompt_logprobs', []) + # Skip position 0 (always None), shift left so pos 1 -> output pos 0 + for raw_pos in range(1, len(prompt_logprobs_list)): + pos_lp = prompt_logprobs_list[raw_pos] + if pos_lp is None: + continue + out_pos = raw_pos - 1 + if out_pos >= out_len: + break + sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1]['logprob'])[:topk] + for k, (tid_str, info) in enumerate(sorted_items): + indices_out[batch_idx, out_pos, k] = int(tid_str) + logprobs_out[batch_idx, out_pos, k] = info['logprob'] + except Exception as e: + _logger.error(f'Failed to get teacher logprobs for sequence {batch_idx}: {e}') + + with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: + list(pool.map(_fetch_one, range(batch_size))) + + return logprobs_out, indices_out diff --git a/swift/rlhf_trainers/teacher_api_client.py b/swift/rlhf_trainers/teacher_api_client.py deleted file mode 100644 index 22d5c5c24b..0000000000 --- a/swift/rlhf_trainers/teacher_api_client.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fetch teacher model logprobs from OpenAI-compatible endpoints.""" -import logging -import requests -import torch -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple - -logger = logging.getLogger(__name__) - -_model_name_cache: dict = {} - - -def _get_model_name(base_url: str) -> str: - if base_url not in _model_name_cache: - try: - resp = requests.get(f'{base_url}/v1/models', timeout=10) - if resp.ok and resp.json().get('data'): - _model_name_cache[base_url] = resp.json()['data'][0]['id'] - except Exception as e: - logger.warning(f'Failed to detect model name: {e}') - if base_url not in _model_name_cache: - _model_name_cache[base_url] = 'default' - return _model_name_cache[base_url] - - -def fetch_teacher_logprobs( - base_url: str, - input_ids: List[List[int]], - topk: int = 20, - timeout: float = 300.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Fetch top-k logprobs from an OpenAI-compatible completions API. - - Args: - base_url: Server URL (e.g., 'http://localhost:8000'). - input_ids: List of token ID sequences. - topk: Number of top log probabilities per token. - timeout: Request timeout in seconds. - - Returns: - (logprobs, indices) tensors of shape [batch, max_seq_len, topk]. - """ - base_url = base_url.rstrip('/') - batch_size = len(input_ids) - max_seq_len = max(len(ids) for ids in input_ids) - url = f'{base_url}/v1/completions' - model = _get_model_name(base_url) - - logprobs_out = torch.full((batch_size, max_seq_len, topk), float('-inf'), dtype=torch.float32) - indices_out = torch.zeros((batch_size, max_seq_len, topk), dtype=torch.long) - - def _fetch_one(batch_idx: int): - payload = { - 'model': model, - 'prompt': input_ids[batch_idx], - 'max_tokens': 0, - 'temperature': 0, - 'logprobs': topk, - 'echo': True, - } - try: - resp = requests.post(url, json=payload, timeout=timeout) - resp.raise_for_status() - top_logprobs_list = resp.json()['choices'][0].get('logprobs', {}).get('top_logprobs', []) - for pos, pos_lp in enumerate(top_logprobs_list): - if pos_lp is None: - continue - sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1])[:topk] - for k, (tid_str, lp) in enumerate(sorted_items): - indices_out[batch_idx, pos, k] = int(tid_str) - logprobs_out[batch_idx, pos, k] = lp - except Exception as e: - logger.error(f'Failed to get logprobs for sequence {batch_idx}: {e}') - - with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: - list(pool.map(_fetch_one, range(batch_size))) - - return logprobs_out, indices_out From f23b71e08b480281ba3dda750b4a47d4dca1be6e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 3 Mar 2026 14:39:52 +0800 Subject: [PATCH 08/10] update --- swift/megatron/trainers/gkd_trainer.py | 31 ++++++++++++++++++-------- swift/rlhf_trainers/gkd_trainer.py | 2 +- swift/rlhf_trainers/utils.py | 10 +++++++-- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 1d06131b34..8a9acab471 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -41,16 +41,17 @@ def __init__(self, args: MegatronArguments, template, **kwargs): self.lmbda = args.lmbda # On-policy probability self.seq_kd = args.seq_kd # Sequential KD: use teacher-generated responses self.offload_teacher_model = args.offload_teacher_model # Offload teacher to CPU - self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_') - self.teacher_config = self.teacher_bridge.processor.model_info.config + self.teacher_model_server = getattr(args, 'teacher_model_server', None) + self.use_teacher_api = self.teacher_model_server is not None + if args.teacher_model: + self.teacher_bridge = args.megatron_model_meta.bridge_cls(args, attr_prefix='teacher_') + self.teacher_config = self.teacher_bridge.processor.model_info.config self.sft_alpha = getattr(args, 'sft_alpha', 0.0) # Weight for SFT loss # GKD top-k logits configuration self.gkd_logits_topk = getattr(args, 'gkd_logits_topk', None) # Check use_teacher_api based on args, not client existence # (API client is only created on last rank, but all ranks need to know the mode) - self.teacher_model_server = getattr(args, 'teacher_model_server', None) - self.use_teacher_api = self.teacher_model_server is not None # Validate teacher configuration if not self.use_teacher_api: @@ -100,12 +101,12 @@ def prepare_model(self): def _offload_teacher_models(self): """Offload teacher models to CPU to save GPU memory.""" - if self.teacher_models: + if self.teacher_models and not self.use_teacher_api: offload_megatron_model_to_cpu(self.teacher_models) def _load_teacher_models_to_gpu(self): """Load teacher models back to GPU.""" - if self.teacher_models: + if self.teacher_models and not self.use_teacher_api: load_megatron_model_to_gpu(self.teacher_models, load_grad=False) @contextmanager @@ -481,11 +482,23 @@ def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, teacher's top-k indices, scale by 1/T, and log_softmax over top-k subset. By shift-invariance of log_softmax, this gives identical results whether teacher_topk_logprobs contains raw logits (local) or raw logprobs (API). + + Masked positions are filtered out BEFORE log_softmax to avoid NaN from + all-(-inf) rows in API teacher padding. """ s_scaled = student_logits / self.temperature s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) - t_log_p = F.log_softmax(teacher_topk_logprobs / self.temperature, dim=-1) - s_log_p = F.log_softmax(s_topk, dim=-1) + t_topk = teacher_topk_logprobs / self.temperature + + # Filter to valid positions first to avoid NaN from -inf padding rows + s_topk_masked = s_topk[mask] + t_topk_masked = t_topk[mask] + + if s_topk_masked.numel() == 0: + return student_logits.new_zeros(()) + + t_log_p = F.log_softmax(t_topk_masked, dim=-1) + s_log_p = F.log_softmax(s_topk_masked, dim=-1) t_p = torch.exp(t_log_p) if beta == 0: @@ -498,7 +511,7 @@ def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, m_log_p = torch.log(beta * t_p + (1 - beta) * s_p + 1e-10) jsd = beta * (t_p * (t_log_p - m_log_p)).sum(-1) + (1 - beta) * (s_p * (s_log_p - m_log_p)).sum(-1) - return (jsd * mask.float()).sum() + return jsd.sum() def loss_func(self, output_tensor: torch.Tensor, diff --git a/swift/rlhf_trainers/gkd_trainer.py b/swift/rlhf_trainers/gkd_trainer.py index 06825fa38d..f02672d57b 100644 --- a/swift/rlhf_trainers/gkd_trainer.py +++ b/swift/rlhf_trainers/gkd_trainer.py @@ -234,7 +234,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # loss / grad norm is unexpectedly large, normalize by sequence length # https://github.com/linkedin/Liger-Kernel/blob/v0.6.3/src/liger_kernel/chunked_loss/jsd_loss.py#L9-L39 - loss /= student_hidden.shape[1] + # loss /= student_hidden.shape[1] # Release hidden states after loss computation del student_hidden, teacher_hidden, true_labels outputs_student = None diff --git a/swift/rlhf_trainers/utils.py b/swift/rlhf_trainers/utils.py index 2484df4e77..402bf7dae6 100644 --- a/swift/rlhf_trainers/utils.py +++ b/swift/rlhf_trainers/utils.py @@ -566,10 +566,16 @@ def profiling_context(trainer, name: str): profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} - if 'wandb' in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: + is_main_process = False + if hasattr(trainer, 'accelerator'): + is_main_process = trainer.accelerator.is_main_process + elif hasattr(trainer, 'is_main_process'): + is_main_process = trainer.is_main_process + + if 'wandb' in trainer.args.report_to and wandb.run is not None and is_main_process: wandb.log(profiling_metrics) - if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and trainer.accelerator.is_main_process: + if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and is_main_process: swanlab.log(profiling_metrics) From 17f88e42ba64c87f459a91d752292235c389e537 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 3 Mar 2026 14:47:03 +0800 Subject: [PATCH 09/10] update script --- docs/source/Instruction/GKD.md | 6 +-- examples/megatron/rlhf/gkd/teacher_server.sh | 41 ++++++++------- examples/train/rlhf/gkd/teacher_server.sh | 54 +++++++------------- swift/megatron/trainers/gkd_trainer.py | 3 -- 4 files changed, 39 insertions(+), 65 deletions(-) diff --git a/docs/source/Instruction/GKD.md b/docs/source/Instruction/GKD.md index b276bd3860..f8d455a93a 100644 --- a/docs/source/Instruction/GKD.md +++ b/docs/source/Instruction/GKD.md @@ -143,7 +143,7 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) | 参数 | 类型 | 默认值 | 取值范围 | 说明 | |------|------|--------|---------|------| -| `--teacher_model` | str | None | - | 教师模型路径或模型 ID
*使用 `teacher_model_server` 时可省略 | +| `--teacher_model` | str | None | - | 教师模型路径或模型 ID | | `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数
• 0.0: Forward KL
• 0.5: JSD (平衡)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率
• 0.0: 离线学习
• 0.5: 混合策略
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | True/False | 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成 | @@ -195,10 +195,6 @@ swift rlhf \ | `--teacher_model_server` | str | None | 教师模型服务地址
如:`http://localhost:8000` | | `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 | -**支持的后端**: -- `vllm serve`(推荐) - -> **注意**:仅支持 `vllm serve` 作为教师服务后端。训练代码通过 `/v1/completions` 接口直接传递 token IDs 并使用 `prompt_logprobs` 参数获取输入 token 的 log 概率,这是 vLLM 原生支持的功能。 **步骤 1:部署教师模型服务** diff --git a/examples/megatron/rlhf/gkd/teacher_server.sh b/examples/megatron/rlhf/gkd/teacher_server.sh index 72a82ba9b4..2ccb05b8a1 100644 --- a/examples/megatron/rlhf/gkd/teacher_server.sh +++ b/examples/megatron/rlhf/gkd/teacher_server.sh @@ -1,44 +1,43 @@ -# GKD Training with External Teacher Model Server (Megatron) -# -# Start teacher server first (in a separate terminal / GPU): -# CUDA_VISIBLE_DEVICES=4 vllm serve Qwen/Qwen3-8B --port 8000 --max-logprobs 64 +# Teacher server must be running first: +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct --port 8000 --max-logprobs 64 -CUDA_VISIBLE_DEVICES=0,1,2,3 \ -NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=1,2 \ +NPROC_PER_NODE=2 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ megatron rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen3-8B-Base \ + --model Qwen/Qwen2.5-0.5B \ --teacher_model_server http://localhost:8000 \ - --gkd_logits_topk 20 \ - --tuner_type lora \ - --dataset AI-ModelScope/alpaca-gpt4-data-en#2000 AI-ModelScope/alpaca-gpt4-data-zh#2000 \ + --gkd_logits_topk 64 \ + --dataset 'modelscope/gsm8k' \ --tensor_model_parallel_size 1 \ - --expert_model_parallel_size 1 \ --pipeline_model_parallel_size 1 \ --context_parallel_size 1 \ - --seq_kd false \ + --expert_model_parallel_size 1 \ --lmbda 1 \ - --beta 1 \ + --seq_kd false \ + --beta 0.5 \ --torch_dtype bfloat16 \ --micro_batch_size 2 \ - --global_batch_size 16 \ - --max_epochs 1 \ + --global_batch_size 32 \ + --train_iters 500 \ --lr 5e-5 \ - --log_interval 1 \ - --max_length 8192 \ - --max_completion_length 8192 \ + --lr_warmup_fraction 0.1 \ + --logging_steps 1 \ + --save_steps 100 \ + --save_total_limit 10 \ + --max_length 2048 \ + --max_completion_length 2048 \ --attention_backend flash \ --use_vllm true \ --vllm_mode colocate \ --vllm_gpu_memory_utilization 0.5 \ --vllm_tensor_parallel_size 1 \ - --vllm_max_model_len 16384 \ + --vllm_max_model_len 4096 \ --sleep_level 1 \ - --recompute_granularity selective \ --finetune \ --no_save_optim \ --no_save_rng \ --temperature 1.0 \ --padding_free true \ - --sequence_parallel true + --recompute_granularity selective diff --git a/examples/train/rlhf/gkd/teacher_server.sh b/examples/train/rlhf/gkd/teacher_server.sh index 04b2691c1c..cbb56680f9 100644 --- a/examples/train/rlhf/gkd/teacher_server.sh +++ b/examples/train/rlhf/gkd/teacher_server.sh @@ -1,62 +1,44 @@ # GKD Training with External Teacher Model Server (vLLM) -# -# This script demonstrates using an external vLLM server as the teacher model -# for knowledge distillation. The teacher server provides prompt_logprobs via -# the /v1/completions endpoint, which requires native vLLM serving (vllm serve). -# -# NOTE: Only `vllm serve` is supported as the teacher server backend, because -# the training code sends raw token IDs via the `prompt` field and uses the -# `prompt_logprobs` parameter in the /v1/completions API. This is a vLLM-native -# feature not available through swift deploy. - # ===================== Step 1: Start Teacher Server ===================== # Run in a separate terminal / GPU: # -# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ +# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \ # --port 8000 \ # --max-logprobs 64 \ # --gpu-memory-utilization 0.9 -# -# Wait until the server is ready (shows "Uvicorn running on ..."). -# Verify with: curl http://localhost:8000/v1/models -# ======================================================================== -TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} -GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64} +# ======================================================================== NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ -CUDA_VISIBLE_DEVICES=1,2,3,4 \ swift rlhf \ --rlhf_type gkd \ - --model Qwen/Qwen2.5-7B \ - --teacher_model_server $TEACHER_SERVER_URL \ + --model Qwen/Qwen2.5-0.5B \ + --teacher_model_server http://localhost:8000 \ + --gkd_logits_topk 64 \ --use_vllm true \ --vllm_mode colocate \ --vllm_gpu_memory_utilization 0.5 \ --vllm_tensor_parallel_size 1 \ - --vllm_max_model_len 10240 \ - --gkd_logits_topk $GKD_LOGITS_TOPK \ - --tuner_type lora \ - --dataset 'AI-ModelScope/alpaca-gpt4-data-en' \ - --seq_kd false \ + --vllm_max_model_len 4096 \ + --sleep_level 0 \ + --dataset 'modelscope/gsm8k' \ --lmbda 1 \ - --beta 1 \ + --seq_kd false \ + --beta 0.5 \ --torch_dtype bfloat16 \ - --max_epochs 1 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 1 \ - --learning_rate 1e-5 \ + --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ - --eval_steps 500 \ - --save_steps 500 \ + --learning_rate 5e-5 \ + --logging_steps 1 \ + --save_steps 100 \ --save_total_limit 2 \ - --logging_steps 5 \ --max_length 2048 \ --max_completion_length 2048 \ - --warmup_ratio 0.05 \ + --warmup_ratio 0.1 \ --save_only_model true \ --dataloader_num_workers 4 \ --dataset_num_proc 4 \ - --deepspeed zero2 \ - --attn_impl flash_attn + --attn_impl flash_attn \ + --report_to tensorboard swanlab diff --git a/swift/megatron/trainers/gkd_trainer.py b/swift/megatron/trainers/gkd_trainer.py index 8a9acab471..09caf85106 100644 --- a/swift/megatron/trainers/gkd_trainer.py +++ b/swift/megatron/trainers/gkd_trainer.py @@ -483,14 +483,11 @@ def _jsd_topk(self, student_logits, teacher_topk_logprobs, teacher_topk_indices, By shift-invariance of log_softmax, this gives identical results whether teacher_topk_logprobs contains raw logits (local) or raw logprobs (API). - Masked positions are filtered out BEFORE log_softmax to avoid NaN from - all-(-inf) rows in API teacher padding. """ s_scaled = student_logits / self.temperature s_topk = torch.gather(s_scaled, dim=-1, index=teacher_topk_indices) t_topk = teacher_topk_logprobs / self.temperature - # Filter to valid positions first to avoid NaN from -inf padding rows s_topk_masked = s_topk[mask] t_topk_masked = t_topk[mask] From 51dd414bbd167348b56796d7e1d12571e0f8ef4a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 3 Mar 2026 14:51:15 +0800 Subject: [PATCH 10/10] update --- .../train/rlhf/gkd/gsm8k_teacher_server.sh | 77 ------------------- swift/arguments/rlhf_args.py | 3 + 2 files changed, 3 insertions(+), 77 deletions(-) delete mode 100644 examples/train/rlhf/gkd/gsm8k_teacher_server.sh diff --git a/examples/train/rlhf/gkd/gsm8k_teacher_server.sh b/examples/train/rlhf/gkd/gsm8k_teacher_server.sh deleted file mode 100644 index 3c848b8814..0000000000 --- a/examples/train/rlhf/gkd/gsm8k_teacher_server.sh +++ /dev/null @@ -1,77 +0,0 @@ -# GKD on GSM8K: Teacher Server Mode with Top-K Logits -# -# This script validates GKD effectiveness on mathematical reasoning using GSM8K. -# Student: Qwen2.5-1.5B-Instruct, Teacher: Qwen2.5-7B-Instruct (via vllm serve) -# -# Expected outcome: GSM8K accuracy should improve after GKD training, as the student -# learns the teacher's reasoning distribution on math problems. -# -# ===================== Step 1: Start Teacher Server ===================== -# Run in a separate terminal / GPU: -# -# CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-7B-Instruct \ -# --port 8000 \ -# --max-logprobs 64 \ -# --gpu-memory-utilization 0.9 -# -# Wait until the server is ready, then verify: -# curl http://localhost:8000/v1/models -# ======================================================================== -# -# ===================== Step 2: Prepare GSM8K Dataset ===================== -# The dataset uses the standard GSM8K train split from Hugging Face: -# openai/gsm8k (7473 training samples) -# Swift will auto-download it via the HuggingFace dataset name. -# ======================================================================== -# -# ===================== Step 3: Evaluation ===================== -# After training, evaluate on GSM8K test set: -# -# CUDA_VISIBLE_DEVICES=0 swift eval \ -# --model /checkpoint-xxx \ -# --eval_backend OpenCompass \ -# --infer_backend vllm \ -# --eval_dataset gsm8k -# -# Compare with the base model to verify improvement: -# CUDA_VISIBLE_DEVICES=0 swift eval \ -# --model Qwen/Qwen2.5-1.5B-Instruct \ -# --eval_backend OpenCompass \ -# --infer_backend vllm \ -# --eval_dataset gsm8k -# ======================================================================== - -TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} -GKD_LOGITS_TOPK=${GKD_LOGITS_TOPK:-64} - -CUDA_VISIBLE_DEVICES=1 \ -PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ -swift rlhf \ - --rlhf_type gkd \ - --model Qwen/Qwen2.5-1.5B-Instruct \ - --teacher_model_server $TEACHER_SERVER_URL \ - --gkd_logits_topk $GKD_LOGITS_TOPK \ - --tuner_type lora \ - --lora_rank 64 \ - --lora_alpha 128 \ - --dataset 'openai/gsm8k#train' \ - --seq_kd false \ - --lmbda 0 \ - --beta 0.5 \ - --torch_dtype bfloat16 \ - --num_train_epochs 3 \ - --per_device_train_batch_size 2 \ - --per_device_eval_batch_size 2 \ - --learning_rate 5e-5 \ - --gradient_accumulation_steps 8 \ - --eval_steps 200 \ - --save_steps 200 \ - --save_total_limit 3 \ - --logging_steps 5 \ - --max_length 1024 \ - --warmup_ratio 0.05 \ - --save_only_model true \ - --dataloader_num_workers 4 \ - --dataset_num_proc 4 \ - --deepspeed zero2 \ - --attn_impl flash_attn diff --git a/swift/arguments/rlhf_args.py b/swift/arguments/rlhf_args.py index a66f02d908..50017b92ad 100644 --- a/swift/arguments/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -582,3 +582,6 @@ def _check_gkd(self): if self.gkd_logits_topk is not None and self.use_liger_kernel: raise ValueError('gkd_logits_topk is not supported when using liger kernel') + + if self.teacher_model_server and self.seq_kd: + raise NotImplementedError('Sequential KD is not supported when using teacher_model_server')