Skip to content

[gkd] top-k-logits & teacher server #7918

Merged
hjh0119 merged 13 commits intomodelscope:mainfrom
hjh0119:gkd_top_k_logits
Mar 3, 2026
Merged

[gkd] top-k-logits & teacher server #7918
hjh0119 merged 13 commits intomodelscope:mainfrom
hjh0119:gkd_top_k_logits

Conversation

@hjh0119
Copy link
Collaborator

@hjh0119 hjh0119 commented Jan 27, 2026

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hjh0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant enhancements to the Generalized Knowledge Distillation (GKD) framework by optimizing resource utilization and increasing deployment flexibility. It enables GKD training to be more efficient by allowing the use of a subset of teacher model logits and by supporting external teacher model services, which is crucial for scaling up to very large teacher models and distributed training environments.

Highlights

  • Top-K Logits for GKD: Introduced a 'Top-K' mode for Generalized Knowledge Distillation (GKD) to compute KL divergence using only the top-K most probable tokens from the teacher model. This significantly reduces memory usage and computational overhead, especially for models with large vocabularies.
  • External Teacher Model API Support: Added functionality to use an external API service (e.g., vLLM) as the teacher model. This allows for decoupling the teacher model from the training process, enabling the use of larger teacher models or shared teacher services, and providing more flexible deployment options.
  • Comprehensive Documentation and Examples: Updated both Chinese and English documentation with detailed explanations of the Top-K mode and external teacher API, including mathematical principles, usage examples, and notes on limitations. New example shell scripts are provided for both standard and Megatron-SWIFT GKD training with an external teacher server.
  • Robust Argument Validation: Implemented new command-line arguments (--teacher_model_server, --gkd_logits_topk) and added validation logic to ensure correct configuration, such as requiring gkd_logits_topk when using an external server and preventing conflicts with other features like liger kernel.
  • Dedicated Teacher API Client: Developed a new TeacherAPIClient to handle asynchronous and synchronous fetching of log probabilities from OpenAI-compatible API endpoints, complete with response parsing and server health checks.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

此拉取请求通过引入对外部教师模型 API 和 Top-K KL 散度计算模式的支持,显著增强了 GKD 训练算法。这些改进提供了更大的灵活性和内存效率,尤其适用于大型教师模型。所有相关文档(中文和英文)都已更新,并提供了新的示例脚本。实现包括对参数配置的严格验证以及一个专用的 API 客户端,该客户端具有全面的解析和分布式通信逻辑。此外,还包含了 API 客户端的单元测试和集成测试,以确保新功能的可靠性。

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Mar 2, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements to GKD training by adding support for top-k logits and an external teacher server. These features improve memory efficiency and enable the use of very large teacher models. The implementation is robust, with changes spanning argument parsing, trainer logic, and a new API client. The documentation has also been updated clearly. My review includes a few suggestions to improve the new example scripts to prevent potential user confusion and misconfiguration.

Comment on lines +24 to +25
--max_length 8192 \
--max_completion_length 8192 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values for max_length and max_completion_length are both set to 8192. max_length represents the total sequence length (prompt + completion). Setting them to the same value implies that the prompt length must be zero to avoid truncation, which is usually not the case. This could lead to unexpected behavior or errors. To allow for non-empty prompts, max_completion_length should be smaller than max_length.

Suggested change
--max_length 8192 \
--max_completion_length 8192 \
--max_length 8192 \
--max_completion_length 4096 \

# --port 8000 \
# --vllm_engine_kwargs '{"max_logprobs": 64}'

TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a port mismatch between the example server setup in the comment (port 8000) and the default TEACHER_SERVER_URL (port 8001). To avoid confusion for users copying the setup commands, it would be best to make them consistent.

Suggested change
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"}
TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"}

Comment on lines +44 to +45
--max_length 2048 \
--max_completion_length 2048 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values for max_length and max_completion_length are both set to 2048. max_length represents the total sequence length (prompt + completion). Setting them to the same value implies that the prompt length must be zero to avoid truncation, which is usually not the case. This could lead to unexpected behavior. Please consider reducing max_completion_length to a smaller value to allow for non-empty prompts.

Suggested change
--max_length 2048 \
--max_completion_length 2048 \
--max_length 2048 \
--max_completion_length 1024 \

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Mar 3, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements to the GKD training algorithm by adding support for top-k logits computation and an external teacher model server. These features are crucial for improving efficiency and scalability, especially when dealing with large models. The implementation is well-structured, and the documentation updates are clear and comprehensive.

My review includes a high-severity suggestion to refactor the global state management for the teacher model name to improve code robustness and maintainability. I've also pointed out a minor grammatical correction in the documentation.

