Skip to content

Commit 425b882

Browse files
shreymodi1Shrey Modi
andauthored
openenvrolloutprocessor (#336)
* openenvrolloutprocessor * openenvrolloutprocessor * trl integration * comments * updates * final * finalll * updates * reward * lint --------- Co-authored-by: Shrey Modi <shrey@fireworks.ai>
1 parent f10c29f commit 425b882

File tree

10 files changed

+2410
-167
lines changed

10 files changed

+2410
-167
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
VLLMPolicy - Policy for TRL's VLLMClient or colocated vLLM LLM.
3+
4+
Thin adapter that turns Eval Protocol-style message lists into a single prompt,
5+
then calls either:
6+
7+
- TRL's VLLMClient (server mode), or
8+
- a colocated vLLM LLM instance (SamplingParams mode).
9+
"""
10+
11+
import logging
12+
from typing import Any, Dict, List, Optional
13+
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class VLLMPolicy:
19+
"""
20+
Policy that uses TRL's VLLMClient for generation.
21+
22+
This is designed to work with `trl vllm-serve` which provides
23+
custom /generate/ and /chat/ endpoints.
24+
"""
25+
26+
def __init__(
27+
self,
28+
vllm_client, # trainer.vllm_client
29+
tokenizer=None, # Optional tokenizer for decoding
30+
temperature: float = 1.0,
31+
max_tokens: int = 100,
32+
top_p: Optional[float] = None,
33+
top_k: Optional[int] = None,
34+
**kwargs,
35+
):
36+
"""
37+
Initialize VLLMPolicy.
38+
39+
Args:
40+
vllm_client: TRL's VLLMClient instance (from trainer.vllm_client)
41+
tokenizer: Optional tokenizer for decoding token IDs to text
42+
temperature: Sampling temperature
43+
max_tokens: Maximum tokens to generate
44+
top_p: Top-p sampling
45+
top_k: Top-k sampling
46+
**kwargs: Additional generation parameters
47+
"""
48+
self.vllm_client = vllm_client
49+
self.tokenizer = tokenizer
50+
self.temperature = temperature
51+
self.max_tokens = max_tokens
52+
self.top_p = top_p if top_p is not None else 1.0
53+
self.top_k = top_k if top_k is not None else -1
54+
self.kwargs = kwargs
55+
56+
async def _make_llm_call(
57+
self,
58+
messages: List[Dict[str, Any]],
59+
tools: Optional[List] = None,
60+
) -> Dict[str, Any]:
61+
"""
62+
Make LLM call using TRL's VLLMClient or a colocated vLLM LLM.
63+
64+
Args:
65+
messages: List of message dicts with 'role' and 'content'
66+
tools: Not used (for compatibility)
67+
68+
Returns:
69+
OpenAI-compatible response dict
70+
"""
71+
# Apply chat template to convert messages to a prompt string
72+
if self.tokenizer is not None:
73+
try:
74+
# Use tokenizer's chat template
75+
prompt_text = self.tokenizer.apply_chat_template(
76+
messages,
77+
add_generation_prompt=True,
78+
tokenize=False,
79+
)
80+
logger.debug(
81+
"[VLLMPolicy] Chat template applied for %d messages (prompt length=%d)",
82+
len(messages),
83+
len(prompt_text),
84+
)
85+
except Exception as e:
86+
logger.warning(
87+
"[VLLMPolicy] Failed to apply chat template: %s",
88+
e,
89+
exc_info=True,
90+
)
91+
# Fallback: simple concatenation (defensive .get access)
92+
prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages)
93+
else:
94+
# No tokenizer: simple concatenation
95+
prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages)
96+
97+
# Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode)
98+
is_llm_object = hasattr(self.vllm_client, "llm_engine") # LLM has llm_engine
99+
100+
if is_llm_object:
101+
# Colocate mode: use SamplingParams
102+
logger.debug("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams")
103+
from vllm import SamplingParams
104+
105+
sampling_params = SamplingParams(
106+
temperature=self.temperature,
107+
max_tokens=self.max_tokens,
108+
top_p=self.top_p,
109+
top_k=self.top_k,
110+
n=1,
111+
)
112+
113+
logger.debug("[VLLMPolicy] Calling LLM.generate()")
114+
outputs = self.vllm_client.generate([prompt_text], sampling_params=sampling_params, use_tqdm=False)
115+
116+
# Extract from vLLM output format
117+
output = outputs[0]
118+
prompt_ids = output.prompt_token_ids
119+
completion_ids = output.outputs[0].token_ids
120+
response = {
121+
"prompt_ids": [prompt_ids],
122+
"completion_ids": [completion_ids],
123+
}
124+
else:
125+
# Server mode: use VLLMClient with kwargs
126+
logger.debug("[VLLMPolicy] Using VLLMClient (server mode)")
127+
vllm_params = {
128+
"temperature": self.temperature,
129+
"max_tokens": self.max_tokens,
130+
"top_p": self.top_p,
131+
"top_k": self.top_k,
132+
"n": 1,
133+
}
134+
vllm_params.update(self.kwargs)
135+
136+
logger.debug("[VLLMPolicy] Calling vllm_client.generate()")
137+
response = self.vllm_client.generate(
138+
prompts=[prompt_text],
139+
**vllm_params,
140+
)
141+
142+
# Extract first result
143+
prompt_ids = response["prompt_ids"][0]
144+
completion_ids = response["completion_ids"][0]
145+
146+
# Decode completion text if tokenizer available
147+
if self.tokenizer is not None:
148+
try:
149+
completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
150+
logger.debug(
151+
"[VLLMPolicy] Generation result: prompt_tokens=%d, completion_tokens=%d, completion_chars=%d",
152+
len(prompt_ids),
153+
len(completion_ids),
154+
len(completion_text),
155+
)
156+
except Exception as e:
157+
logger.warning(
158+
"[VLLMPolicy] Failed to decode completion: %s",
159+
e,
160+
exc_info=True,
161+
)
162+
completion_text = f"<decoded_error:{len(completion_ids)}_tokens>"
163+
else:
164+
# Fallback: just indicate number of tokens
165+
completion_text = f"<{len(completion_ids)}_tokens>"
166+
167+
# Convert to OpenAI-compatible format for compatibility with OpenEnvRolloutProcessor
168+
# Also include raw token IDs for TRL integration (avoids double encoding)
169+
return {
170+
"choices": [
171+
{
172+
"message": {
173+
"content": completion_text,
174+
"role": "assistant",
175+
}
176+
}
177+
],
178+
"usage": {
179+
"prompt_tokens": len(prompt_ids),
180+
"completion_tokens": len(completion_ids),
181+
"total_tokens": len(prompt_ids) + len(completion_ids),
182+
},
183+
# Include raw token IDs for TRL (avoids re-encoding)
184+
"prompt_ids": prompt_ids,
185+
"completion_ids": completion_ids,
186+
}

eval_protocol/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,14 @@ class ExecutionMetadata(BaseModel):
776776
description="Processing duration in seconds for an entire experiment. Note that includes time it took for retries.",
777777
)
778778

779+
# Generic bag for integration-specific metadata.
780+
# Examples:
781+
# - OpenEnvRolloutProcessor: per-step rewards, token IDs for GRPO / TRL
782+
extra: Optional[Dict[str, Any]] = Field(
783+
default=None,
784+
description="Arbitrary execution metadata for integrations (step rewards, token IDs, debug info, etc.).",
785+
)
786+
779787
finish_reason: Optional[str] = Field(
780788
default=None,
781789
description="finish_reason reported by the completion response for this row.",

0 commit comments

Comments
 (0)