[gkd] top-k-logits & teacher server #7918
Conversation
Summary of ChangesHello @hjh0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant enhancements to the Generalized Knowledge Distillation (GKD) framework by optimizing resource utilization and increasing deployment flexibility. It enables GKD training to be more efficient by allowing the use of a subset of teacher model logits and by supporting external teacher model services, which is crucial for scaling up to very large teacher models and distributed training environments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements to GKD training by adding support for top-k logits and an external teacher server. These features improve memory efficiency and enable the use of very large teacher models. The implementation is robust, with changes spanning argument parsing, trainer logic, and a new API client. The documentation has also been updated clearly. My review includes a few suggestions to improve the new example scripts to prevent potential user confusion and misconfiguration.
| --max_length 8192 \ | ||
| --max_completion_length 8192 \ |
There was a problem hiding this comment.
The values for max_length and max_completion_length are both set to 8192. max_length represents the total sequence length (prompt + completion). Setting them to the same value implies that the prompt length must be zero to avoid truncation, which is usually not the case. This could lead to unexpected behavior or errors. To allow for non-empty prompts, max_completion_length should be smaller than max_length.
| --max_length 8192 \ | |
| --max_completion_length 8192 \ | |
| --max_length 8192 \ | |
| --max_completion_length 4096 \ |
| # --port 8000 \ | ||
| # --vllm_engine_kwargs '{"max_logprobs": 64}' | ||
|
|
||
| TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"} |
There was a problem hiding this comment.
There's a port mismatch between the example server setup in the comment (port 8000) and the default TEACHER_SERVER_URL (port 8001). To avoid confusion for users copying the setup commands, it would be best to make them consistent.
| TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8001"} | |
| TEACHER_SERVER_URL=${TEACHER_SERVER_URL:-"http://localhost:8000"} |
| --max_length 2048 \ | ||
| --max_completion_length 2048 \ |
There was a problem hiding this comment.
The values for max_length and max_completion_length are both set to 2048. max_length represents the total sequence length (prompt + completion). Setting them to the same value implies that the prompt length must be zero to avoid truncation, which is usually not the case. This could lead to unexpected behavior. Please consider reducing max_completion_length to a smaller value to allow for non-empty prompts.
| --max_length 2048 \ | |
| --max_completion_length 2048 \ | |
| --max_length 2048 \ | |
| --max_completion_length 1024 \ |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements to the GKD training algorithm by adding support for top-k logits computation and an external teacher model server. These features are crucial for improving efficiency and scalability, especially when dealing with large models. The implementation is well-structured, and the documentation updates are clear and comprehensive.
My review includes a high-severity suggestion to refactor the global state management for the teacher model name to improve code robustness and maintainability. I've also pointed out a minor grammatical correction in the documentation.
| def fetch_teacher_logprobs(base_url, input_ids, topk=20, timeout=300.0): | ||
| """Fetch top-k prompt logprobs from a vLLM-compatible /v1/completions endpoint. | ||
|
|
||
| Uses prompt_logprobs to get logprobs for input tokens without generating. | ||
| vLLM prompt_logprobs are always raw (temperature=1) log-probabilities from the model; | ||
| the temperature parameter in the API only affects token sampling, not prompt_logprobs. | ||
|
|
||
| Args: | ||
| base_url: vLLM server URL (e.g., 'http://localhost:8000'). | ||
| input_ids: List of token ID sequences. | ||
| topk: Number of top log probabilities per token. | ||
| timeout: Request timeout in seconds. | ||
|
|
||
| Returns: | ||
| (logprobs, indices) tensors of shape [batch, max_seq_len - 1, topk]. | ||
| The shift is because prompt_logprobs[0] is always None (first token has no | ||
| conditional probability), so position i in the output corresponds to | ||
| P(token_{i+1} | token_0..token_i), aligning with model logits[i]. | ||
| """ | ||
| import logging | ||
| import requests | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
||
| _logger = logging.getLogger(__name__) | ||
| base_url = base_url.rstrip('/') | ||
| batch_size = len(input_ids) | ||
| max_seq_len = max(len(ids) for ids in input_ids) | ||
| url = f'{base_url}/v1/completions' | ||
| global teacher_model_server_model_name | ||
| if teacher_model_server_model_name is None: | ||
| try: | ||
| resp = requests.get(f'{base_url}/v1/models', timeout=10) | ||
| model = resp.json()['data'][0]['id'] if resp.ok else 'default' | ||
| except Exception: | ||
| model = 'default' | ||
| teacher_model_server_model_name = model | ||
| else: | ||
| model = teacher_model_server_model_name | ||
|
|
||
| # prompt_logprobs[0] is always None (no conditional prob for the first token), | ||
| # prompt_logprobs[i] = P(token_i | token_0..token_{i-1}) which aligns with logits[i-1]. | ||
| # So we skip position 0 and the result has shape [batch, max_seq_len-1, topk], | ||
| # aligning with student logits which predict the next token at each position. | ||
| out_len = max_seq_len - 1 | ||
| logprobs_out = torch.full((batch_size, out_len, topk), float('-inf'), dtype=torch.float32) | ||
| indices_out = torch.zeros((batch_size, out_len, topk), dtype=torch.long) | ||
|
|
||
| def _fetch_one(batch_idx): | ||
| payload = { | ||
| 'model': model, | ||
| 'prompt': input_ids[batch_idx], | ||
| 'max_tokens': 1, | ||
| 'temperature': 0, | ||
| 'prompt_logprobs': topk, | ||
| } | ||
| try: | ||
| resp = requests.post(url, json=payload, timeout=timeout) | ||
| resp.raise_for_status() | ||
| prompt_logprobs_list = resp.json()['choices'][0].get('prompt_logprobs', []) | ||
| # Skip position 0 (always None), shift left so pos 1 -> output pos 0 | ||
| for raw_pos in range(1, len(prompt_logprobs_list)): | ||
| pos_lp = prompt_logprobs_list[raw_pos] | ||
| if pos_lp is None: | ||
| continue | ||
| out_pos = raw_pos - 1 | ||
| if out_pos >= out_len: | ||
| break | ||
| sorted_items = sorted(pos_lp.items(), key=lambda x: -x[1]['logprob'])[:topk] | ||
| for k, (tid_str, info) in enumerate(sorted_items): | ||
| indices_out[batch_idx, out_pos, k] = int(tid_str) | ||
| logprobs_out[batch_idx, out_pos, k] = info['logprob'] | ||
| except Exception as e: | ||
| _logger.error(f'Failed to get teacher logprobs for sequence {batch_idx}: {e}') | ||
|
|
||
| with ThreadPoolExecutor(max_workers=min(batch_size, 8)) as pool: | ||
| list(pool.map(_fetch_one, range(batch_size))) | ||
|
|
||
| return logprobs_out, indices_out |
There was a problem hiding this comment.
The use of a global variable teacher_model_server_model_name to cache the model name from the teacher server is not ideal. Global state can lead to race conditions, makes the code harder to test, and can cause unexpected behavior, especially if multiple trainers were to run in the same process.
A better approach would be to manage this state within the trainer instance. I suggest the following refactoring:
- Modify
fetch_teacher_logprobsto be a stateless utility function that accepts the model name as a parameter. - Have the
GKDTrainerandMegatronGKDTrainerclasses be responsible for fetching and caching the model name in an instance attribute (e.g.,self.teacher_model_server_model_name) during initialization. - Pass this cached model name to
fetch_teacher_logprobswhen calling it.
This would look something like this:
# In swift/rlhf_trainers/gkd_trainer.py
# In GKDTrainer.__init__
self.teacher_model_server_model_name = None
if self.use_teacher_api:
try:
import requests
resp = requests.get(f'{self.teacher_model_server.rstrip("/")}/v1/models', timeout=10)
self.teacher_model_server_model_name = resp.json()['data'][0]['id'] if resp.ok else 'default'
except Exception as e:
logger.warning(f'Failed to get model name from teacher server, using "default". Error: {e}')
self.teacher_model_server_model_name = 'default'
# In GKDTrainer._fetch_teacher_logprobs_from_api
teacher_logprobs, teacher_indices = fetch_teacher_logprobs(
self.teacher_model_server,
self.teacher_model_server_model_name, # Pass cached model name
input_ids.tolist(),
topk=self.gkd_logits_topk
)
# Refactored fetch_teacher_logprobs
def fetch_teacher_logprobs(base_url, model_name, input_ids, topk=20, timeout=300.0):
# ... remove global variable and model name fetching logic
payload = {
'model': model_name,
# ...
}
# ...A similar change would be needed in MegatronGKDTrainer.
|
|
||
| **Top-K Mode Principle**: | ||
|
|
||
| In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. |
There was a problem hiding this comment.
There's a minor grammatical error here. "It use" should be "It uses".
| In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. | |
| In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It uses the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. |
No description provided.