Skip to content

Commit 7e71e03

Browse files
author
Shrey Modi
committed
trl integration
1 parent ed93cb0 commit 7e71e03

File tree

2 files changed

+428
-71
lines changed

2 files changed

+428
-71
lines changed
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
"""
2+
Lightweight vLLM + OpenEnv Integration
3+
4+
Minimal integration to use TRL's vLLM server for inference with OpenEnv BrowserGym
5+
environments, wired into GRPO via a custom ``rollout_func``.
6+
7+
- Uses TRL's ``VLLMClient`` (``use_vllm=True, vllm_mode="server"``) for inference
8+
- Uses ``OpenEnvRolloutProcessor`` to drive OpenEnv (BrowserGym-style) environments
9+
- Supports task rotation across MiniWoB tasks
10+
- Returns Wordle-style GRPO data: 2D token lists and 1D per-episode rewards
11+
- No Fireworks, no hot reload, no additional providers
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import asyncio
17+
import sys
18+
from typing import Any, Callable, Dict, List, Optional, Type
19+
20+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
21+
from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor
22+
from eval_protocol.pytest.types import RolloutProcessorConfig
23+
24+
25+
def create_openenv_vllm_rollout_func(
26+
env_factory: Callable[[], Any] | None,
27+
prompt_builder: Callable[[Any, int, list[str]], Any],
28+
action_parser: Callable[[str], Any],
29+
vllm_base_url: str = "http://localhost:8000",
30+
vllm_model: str = "Qwen/Qwen2.5-7B",
31+
max_steps: int = 8,
32+
*,
33+
completion_params: Dict[str, Any] | None = None,
34+
concurrency: int | None = None,
35+
processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor,
36+
processor_kwargs: Optional[Dict[str, Any]] = None,
37+
# Environment configuration
38+
env_client_cls: Optional[Type[Any]] = None,
39+
tasks: List[str] | None = None,
40+
miniwob_url: str | None = None,
41+
docker_image: str = "browsergym-env:latest",
42+
env_base_url: Optional[str] = None,
43+
request_timeout_s: float = 15.0,
44+
default_headers: Optional[Dict[str, str]] = None,
45+
provider: Any | None = None,
46+
docker_port: Optional[int] = None,
47+
env_vars: Optional[Dict[str, str]] = None,
48+
benchmark: str = "miniwob",
49+
headless: bool = True,
50+
viewport_width: int = 1280,
51+
viewport_height: int = 720,
52+
timeout_ms: int = 10000,
53+
):
54+
"""
55+
Build a TRL-compatible ``rollout_func`` using vLLM inference with OpenEnv.
56+
57+
High-level:
58+
- ``GRPOTrainer`` calls the returned ``rollout_func(prompts, trainer)``
59+
- For each prompt, we create ``num_generations`` evaluation rows
60+
- ``OpenEnvRolloutProcessor`` runs BrowserGym-style episodes via Docker
61+
- ``VLLMPolicy`` formats messages with the chat template and calls TRL's
62+
vLLM server using ``trainer.vllm_client``
63+
- We accumulate tokens across all turns of an episode and sum rewards,
64+
returning Wordle-style GRPO data.
65+
66+
The environment side is configured via ``env_client_cls`` and the BrowserGym
67+
parameters (``tasks``, ``miniwob_url``, ``docker_image``, etc.).
68+
"""
69+
print(f"\n{'='*80}", flush=True)
70+
print(f"[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED", flush=True)
71+
print(f" vllm_base_url: {vllm_base_url}", flush=True)
72+
print(f" vllm_model: {vllm_model}", flush=True)
73+
print(f" tasks: {tasks}", flush=True)
74+
print(f" max_steps: {max_steps}", flush=True)
75+
print(f"{'='*80}", flush=True)
76+
sys.stdout.flush()
77+
78+
# Import VLLMPolicy
79+
from eval_protocol.mcp.execution.vllm_policy import VLLMPolicy
80+
81+
# Global-ish task rotation offset across rollout_func calls.
82+
# This lets us rotate tasks between GRPO steps instead of always
83+
# starting from tasks[0] when a new OpenEnvRolloutProcessor is created.
84+
task_cycle_index: int = 0
85+
86+
def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
87+
"""Execute rollouts via OpenEnv + vLLM and return GRPO-compatible results."""
88+
print("\n[OpenEnvVLLM] rollout_func called", flush=True)
89+
90+
# Extract args from trainer
91+
args = trainer.args
92+
processing_class = trainer.processing_class
93+
94+
num_generations = getattr(args, "num_generations", 8)
95+
print(
96+
f"[OpenEnvVLLM] Received {len(prompts)} prompts, "
97+
f"{num_generations} generations each",
98+
flush=True,
99+
)
100+
101+
# 1) Build evaluation rows
102+
evaluation_rows: List[EvaluationRow] = []
103+
for prompt in prompts:
104+
for gen_idx in range(num_generations):
105+
evaluation_rows.append(
106+
EvaluationRow(
107+
messages=[Message(role="user", content=prompt)],
108+
input_metadata=InputMetadata(
109+
completion_params={},
110+
extra={"generation_idx": gen_idx}
111+
),
112+
)
113+
)
114+
115+
# 2) Build processor config with VLLMPolicy
116+
# We'll pass trainer.vllm_client to VLLMPolicy
117+
base_params: Dict[str, Any] = {
118+
"model": "dummy", # Not used by VLLMPolicy, but needed for config
119+
"temperature": getattr(args, "temperature", 1.0),
120+
"max_tokens": getattr(args, "max_completion_length", 100),
121+
}
122+
if completion_params:
123+
base_params.update(completion_params)
124+
125+
print(
126+
f"[OpenEnvVLLM] Temperature={base_params['temperature']}, "
127+
f"max_tokens={base_params['max_tokens']}",
128+
flush=True,
129+
)
130+
print("[OpenEnvVLLM] Using TRL VLLMClient from trainer", flush=True)
131+
132+
max_concurrency = concurrency if concurrency is not None else getattr(
133+
args, "per_device_train_batch_size", 1
134+
)
135+
print(
136+
f"[OpenEnvVLLM] Max concurrency={max_concurrency}, "
137+
f"max_steps={max_steps}",
138+
flush=True,
139+
)
140+
141+
config = RolloutProcessorConfig(
142+
completion_params=base_params,
143+
mcp_config_path="",
144+
semaphore=asyncio.Semaphore(max_concurrency),
145+
steps=max_steps,
146+
)
147+
148+
# 3) Execute rollouts with VLLMPolicy
149+
print(
150+
f"[OpenEnvVLLM] Instantiating processor: "
151+
f"{processor_cls.__name__ if processor_cls else 'OpenEnvRolloutProcessor'}",
152+
flush=True,
153+
)
154+
155+
# Create policy factory that uses trainer's vllm_client
156+
def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs):
157+
"""Factory that creates VLLMPolicy using trainer's vllm_client."""
158+
return VLLMPolicy(
159+
vllm_client=trainer.vllm_client, # Use trainer's vLLM client!
160+
tokenizer=processing_class, # Pass tokenizer for decoding
161+
temperature=temperature,
162+
max_tokens=max_tokens,
163+
top_p=kwargs.get("top_p"),
164+
top_k=kwargs.get("top_k"),
165+
**kwargs,
166+
)
167+
168+
Processor = processor_cls or OpenEnvRolloutProcessor
169+
_kwargs: Dict[str, Any] = dict(processor_kwargs or {})
170+
_kwargs.setdefault("env_factory", env_factory)
171+
_kwargs.setdefault("prompt_builder", prompt_builder)
172+
_kwargs.setdefault("action_parser", action_parser)
173+
_kwargs.setdefault("policy_factory", vllm_policy_factory) # Pass VLLMPolicy factory!
174+
_kwargs.setdefault("env_client_cls", env_client_cls)
175+
176+
# Rotate tasks across rollout_func calls so each GRPO step
177+
# primarily targets a different task, while keeping all
178+
# generations within a step on the same task.
179+
rotated_tasks = tasks
180+
if tasks:
181+
nonlocal task_cycle_index
182+
offset = task_cycle_index % len(tasks)
183+
rotated_tasks = tasks[offset:] + tasks[:offset]
184+
task_cycle_index = (task_cycle_index + 1) % len(tasks)
185+
print(
186+
f"[OpenEnvVLLM] Task rotation offset={offset}, rotated={rotated_tasks}",
187+
flush=True,
188+
)
189+
_kwargs.setdefault("tasks", rotated_tasks)
190+
191+
_kwargs.setdefault("miniwob_url", miniwob_url)
192+
_kwargs.setdefault("docker_image", docker_image)
193+
_kwargs.setdefault("env_base_url", env_base_url)
194+
_kwargs.setdefault("request_timeout_s", request_timeout_s)
195+
_kwargs.setdefault("default_headers", default_headers)
196+
_kwargs.setdefault("provider", provider)
197+
_kwargs.setdefault("docker_port", docker_port)
198+
_kwargs.setdefault("env_vars", env_vars)
199+
_kwargs.setdefault("benchmark", benchmark)
200+
_kwargs.setdefault("headless", headless)
201+
_kwargs.setdefault("viewport_width", viewport_width)
202+
_kwargs.setdefault("viewport_height", viewport_height)
203+
_kwargs.setdefault("timeout_ms", timeout_ms)
204+
_kwargs.setdefault("num_generations", num_generations)
205+
206+
processor = Processor(**_kwargs)
207+
print(f"[OpenEnvVLLM] Processor instantiated successfully", flush=True)
208+
209+
loop = asyncio.new_event_loop()
210+
asyncio.set_event_loop(loop)
211+
try:
212+
async def _run_all():
213+
tasks_list = processor(evaluation_rows, config)
214+
return await asyncio.gather(*tasks_list)
215+
216+
completed_rows = loop.run_until_complete(_run_all())
217+
print(
218+
f"[OpenEnvVLLM] All rollouts completed: {len(completed_rows)} results",
219+
flush=True,
220+
)
221+
finally:
222+
loop.close()
223+
224+
# 4) Convert to Wordle-style format (no splitting)
225+
# Each completed_row is one rollout with multiple turns
226+
# We .extend() tokens across turns, then .append() per rollout
227+
print(
228+
f"[OpenEnvVLLM] Converting {len(completed_rows)} rollouts to TRL format",
229+
flush=True,
230+
)
231+
232+
tokenizer = getattr(processing_class, "tokenizer", None) or processing_class
233+
encode_fn = getattr(tokenizer, "encode", None)
234+
235+
episode_prompt_ids: List[List[int]] = []
236+
episode_completion_ids: List[List[int]] = []
237+
episode_logprobs: List[List[float]] = []
238+
step_rewards_all: List[List[float]] = []
239+
240+
for idx, row in enumerate(completed_rows):
241+
# Accumulate tokens across all turns in this rollout
242+
prompt_ids: List[int] = [] # .extend() for each turn
243+
completion_ids: List[int] = [] # .extend() for each turn
244+
logprobs: List[float] = [] # .extend() for each turn
245+
rewards: List[float] = []
246+
247+
# Go through all messages and accumulate tokens
248+
for msg in row.messages:
249+
if msg.role == "user":
250+
tokens = encode_fn(msg.content or "") if encode_fn else []
251+
prompt_ids.extend(tokens) # Accumulate user tokens
252+
elif msg.role == "assistant":
253+
tokens = encode_fn(msg.content or "") if encode_fn else []
254+
completion_ids.extend(tokens) # Accumulate assistant tokens
255+
logprobs.extend([0.0] * len(tokens)) # Placeholder logprobs
256+
elif msg.role == "system":
257+
# Extract step rewards
258+
try:
259+
content = msg.content or ""
260+
if isinstance(content, str) and content.startswith("__ep_step_rewards__:"):
261+
import json
262+
payload = content.split(":", 1)[1]
263+
rewards = json.loads(payload) or []
264+
except Exception:
265+
pass
266+
267+
# Fallback for rewards
268+
if not rewards and hasattr(row.execution_metadata, "extra"):
269+
try:
270+
rewards = row.execution_metadata.extra.get("step_rewards", []) or []
271+
except Exception:
272+
pass
273+
274+
# Append accumulated tokens for this episode
275+
episode_prompt_ids.append(prompt_ids if prompt_ids else [0])
276+
episode_completion_ids.append(completion_ids if completion_ids else [0])
277+
episode_logprobs.append(logprobs if logprobs else [0.0])
278+
step_rewards_all.append(rewards if rewards else [0.0])
279+
280+
total_reward = sum(sum(r) for r in step_rewards_all)
281+
avg_reward = total_reward / len(step_rewards_all) if step_rewards_all else 0.0
282+
print(
283+
f"[OpenEnvVLLM] Total reward={total_reward:.2f}, Avg reward={avg_reward:.2f}",
284+
flush=True,
285+
)
286+
print(
287+
f"[OpenEnvVLLM] Returning {len(episode_prompt_ids)} episodes", flush=True
288+
)
289+
sys.stdout.flush()
290+
291+
# Return in Wordle format
292+
# Tokens: 2D arrays (accumulate across turns, one list per episode)
293+
# Rewards: 1D arrays (one scalar per episode)
294+
total_rewards = [sum(r) for r in step_rewards_all] # Sum step rewards per episode
295+
296+
print(f"[OpenEnvVLLM] Episode rewards: {total_rewards}", flush=True)
297+
298+
return {
299+
"prompt_ids": episode_prompt_ids, # List[List[int]] - tokens per episode
300+
"completion_ids": episode_completion_ids, # List[List[int]] - tokens per episode
301+
"logprobs": episode_logprobs, # List[List[float]] - logprobs per episode
302+
"step_rewards": total_rewards, # List[float] - total reward per episode (1D!)
303+
}
304+
305+
print(f"[openenv_trl_vllm] Returning rollout_func (type={type(rollout_func)})", flush=True)
306+
sys.stdout.flush()
307+
return rollout_func
308+

0 commit comments

Comments
 (0)