Comment on lines +714 to +791
def fetch_teacher_logprobs(base_url, input_ids, topk=20, timeout=300.0):
"""Fetch top-k prompt logprobs from a vLLM-compatible /v1/completions endpoint.

Uses prompt_logprobs to get logprobs for input tokens without generating.
vLLM prompt_logprobs are always raw (temperature=1) log-probabilities from the model;
the temperature parameter in the API only affects token sampling, not prompt_logprobs.

Args:
base_url: vLLM server URL (e.g., 'http://localhost:8000').
input_ids: List of token ID sequences.
topk: Number of top log probabilities per token.
timeout: Request timeout in seconds.

Returns:
(logprobs, indices) tensors of shape [batch, max_seq_len - 1, topk].
The shift is because prompt_logprobs[0] is always None (first token has no
conditional probability), so position i in the output corresponds to
P(token_{i+1} | token_0..token_i), aligning with model logits[i].
"""
import logging
import requests
from concurrent.futures import ThreadPoolExecutor

_logger = logging.getLogger(__name__)
base_url = base_url.rstrip('/')
batch_size = len(input_ids)
max_seq_len = max(len(ids) for ids in input_ids)
url = f'{base_url}/v1/completions'
global teacher_model_server_model_name
if teacher_model_server_model_name is None:
try:
resp = requests.get(f'{base_url}/v1/models', timeout=10)
model = resp.json()['data'][0]['id'] if resp.ok else 'default'
except Exception:
model = 'default'
teacher_model_server_model_name = model
else:
model = teacher_model_server_model_name

# prompt_logprobs[0] is always None (no conditional prob for the first token),
# prompt_logprobs[i] = P(token_i | token_0..token_{i-1}) which aligns with logits[i-1].
# So we skip position 0 and the result has shape [batch, max_seq_len-1, topk],
# aligning with student logits which predict the next token at each position.
out_len = max_seq_len - 1
logprobs_out = torch.full((batch_size, out_len, topk), float('-inf'), dtype=torch.float32)
indices_out = torch.zeros((batch_size, out_len, topk), dtype=torch.long)

def _fetch_one(batch_idx):
payload = {
'model': model,
'prompt': input_ids[batch_idx],
'max_tokens': 1,
'temperature': 0,
'prompt_logprobs': topk,
}
try:
resp = requests.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
prompt_logprobs_list = resp.json()['choices'][0].get('prompt_logprobs', [])
# Skip position 0 (always None), shift left so pos 1 -> output pos 0
for raw_pos in range(1, len(prompt_logprobs_list)):
pos_lp = prompt_logprobs_list[raw_pos]
if pos_lp is None:
continue
out_pos = raw_pos - 1
if out_pos >= out_len:
break
sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1]['logprob'])[:topk]
for k, (tid_str, info) in enumerate(sorted_items):
indices_out[batch_idx, out_pos, k] = int(tid_str)
logprobs_out[batch_idx, out_pos, k] = info['logprob']
except Exception as e:
_logger.error(f'Failed to get teacher logprobs for sequence {batch_idx}: {e}')

with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool:
list(pool.map(_fetch_one, range(batch_size)))

return logprobs_out, indices_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of a global variable teacher_model_server_model_name to cache the model name from the teacher server is not ideal. Global state can lead to race conditions, makes the code harder to test, and can cause unexpected behavior, especially if multiple trainers were to run in the same process.

A better approach would be to manage this state within the trainer instance. I suggest the following refactoring:

  1. Modify fetch_teacher_logprobs to be a stateless utility function that accepts the model name as a parameter.
  2. Have the GKDTrainer and MegatronGKDTrainer classes be responsible for fetching and caching the model name in an instance attribute (e.g., self.teacher_model_server_model_name) during initialization.
  3. Pass this cached model name to fetch_teacher_logprobs when calling it.

This would look something like this:

# In swift/rlhf_trainers/gkd_trainer.py

# In GKDTrainer.__init__
self.teacher_model_server_model_name = None
if self.use_teacher_api:
    try:
        import requests
        resp = requests.get(f'{self.teacher_model_server.rstrip("/")}/v1/models', timeout=10)
        self.teacher_model_server_model_name = resp.json()['data'][0]['id'] if resp.ok else 'default'
    except Exception as e:
        logger.warning(f'Failed to get model name from teacher server, using "default". Error: {e}')
        self.teacher_model_server_model_name = 'default'

# In GKDTrainer._fetch_teacher_logprobs_from_api
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
    self.teacher_model_server,
    self.teacher_model_server_model_name,  # Pass cached model name
    input_ids.tolist(),
    topk=self.gkd_logits_topk
)

# Refactored fetch_teacher_logprobs
def fetch_teacher_logprobs(base_url, model_name, input_ids, topk=20, timeout=300.0):
    # ... remove global variable and model name fetching logic
    payload = {
        'model': model_name,
        # ...
    }
    # ...

A similar change would be needed in MegatronGKDTrainer.


**Top-K Mode Principle**:

In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor grammatical error here. "It use" should be "It uses".

Suggested change
In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It uses the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.

@hjh0119 hjh0119 merged commit 0d7c9f5 into modelscope:main Mar 3, 2026
2 of 3 checks passed
@hjh0119 hjh0119 deleted the gkd_top_k_logits branch March 3, 2026 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants