Skip to content

Commit 9e5107f

Browse files
author
Shrey Modi
committed
openenvrolloutprocessor
1 parent 15dd08d commit 9e5107f

11 files changed

+1247
-18
lines changed
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
"""
2+
Lightweight vLLM + OpenEnv Integration
3+
4+
Simplified integration using vLLM for inference with proper multi-turn completion splitting.
5+
No Fireworks inference, no hot reload - just vLLM.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import asyncio
11+
from typing import Any, Callable, Dict, List, Optional, Type
12+
13+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
14+
from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor
15+
from eval_protocol.pytest.types import RolloutProcessorConfig
16+
from eval_protocol.utils.evaluation_row_utils import (
17+
filter_longest_conversation,
18+
multi_turn_assistant_to_ground_truth,
19+
assistant_to_ground_truth,
20+
)
21+
from trl import GRPOConfig
22+
23+
24+
def create_openenv_vllm_rollout_func(
25+
env_factory: Callable[[], Any] | None,
26+
prompt_builder: Callable[[Any, int, list[str]], Any],
27+
action_parser: Callable[[str], Any],
28+
vllm_base_url: str = "http://localhost:8000",
29+
max_steps: int = 8,
30+
*,
31+
split_mode: str = "multi_turn", # "multi_turn", "last_turn", "longest", or None
32+
completion_params: Dict[str, Any] | None = None,
33+
concurrency: int | None = None,
34+
processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor,
35+
processor_kwargs: Optional[Dict[str, Any]] = None,
36+
# Environment configuration
37+
env_client_cls: Optional[Type[Any]] = None,
38+
tasks: List[str] | None = None,
39+
miniwob_url: str | None = None,
40+
docker_image: str = "browsergym-env:latest",
41+
env_base_url: Optional[str] = None,
42+
request_timeout_s: float = 15.0,
43+
default_headers: Optional[Dict[str, str]] = None,
44+
provider: Any | None = None,
45+
docker_port: Optional[int] = None,
46+
env_vars: Optional[Dict[str, str]] = None,
47+
benchmark: str = "miniwob",
48+
headless: bool = True,
49+
viewport_width: int = 1280,
50+
viewport_height: int = 720,
51+
timeout_ms: int = 10000,
52+
):
53+
"""
54+
Build a TRL-compatible rollout_func using vLLM inference with OpenEnv.
55+
56+
This is a lightweight version that:
57+
- Uses vLLM client directly (no Fireworks, no hot reload)
58+
- Properly splits completions using evaluation_row_utils helpers
59+
- Works with TRL's GRPO trainer
60+
61+
Args:
62+
env_factory: Callable yielding an OpenEnv HTTPEnvClient instance
63+
prompt_builder: (observation, step, history) -> content for LLM
64+
action_parser: (llm_response: str) -> env action object
65+
vllm_base_url: Base URL for vLLM server (e.g., "http://localhost:8000")
66+
max_steps: Maximum environment steps per rollout
67+
split_mode: How to split completions:
68+
- "multi_turn": Split each assistant message as separate row (multi_turn_assistant_to_ground_truth)
69+
- "last_turn": Extract last assistant message as ground truth (assistant_to_ground_truth)
70+
- "longest": Keep only longest conversation (filter_longest_conversation)
71+
- None: No splitting, return all rows as-is
72+
completion_params: Extra completion parameters (temperature, max_tokens, etc.)
73+
concurrency: Max concurrent rollouts (defaults to per_device_train_batch_size)
74+
processor_cls: Rollout processor class (default: OpenEnvRolloutProcessor)
75+
processor_kwargs: Extra kwargs for processor
76+
env_client_cls: Environment client class
77+
tasks: List of task names to rotate through
78+
miniwob_url: MiniWoB base URL
79+
docker_image: Docker image for environments
80+
env_base_url: Direct HTTP connection to existing server
81+
request_timeout_s: HTTP timeout
82+
default_headers: HTTP headers
83+
provider: Docker provider
84+
docker_port: Host port binding
85+
env_vars: Environment variables for container
86+
benchmark: BrowserGym benchmark name
87+
headless: Headless browser mode
88+
viewport_width/height: Browser viewport size
89+
timeout_ms: Action timeout
90+
91+
Returns:
92+
rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List]
93+
94+
Example:
95+
```python
96+
from trl import GRPOConfig, GRPOTrainer
97+
from trl.extras.vllm_client import VLLMClient
98+
from envs.browsergym_env import BrowserGymEnv, BrowserGymAction
99+
100+
# Start vLLM server first:
101+
# CUDA_VISIBLE_DEVICES=0,1 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2
102+
103+
def make_env():
104+
return BrowserGymEnv.from_docker_image(
105+
"browsergym-env:latest",
106+
env_vars={"BROWSERGYM_BENCHMARK": "miniwob"}
107+
)
108+
109+
def build_prompt(obs, step, history):
110+
return f"Step {step}\\nGoal: {obs.goal}\\n{obs.text[:500]}"
111+
112+
def parse_action(text):
113+
return BrowserGymAction(action_str=text)
114+
115+
rollout_func = create_openenv_vllm_rollout_func(
116+
env_factory=make_env,
117+
prompt_builder=build_prompt,
118+
action_parser=parse_action,
119+
vllm_base_url="http://localhost:8000",
120+
tasks=["click-test", "click-button", "enter-text"],
121+
split_mode="multi_turn", # Split each turn for training
122+
)
123+
124+
training_args = GRPOConfig(
125+
output_dir="outputs/vllm-training",
126+
per_device_train_batch_size=2,
127+
num_generations=4,
128+
)
129+
130+
trainer = GRPOTrainer(
131+
model="Qwen/Qwen2.5-7B",
132+
args=training_args,
133+
train_dataset=dataset,
134+
rollout_func=rollout_func,
135+
)
136+
```
137+
"""
138+
139+
# Import vLLM client (will be used for generation)
140+
try:
141+
from trl.extras.vllm_client import VLLMClient
142+
except ImportError:
143+
raise ImportError(
144+
"vLLM client not available. Install with: pip install trl[vllm]"
145+
)
146+
147+
# Initialize vLLM client
148+
vllm_client = VLLMClient(base_url=vllm_base_url)
149+
150+
def rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List]:
151+
"""
152+
Execute rollouts and return TRL-compatible results.
153+
154+
Flow:
155+
1. Prompts → EvaluationRows (num_generations per prompt)
156+
2. Execute rollouts via OpenEnvRolloutProcessor
157+
3. Split completions using evaluation_row_utils
158+
4. Generate completions via vLLM for each split row
159+
5. Convert to TRL format
160+
"""
161+
num_generations = getattr(args, "num_generations", 8)
162+
163+
# 1) Build evaluation rows (one per generation per prompt)
164+
evaluation_rows: List[EvaluationRow] = []
165+
for prompt in prompts:
166+
for gen_idx in range(num_generations):
167+
evaluation_rows.append(
168+
EvaluationRow(
169+
messages=[Message(role="user", content=prompt)],
170+
input_metadata=InputMetadata(
171+
completion_params={},
172+
extra={"generation_idx": gen_idx}
173+
),
174+
)
175+
)
176+
177+
# 2) Build processor config
178+
base_params: Dict[str, Any] = {
179+
"temperature": getattr(args, "temperature", 1.0),
180+
"max_tokens": getattr(args, "max_completion_length", 100),
181+
}
182+
if completion_params:
183+
base_params.update(completion_params)
184+
185+
max_concurrency = concurrency if concurrency is not None else getattr(
186+
args, "per_device_train_batch_size", 1
187+
)
188+
189+
config = RolloutProcessorConfig(
190+
completion_params=base_params,
191+
mcp_config_path="",
192+
semaphore=asyncio.Semaphore(max_concurrency),
193+
steps=max_steps,
194+
)
195+
196+
# 3) Execute rollouts using OpenEnvRolloutProcessor
197+
Processor = processor_cls or OpenEnvRolloutProcessor
198+
_kwargs: Dict[str, Any] = dict(processor_kwargs or {})
199+
_kwargs.setdefault("env_factory", env_factory)
200+
_kwargs.setdefault("prompt_builder", prompt_builder)
201+
_kwargs.setdefault("action_parser", action_parser)
202+
_kwargs.setdefault("env_client_cls", env_client_cls)
203+
_kwargs.setdefault("tasks", tasks)
204+
_kwargs.setdefault("miniwob_url", miniwob_url)
205+
_kwargs.setdefault("docker_image", docker_image)
206+
_kwargs.setdefault("env_base_url", env_base_url)
207+
_kwargs.setdefault("request_timeout_s", request_timeout_s)
208+
_kwargs.setdefault("default_headers", default_headers)
209+
_kwargs.setdefault("provider", provider)
210+
_kwargs.setdefault("docker_port", docker_port)
211+
_kwargs.setdefault("env_vars", env_vars)
212+
_kwargs.setdefault("benchmark", benchmark)
213+
_kwargs.setdefault("headless", headless)
214+
_kwargs.setdefault("viewport_width", viewport_width)
215+
_kwargs.setdefault("viewport_height", viewport_height)
216+
_kwargs.setdefault("timeout_ms", timeout_ms)
217+
_kwargs.setdefault("num_generations", num_generations)
218+
219+
processor = Processor(**_kwargs)
220+
221+
loop = asyncio.new_event_loop()
222+
asyncio.set_event_loop(loop)
223+
try:
224+
async def _run_all():
225+
tasks = processor(evaluation_rows, config)
226+
return await asyncio.gather(*tasks)
227+
228+
completed_rows = loop.run_until_complete(_run_all())
229+
finally:
230+
loop.close()
231+
232+
# 4) Split completions based on split_mode
233+
if split_mode == "multi_turn":
234+
# Split each assistant message into separate rows
235+
split_rows = multi_turn_assistant_to_ground_truth(completed_rows)
236+
elif split_mode == "last_turn":
237+
# Extract last assistant message as ground truth
238+
split_rows = assistant_to_ground_truth(completed_rows)
239+
elif split_mode == "longest":
240+
# Keep only longest conversation per rollout_id
241+
split_rows = filter_longest_conversation(completed_rows)
242+
elif split_mode is None:
243+
# No splitting
244+
split_rows = completed_rows
245+
else:
246+
raise ValueError(
247+
f"Invalid split_mode: {split_mode}. "
248+
"Must be 'multi_turn', 'last_turn', 'longest', or None"
249+
)
250+
251+
print(f"[OpenEnvVLLM] Split {len(completed_rows)} rows → {len(split_rows)} rows (mode={split_mode})")
252+
253+
# 5) Generate completions via vLLM for each split row
254+
# Build messages for vLLM chat endpoint
255+
all_messages: List[List[Dict]] = []
256+
for row in split_rows:
257+
messages = [{"role": msg.role, "content": msg.content} for msg in row.messages]
258+
all_messages.append(messages)
259+
260+
# Call vLLM to generate completions
261+
# Check if we have conversational format
262+
is_conversational = all_messages and isinstance(all_messages[0], list)
263+
264+
vllm_params = {
265+
"n": 1, # One completion per split row
266+
"temperature": base_params["temperature"],
267+
"max_tokens": base_params["max_tokens"],
268+
}
269+
270+
# Add any extra vLLM parameters from completion_params
271+
if completion_params:
272+
for key in ["top_p", "top_k", "min_p", "repetition_penalty"]:
273+
if key in completion_params:
274+
vllm_params[key] = completion_params[key]
275+
276+
if is_conversational:
277+
print(f"[OpenEnvVLLM] Calling vLLM chat endpoint with {len(all_messages)} conversations")
278+
vllm_response = vllm_client.chat(
279+
messages=all_messages,
280+
**vllm_params,
281+
)
282+
else:
283+
# Convert messages to prompts for generate endpoint
284+
prompts_for_vllm = []
285+
for msgs in all_messages:
286+
# Simple concatenation (you may want to use a chat template here)
287+
prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in msgs)
288+
prompts_for_vllm.append(prompt_text)
289+
290+
print(f"[OpenEnvVLLM] Calling vLLM generate endpoint with {len(prompts_for_vllm)} prompts")
291+
vllm_response = vllm_client.generate(
292+
prompts=prompts_for_vllm,
293+
**vllm_params,
294+
)
295+
296+
# 6) Convert to TRL format
297+
prompt_ids = vllm_response["prompt_ids"]
298+
completion_ids = vllm_response["completion_ids"]
299+
logprobs = vllm_response["logprobs"]
300+
301+
# Extract step rewards from completed rows
302+
step_rewards: List[List[float]] = []
303+
for row in split_rows:
304+
rewards: List[float] = []
305+
306+
# Look for rewards in system messages (sentinel pattern)
307+
for msg in row.messages:
308+
if msg.role == "system":
309+
try:
310+
content = msg.content or ""
311+
if isinstance(content, str) and content.startswith("__ep_step_rewards__:"):
312+
import json
313+
payload = content.split(":", 1)[1]
314+
rewards = json.loads(payload) or []
315+
break
316+
except Exception:
317+
pass
318+
319+
# Fallback to execution metadata
320+
if not rewards and hasattr(row.execution_metadata, "extra"):
321+
try:
322+
rewards = row.execution_metadata.extra.get("step_rewards", []) or []
323+
except Exception:
324+
pass
325+
326+
step_rewards.append(rewards if rewards else [0.0])
327+
328+
# Compute statistics
329+
total_reward = sum(sum(r) for r in step_rewards)
330+
avg_reward = total_reward / len(step_rewards) if step_rewards else 0.0
331+
print(f"[OpenEnvVLLM] Total reward: {total_reward:.2f}, Avg: {avg_reward:.2f}")
332+
333+
# TRL expects prompt_ids at unique-prompt level (not per-generation)
334+
# Deduplicate while preserving order
335+
seen_prompts = set()
336+
prompt_ids_unique = []
337+
for p_ids in prompt_ids:
338+
p_tuple = tuple(p_ids)
339+
if p_tuple not in seen_prompts:
340+
seen_prompts.add(p_tuple)
341+
prompt_ids_unique.append(p_ids)
342+
343+
return {
344+
"prompt_ids": prompt_ids_unique,
345+
"completion_ids": completion_ids,
346+
"logprobs": logprobs,
347+
"step_rewards": step_rewards,
348+
}
349+
350+
return rollout_func
351+

0 commit comments

Comments
 (0)