|
| 1 | +""" |
| 2 | +Lightweight vLLM + OpenEnv Integration |
| 3 | +
|
| 4 | +Minimal integration to use TRL's vLLM server for inference with OpenEnv BrowserGym |
| 5 | +environments, wired into GRPO via a custom ``rollout_func``. |
| 6 | +
|
| 7 | +- Uses TRL's ``VLLMClient`` (``use_vllm=True, vllm_mode="server"``) for inference |
| 8 | +- Uses ``OpenEnvRolloutProcessor`` to drive OpenEnv (BrowserGym-style) environments |
| 9 | +- Supports task rotation across MiniWoB tasks |
| 10 | +- Returns Wordle-style GRPO data: 2D token lists and 1D per-episode rewards |
| 11 | +- No Fireworks, no hot reload, no additional providers |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import asyncio |
| 17 | +import sys |
| 18 | +from typing import Any, Callable, Dict, List, Optional, Type |
| 19 | + |
| 20 | +from eval_protocol.models import EvaluationRow, InputMetadata, Message |
| 21 | +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor |
| 22 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
| 23 | + |
| 24 | + |
| 25 | +def create_openenv_vllm_rollout_func( |
| 26 | + env_factory: Callable[[], Any] | None, |
| 27 | + prompt_builder: Callable[[Any, int, list[str]], Any], |
| 28 | + action_parser: Callable[[str], Any], |
| 29 | + vllm_base_url: str = "http://localhost:8000", |
| 30 | + vllm_model: str = "Qwen/Qwen2.5-7B", |
| 31 | + max_steps: int = 8, |
| 32 | + *, |
| 33 | + completion_params: Dict[str, Any] | None = None, |
| 34 | + concurrency: int | None = None, |
| 35 | + processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor, |
| 36 | + processor_kwargs: Optional[Dict[str, Any]] = None, |
| 37 | + # Environment configuration |
| 38 | + env_client_cls: Optional[Type[Any]] = None, |
| 39 | + tasks: List[str] | None = None, |
| 40 | + miniwob_url: str | None = None, |
| 41 | + docker_image: str = "browsergym-env:latest", |
| 42 | + env_base_url: Optional[str] = None, |
| 43 | + request_timeout_s: float = 15.0, |
| 44 | + default_headers: Optional[Dict[str, str]] = None, |
| 45 | + provider: Any | None = None, |
| 46 | + docker_port: Optional[int] = None, |
| 47 | + env_vars: Optional[Dict[str, str]] = None, |
| 48 | + benchmark: str = "miniwob", |
| 49 | + headless: bool = True, |
| 50 | + viewport_width: int = 1280, |
| 51 | + viewport_height: int = 720, |
| 52 | + timeout_ms: int = 10000, |
| 53 | +): |
| 54 | + """ |
| 55 | + Build a TRL-compatible ``rollout_func`` using vLLM inference with OpenEnv. |
| 56 | +
|
| 57 | + High-level: |
| 58 | + - ``GRPOTrainer`` calls the returned ``rollout_func(prompts, trainer)`` |
| 59 | + - For each prompt, we create ``num_generations`` evaluation rows |
| 60 | + - ``OpenEnvRolloutProcessor`` runs BrowserGym-style episodes via Docker |
| 61 | + - ``VLLMPolicy`` formats messages with the chat template and calls TRL's |
| 62 | + vLLM server using ``trainer.vllm_client`` |
| 63 | + - We accumulate tokens across all turns of an episode and sum rewards, |
| 64 | + returning Wordle-style GRPO data. |
| 65 | +
|
| 66 | + The environment side is configured via ``env_client_cls`` and the BrowserGym |
| 67 | + parameters (``tasks``, ``miniwob_url``, ``docker_image``, etc.). |
| 68 | + """ |
| 69 | + print(f"\n{'='*80}", flush=True) |
| 70 | + print(f"[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED", flush=True) |
| 71 | + print(f" vllm_base_url: {vllm_base_url}", flush=True) |
| 72 | + print(f" vllm_model: {vllm_model}", flush=True) |
| 73 | + print(f" tasks: {tasks}", flush=True) |
| 74 | + print(f" max_steps: {max_steps}", flush=True) |
| 75 | + print(f"{'='*80}", flush=True) |
| 76 | + sys.stdout.flush() |
| 77 | + |
| 78 | + # Import VLLMPolicy |
| 79 | + from eval_protocol.mcp.execution.vllm_policy import VLLMPolicy |
| 80 | + |
| 81 | + # Global-ish task rotation offset across rollout_func calls. |
| 82 | + # This lets us rotate tasks between GRPO steps instead of always |
| 83 | + # starting from tasks[0] when a new OpenEnvRolloutProcessor is created. |
| 84 | + task_cycle_index: int = 0 |
| 85 | + |
| 86 | + def rollout_func(prompts: List[str], trainer) -> Dict[str, List]: |
| 87 | + """Execute rollouts via OpenEnv + vLLM and return GRPO-compatible results.""" |
| 88 | + print("\n[OpenEnvVLLM] rollout_func called", flush=True) |
| 89 | + |
| 90 | + # Extract args from trainer |
| 91 | + args = trainer.args |
| 92 | + processing_class = trainer.processing_class |
| 93 | + |
| 94 | + num_generations = getattr(args, "num_generations", 8) |
| 95 | + print( |
| 96 | + f"[OpenEnvVLLM] Received {len(prompts)} prompts, " |
| 97 | + f"{num_generations} generations each", |
| 98 | + flush=True, |
| 99 | + ) |
| 100 | + |
| 101 | + # 1) Build evaluation rows |
| 102 | + evaluation_rows: List[EvaluationRow] = [] |
| 103 | + for prompt in prompts: |
| 104 | + for gen_idx in range(num_generations): |
| 105 | + evaluation_rows.append( |
| 106 | + EvaluationRow( |
| 107 | + messages=[Message(role="user", content=prompt)], |
| 108 | + input_metadata=InputMetadata( |
| 109 | + completion_params={}, |
| 110 | + extra={"generation_idx": gen_idx} |
| 111 | + ), |
| 112 | + ) |
| 113 | + ) |
| 114 | + |
| 115 | + # 2) Build processor config with VLLMPolicy |
| 116 | + # We'll pass trainer.vllm_client to VLLMPolicy |
| 117 | + base_params: Dict[str, Any] = { |
| 118 | + "model": "dummy", # Not used by VLLMPolicy, but needed for config |
| 119 | + "temperature": getattr(args, "temperature", 1.0), |
| 120 | + "max_tokens": getattr(args, "max_completion_length", 100), |
| 121 | + } |
| 122 | + if completion_params: |
| 123 | + base_params.update(completion_params) |
| 124 | + |
| 125 | + print( |
| 126 | + f"[OpenEnvVLLM] Temperature={base_params['temperature']}, " |
| 127 | + f"max_tokens={base_params['max_tokens']}", |
| 128 | + flush=True, |
| 129 | + ) |
| 130 | + print("[OpenEnvVLLM] Using TRL VLLMClient from trainer", flush=True) |
| 131 | + |
| 132 | + max_concurrency = concurrency if concurrency is not None else getattr( |
| 133 | + args, "per_device_train_batch_size", 1 |
| 134 | + ) |
| 135 | + print( |
| 136 | + f"[OpenEnvVLLM] Max concurrency={max_concurrency}, " |
| 137 | + f"max_steps={max_steps}", |
| 138 | + flush=True, |
| 139 | + ) |
| 140 | + |
| 141 | + config = RolloutProcessorConfig( |
| 142 | + completion_params=base_params, |
| 143 | + mcp_config_path="", |
| 144 | + semaphore=asyncio.Semaphore(max_concurrency), |
| 145 | + steps=max_steps, |
| 146 | + ) |
| 147 | + |
| 148 | + # 3) Execute rollouts with VLLMPolicy |
| 149 | + print( |
| 150 | + f"[OpenEnvVLLM] Instantiating processor: " |
| 151 | + f"{processor_cls.__name__ if processor_cls else 'OpenEnvRolloutProcessor'}", |
| 152 | + flush=True, |
| 153 | + ) |
| 154 | + |
| 155 | + # Create policy factory that uses trainer's vllm_client |
| 156 | + def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs): |
| 157 | + """Factory that creates VLLMPolicy using trainer's vllm_client.""" |
| 158 | + return VLLMPolicy( |
| 159 | + vllm_client=trainer.vllm_client, # Use trainer's vLLM client! |
| 160 | + tokenizer=processing_class, # Pass tokenizer for decoding |
| 161 | + temperature=temperature, |
| 162 | + max_tokens=max_tokens, |
| 163 | + top_p=kwargs.get("top_p"), |
| 164 | + top_k=kwargs.get("top_k"), |
| 165 | + **kwargs, |
| 166 | + ) |
| 167 | + |
| 168 | + Processor = processor_cls or OpenEnvRolloutProcessor |
| 169 | + _kwargs: Dict[str, Any] = dict(processor_kwargs or {}) |
| 170 | + _kwargs.setdefault("env_factory", env_factory) |
| 171 | + _kwargs.setdefault("prompt_builder", prompt_builder) |
| 172 | + _kwargs.setdefault("action_parser", action_parser) |
| 173 | + _kwargs.setdefault("policy_factory", vllm_policy_factory) # Pass VLLMPolicy factory! |
| 174 | + _kwargs.setdefault("env_client_cls", env_client_cls) |
| 175 | + |
| 176 | + # Rotate tasks across rollout_func calls so each GRPO step |
| 177 | + # primarily targets a different task, while keeping all |
| 178 | + # generations within a step on the same task. |
| 179 | + rotated_tasks = tasks |
| 180 | + if tasks: |
| 181 | + nonlocal task_cycle_index |
| 182 | + offset = task_cycle_index % len(tasks) |
| 183 | + rotated_tasks = tasks[offset:] + tasks[:offset] |
| 184 | + task_cycle_index = (task_cycle_index + 1) % len(tasks) |
| 185 | + print( |
| 186 | + f"[OpenEnvVLLM] Task rotation offset={offset}, rotated={rotated_tasks}", |
| 187 | + flush=True, |
| 188 | + ) |
| 189 | + _kwargs.setdefault("tasks", rotated_tasks) |
| 190 | + |
| 191 | + _kwargs.setdefault("miniwob_url", miniwob_url) |
| 192 | + _kwargs.setdefault("docker_image", docker_image) |
| 193 | + _kwargs.setdefault("env_base_url", env_base_url) |
| 194 | + _kwargs.setdefault("request_timeout_s", request_timeout_s) |
| 195 | + _kwargs.setdefault("default_headers", default_headers) |
| 196 | + _kwargs.setdefault("provider", provider) |
| 197 | + _kwargs.setdefault("docker_port", docker_port) |
| 198 | + _kwargs.setdefault("env_vars", env_vars) |
| 199 | + _kwargs.setdefault("benchmark", benchmark) |
| 200 | + _kwargs.setdefault("headless", headless) |
| 201 | + _kwargs.setdefault("viewport_width", viewport_width) |
| 202 | + _kwargs.setdefault("viewport_height", viewport_height) |
| 203 | + _kwargs.setdefault("timeout_ms", timeout_ms) |
| 204 | + _kwargs.setdefault("num_generations", num_generations) |
| 205 | + |
| 206 | + processor = Processor(**_kwargs) |
| 207 | + print(f"[OpenEnvVLLM] Processor instantiated successfully", flush=True) |
| 208 | + |
| 209 | + loop = asyncio.new_event_loop() |
| 210 | + asyncio.set_event_loop(loop) |
| 211 | + try: |
| 212 | + async def _run_all(): |
| 213 | + tasks_list = processor(evaluation_rows, config) |
| 214 | + return await asyncio.gather(*tasks_list) |
| 215 | + |
| 216 | + completed_rows = loop.run_until_complete(_run_all()) |
| 217 | + print( |
| 218 | + f"[OpenEnvVLLM] All rollouts completed: {len(completed_rows)} results", |
| 219 | + flush=True, |
| 220 | + ) |
| 221 | + finally: |
| 222 | + loop.close() |
| 223 | + |
| 224 | + # 4) Convert to Wordle-style format (no splitting) |
| 225 | + # Each completed_row is one rollout with multiple turns |
| 226 | + # We .extend() tokens across turns, then .append() per rollout |
| 227 | + print( |
| 228 | + f"[OpenEnvVLLM] Converting {len(completed_rows)} rollouts to TRL format", |
| 229 | + flush=True, |
| 230 | + ) |
| 231 | + |
| 232 | + tokenizer = getattr(processing_class, "tokenizer", None) or processing_class |
| 233 | + encode_fn = getattr(tokenizer, "encode", None) |
| 234 | + |
| 235 | + episode_prompt_ids: List[List[int]] = [] |
| 236 | + episode_completion_ids: List[List[int]] = [] |
| 237 | + episode_logprobs: List[List[float]] = [] |
| 238 | + step_rewards_all: List[List[float]] = [] |
| 239 | + |
| 240 | + for idx, row in enumerate(completed_rows): |
| 241 | + # Accumulate tokens across all turns in this rollout |
| 242 | + prompt_ids: List[int] = [] # .extend() for each turn |
| 243 | + completion_ids: List[int] = [] # .extend() for each turn |
| 244 | + logprobs: List[float] = [] # .extend() for each turn |
| 245 | + rewards: List[float] = [] |
| 246 | + |
| 247 | + # Go through all messages and accumulate tokens |
| 248 | + for msg in row.messages: |
| 249 | + if msg.role == "user": |
| 250 | + tokens = encode_fn(msg.content or "") if encode_fn else [] |
| 251 | + prompt_ids.extend(tokens) # Accumulate user tokens |
| 252 | + elif msg.role == "assistant": |
| 253 | + tokens = encode_fn(msg.content or "") if encode_fn else [] |
| 254 | + completion_ids.extend(tokens) # Accumulate assistant tokens |
| 255 | + logprobs.extend([0.0] * len(tokens)) # Placeholder logprobs |
| 256 | + elif msg.role == "system": |
| 257 | + # Extract step rewards |
| 258 | + try: |
| 259 | + content = msg.content or "" |
| 260 | + if isinstance(content, str) and content.startswith("__ep_step_rewards__:"): |
| 261 | + import json |
| 262 | + payload = content.split(":", 1)[1] |
| 263 | + rewards = json.loads(payload) or [] |
| 264 | + except Exception: |
| 265 | + pass |
| 266 | + |
| 267 | + # Fallback for rewards |
| 268 | + if not rewards and hasattr(row.execution_metadata, "extra"): |
| 269 | + try: |
| 270 | + rewards = row.execution_metadata.extra.get("step_rewards", []) or [] |
| 271 | + except Exception: |
| 272 | + pass |
| 273 | + |
| 274 | + # Append accumulated tokens for this episode |
| 275 | + episode_prompt_ids.append(prompt_ids if prompt_ids else [0]) |
| 276 | + episode_completion_ids.append(completion_ids if completion_ids else [0]) |
| 277 | + episode_logprobs.append(logprobs if logprobs else [0.0]) |
| 278 | + step_rewards_all.append(rewards if rewards else [0.0]) |
| 279 | + |
| 280 | + total_reward = sum(sum(r) for r in step_rewards_all) |
| 281 | + avg_reward = total_reward / len(step_rewards_all) if step_rewards_all else 0.0 |
| 282 | + print( |
| 283 | + f"[OpenEnvVLLM] Total reward={total_reward:.2f}, Avg reward={avg_reward:.2f}", |
| 284 | + flush=True, |
| 285 | + ) |
| 286 | + print( |
| 287 | + f"[OpenEnvVLLM] Returning {len(episode_prompt_ids)} episodes", flush=True |
| 288 | + ) |
| 289 | + sys.stdout.flush() |
| 290 | + |
| 291 | + # Return in Wordle format |
| 292 | + # Tokens: 2D arrays (accumulate across turns, one list per episode) |
| 293 | + # Rewards: 1D arrays (one scalar per episode) |
| 294 | + total_rewards = [sum(r) for r in step_rewards_all] # Sum step rewards per episode |
| 295 | + |
| 296 | + print(f"[OpenEnvVLLM] Episode rewards: {total_rewards}", flush=True) |
| 297 | + |
| 298 | + return { |
| 299 | + "prompt_ids": episode_prompt_ids, # List[List[int]] - tokens per episode |
| 300 | + "completion_ids": episode_completion_ids, # List[List[int]] - tokens per episode |
| 301 | + "logprobs": episode_logprobs, # List[List[float]] - logprobs per episode |
| 302 | + "step_rewards": total_rewards, # List[float] - total reward per episode (1D!) |
| 303 | + } |
| 304 | + |
| 305 | + print(f"[openenv_trl_vllm] Returning rollout_func (type={type(rollout_func)})", flush=True) |
| 306 | + sys.stdout.flush() |
| 307 | + return rollout_func |
| 308 | + |
0 commit comments