|
| 1 | +""" |
| 2 | +VLLMPolicy - Policy for TRL's VLLMClient or colocated vLLM LLM. |
| 3 | +
|
| 4 | +Thin adapter that turns Eval Protocol-style message lists into a single prompt, |
| 5 | +then calls either: |
| 6 | +
|
| 7 | +- TRL's VLLMClient (server mode), or |
| 8 | +- a colocated vLLM LLM instance (SamplingParams mode). |
| 9 | +""" |
| 10 | + |
| 11 | +import logging |
| 12 | +from typing import Any, Dict, List, Optional |
| 13 | + |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class VLLMPolicy: |
| 19 | + """ |
| 20 | + Policy that uses TRL's VLLMClient for generation. |
| 21 | +
|
| 22 | + This is designed to work with `trl vllm-serve` which provides |
| 23 | + custom /generate/ and /chat/ endpoints. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + vllm_client, # trainer.vllm_client |
| 29 | + tokenizer=None, # Optional tokenizer for decoding |
| 30 | + temperature: float = 1.0, |
| 31 | + max_tokens: int = 100, |
| 32 | + top_p: Optional[float] = None, |
| 33 | + top_k: Optional[int] = None, |
| 34 | + **kwargs, |
| 35 | + ): |
| 36 | + """ |
| 37 | + Initialize VLLMPolicy. |
| 38 | +
|
| 39 | + Args: |
| 40 | + vllm_client: TRL's VLLMClient instance (from trainer.vllm_client) |
| 41 | + tokenizer: Optional tokenizer for decoding token IDs to text |
| 42 | + temperature: Sampling temperature |
| 43 | + max_tokens: Maximum tokens to generate |
| 44 | + top_p: Top-p sampling |
| 45 | + top_k: Top-k sampling |
| 46 | + **kwargs: Additional generation parameters |
| 47 | + """ |
| 48 | + self.vllm_client = vllm_client |
| 49 | + self.tokenizer = tokenizer |
| 50 | + self.temperature = temperature |
| 51 | + self.max_tokens = max_tokens |
| 52 | + self.top_p = top_p if top_p is not None else 1.0 |
| 53 | + self.top_k = top_k if top_k is not None else -1 |
| 54 | + self.kwargs = kwargs |
| 55 | + |
| 56 | + async def _make_llm_call( |
| 57 | + self, |
| 58 | + messages: List[Dict[str, Any]], |
| 59 | + tools: Optional[List] = None, |
| 60 | + ) -> Dict[str, Any]: |
| 61 | + """ |
| 62 | + Make LLM call using TRL's VLLMClient or a colocated vLLM LLM. |
| 63 | +
|
| 64 | + Args: |
| 65 | + messages: List of message dicts with 'role' and 'content' |
| 66 | + tools: Not used (for compatibility) |
| 67 | +
|
| 68 | + Returns: |
| 69 | + OpenAI-compatible response dict |
| 70 | + """ |
| 71 | + # Apply chat template to convert messages to a prompt string |
| 72 | + if self.tokenizer is not None: |
| 73 | + try: |
| 74 | + # Use tokenizer's chat template |
| 75 | + prompt_text = self.tokenizer.apply_chat_template( |
| 76 | + messages, |
| 77 | + add_generation_prompt=True, |
| 78 | + tokenize=False, |
| 79 | + ) |
| 80 | + logger.debug( |
| 81 | + "[VLLMPolicy] Chat template applied for %d messages (prompt length=%d)", |
| 82 | + len(messages), |
| 83 | + len(prompt_text), |
| 84 | + ) |
| 85 | + except Exception as e: |
| 86 | + logger.warning( |
| 87 | + "[VLLMPolicy] Failed to apply chat template: %s", |
| 88 | + e, |
| 89 | + exc_info=True, |
| 90 | + ) |
| 91 | + # Fallback: simple concatenation (defensive .get access) |
| 92 | + prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages) |
| 93 | + else: |
| 94 | + # No tokenizer: simple concatenation |
| 95 | + prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages) |
| 96 | + |
| 97 | + # Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode) |
| 98 | + is_llm_object = hasattr(self.vllm_client, "llm_engine") # LLM has llm_engine |
| 99 | + |
| 100 | + if is_llm_object: |
| 101 | + # Colocate mode: use SamplingParams |
| 102 | + logger.debug("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams") |
| 103 | + from vllm import SamplingParams |
| 104 | + |
| 105 | + sampling_params = SamplingParams( |
| 106 | + temperature=self.temperature, |
| 107 | + max_tokens=self.max_tokens, |
| 108 | + top_p=self.top_p, |
| 109 | + top_k=self.top_k, |
| 110 | + n=1, |
| 111 | + ) |
| 112 | + |
| 113 | + logger.debug("[VLLMPolicy] Calling LLM.generate()") |
| 114 | + outputs = self.vllm_client.generate([prompt_text], sampling_params=sampling_params, use_tqdm=False) |
| 115 | + |
| 116 | + # Extract from vLLM output format |
| 117 | + output = outputs[0] |
| 118 | + prompt_ids = output.prompt_token_ids |
| 119 | + completion_ids = output.outputs[0].token_ids |
| 120 | + response = { |
| 121 | + "prompt_ids": [prompt_ids], |
| 122 | + "completion_ids": [completion_ids], |
| 123 | + } |
| 124 | + else: |
| 125 | + # Server mode: use VLLMClient with kwargs |
| 126 | + logger.debug("[VLLMPolicy] Using VLLMClient (server mode)") |
| 127 | + vllm_params = { |
| 128 | + "temperature": self.temperature, |
| 129 | + "max_tokens": self.max_tokens, |
| 130 | + "top_p": self.top_p, |
| 131 | + "top_k": self.top_k, |
| 132 | + "n": 1, |
| 133 | + } |
| 134 | + vllm_params.update(self.kwargs) |
| 135 | + |
| 136 | + logger.debug("[VLLMPolicy] Calling vllm_client.generate()") |
| 137 | + response = self.vllm_client.generate( |
| 138 | + prompts=[prompt_text], |
| 139 | + **vllm_params, |
| 140 | + ) |
| 141 | + |
| 142 | + # Extract first result |
| 143 | + prompt_ids = response["prompt_ids"][0] |
| 144 | + completion_ids = response["completion_ids"][0] |
| 145 | + |
| 146 | + # Decode completion text if tokenizer available |
| 147 | + if self.tokenizer is not None: |
| 148 | + try: |
| 149 | + completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) |
| 150 | + logger.debug( |
| 151 | + "[VLLMPolicy] Generation result: prompt_tokens=%d, completion_tokens=%d, completion_chars=%d", |
| 152 | + len(prompt_ids), |
| 153 | + len(completion_ids), |
| 154 | + len(completion_text), |
| 155 | + ) |
| 156 | + except Exception as e: |
| 157 | + logger.warning( |
| 158 | + "[VLLMPolicy] Failed to decode completion: %s", |
| 159 | + e, |
| 160 | + exc_info=True, |
| 161 | + ) |
| 162 | + completion_text = f"<decoded_error:{len(completion_ids)}_tokens>" |
| 163 | + else: |
| 164 | + # Fallback: just indicate number of tokens |
| 165 | + completion_text = f"<{len(completion_ids)}_tokens>" |
| 166 | + |
| 167 | + # Convert to OpenAI-compatible format for compatibility with OpenEnvRolloutProcessor |
| 168 | + # Also include raw token IDs for TRL integration (avoids double encoding) |
| 169 | + return { |
| 170 | + "choices": [ |
| 171 | + { |
| 172 | + "message": { |
| 173 | + "content": completion_text, |
| 174 | + "role": "assistant", |
| 175 | + } |
| 176 | + } |
| 177 | + ], |
| 178 | + "usage": { |
| 179 | + "prompt_tokens": len(prompt_ids), |
| 180 | + "completion_tokens": len(completion_ids), |
| 181 | + "total_tokens": len(prompt_ids) + len(completion_ids), |
| 182 | + }, |
| 183 | + # Include raw token IDs for TRL (avoids re-encoding) |
| 184 | + "prompt_ids": prompt_ids, |
| 185 | + "completion_ids": completion_ids, |
| 186 | + } |
0 commit comments