|
| 1 | +""" |
| 2 | +Lightweight vLLM + OpenEnv Integration |
| 3 | +
|
| 4 | +Simplified integration using vLLM for inference with proper multi-turn completion splitting. |
| 5 | +No Fireworks inference, no hot reload - just vLLM. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import asyncio |
| 11 | +from typing import Any, Callable, Dict, List, Optional, Type |
| 12 | + |
| 13 | +from eval_protocol.models import EvaluationRow, InputMetadata, Message |
| 14 | +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor |
| 15 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
| 16 | +from eval_protocol.utils.evaluation_row_utils import ( |
| 17 | + filter_longest_conversation, |
| 18 | + multi_turn_assistant_to_ground_truth, |
| 19 | + assistant_to_ground_truth, |
| 20 | +) |
| 21 | +from trl import GRPOConfig |
| 22 | + |
| 23 | + |
| 24 | +def create_openenv_vllm_rollout_func( |
| 25 | + env_factory: Callable[[], Any] | None, |
| 26 | + prompt_builder: Callable[[Any, int, list[str]], Any], |
| 27 | + action_parser: Callable[[str], Any], |
| 28 | + vllm_base_url: str = "http://localhost:8000", |
| 29 | + max_steps: int = 8, |
| 30 | + *, |
| 31 | + split_mode: str = "multi_turn", # "multi_turn", "last_turn", "longest", or None |
| 32 | + completion_params: Dict[str, Any] | None = None, |
| 33 | + concurrency: int | None = None, |
| 34 | + processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor, |
| 35 | + processor_kwargs: Optional[Dict[str, Any]] = None, |
| 36 | + # Environment configuration |
| 37 | + env_client_cls: Optional[Type[Any]] = None, |
| 38 | + tasks: List[str] | None = None, |
| 39 | + miniwob_url: str | None = None, |
| 40 | + docker_image: str = "browsergym-env:latest", |
| 41 | + env_base_url: Optional[str] = None, |
| 42 | + request_timeout_s: float = 15.0, |
| 43 | + default_headers: Optional[Dict[str, str]] = None, |
| 44 | + provider: Any | None = None, |
| 45 | + docker_port: Optional[int] = None, |
| 46 | + env_vars: Optional[Dict[str, str]] = None, |
| 47 | + benchmark: str = "miniwob", |
| 48 | + headless: bool = True, |
| 49 | + viewport_width: int = 1280, |
| 50 | + viewport_height: int = 720, |
| 51 | + timeout_ms: int = 10000, |
| 52 | +): |
| 53 | + """ |
| 54 | + Build a TRL-compatible rollout_func using vLLM inference with OpenEnv. |
| 55 | + |
| 56 | + This is a lightweight version that: |
| 57 | + - Uses vLLM client directly (no Fireworks, no hot reload) |
| 58 | + - Properly splits completions using evaluation_row_utils helpers |
| 59 | + - Works with TRL's GRPO trainer |
| 60 | + |
| 61 | + Args: |
| 62 | + env_factory: Callable yielding an OpenEnv HTTPEnvClient instance |
| 63 | + prompt_builder: (observation, step, history) -> content for LLM |
| 64 | + action_parser: (llm_response: str) -> env action object |
| 65 | + vllm_base_url: Base URL for vLLM server (e.g., "http://localhost:8000") |
| 66 | + max_steps: Maximum environment steps per rollout |
| 67 | + split_mode: How to split completions: |
| 68 | + - "multi_turn": Split each assistant message as separate row (multi_turn_assistant_to_ground_truth) |
| 69 | + - "last_turn": Extract last assistant message as ground truth (assistant_to_ground_truth) |
| 70 | + - "longest": Keep only longest conversation (filter_longest_conversation) |
| 71 | + - None: No splitting, return all rows as-is |
| 72 | + completion_params: Extra completion parameters (temperature, max_tokens, etc.) |
| 73 | + concurrency: Max concurrent rollouts (defaults to per_device_train_batch_size) |
| 74 | + processor_cls: Rollout processor class (default: OpenEnvRolloutProcessor) |
| 75 | + processor_kwargs: Extra kwargs for processor |
| 76 | + env_client_cls: Environment client class |
| 77 | + tasks: List of task names to rotate through |
| 78 | + miniwob_url: MiniWoB base URL |
| 79 | + docker_image: Docker image for environments |
| 80 | + env_base_url: Direct HTTP connection to existing server |
| 81 | + request_timeout_s: HTTP timeout |
| 82 | + default_headers: HTTP headers |
| 83 | + provider: Docker provider |
| 84 | + docker_port: Host port binding |
| 85 | + env_vars: Environment variables for container |
| 86 | + benchmark: BrowserGym benchmark name |
| 87 | + headless: Headless browser mode |
| 88 | + viewport_width/height: Browser viewport size |
| 89 | + timeout_ms: Action timeout |
| 90 | + |
| 91 | + Returns: |
| 92 | + rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List] |
| 93 | + |
| 94 | + Example: |
| 95 | + ```python |
| 96 | + from trl import GRPOConfig, GRPOTrainer |
| 97 | + from trl.extras.vllm_client import VLLMClient |
| 98 | + from envs.browsergym_env import BrowserGymEnv, BrowserGymAction |
| 99 | + |
| 100 | + # Start vLLM server first: |
| 101 | + # CUDA_VISIBLE_DEVICES=0,1 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 |
| 102 | + |
| 103 | + def make_env(): |
| 104 | + return BrowserGymEnv.from_docker_image( |
| 105 | + "browsergym-env:latest", |
| 106 | + env_vars={"BROWSERGYM_BENCHMARK": "miniwob"} |
| 107 | + ) |
| 108 | + |
| 109 | + def build_prompt(obs, step, history): |
| 110 | + return f"Step {step}\\nGoal: {obs.goal}\\n{obs.text[:500]}" |
| 111 | + |
| 112 | + def parse_action(text): |
| 113 | + return BrowserGymAction(action_str=text) |
| 114 | + |
| 115 | + rollout_func = create_openenv_vllm_rollout_func( |
| 116 | + env_factory=make_env, |
| 117 | + prompt_builder=build_prompt, |
| 118 | + action_parser=parse_action, |
| 119 | + vllm_base_url="http://localhost:8000", |
| 120 | + tasks=["click-test", "click-button", "enter-text"], |
| 121 | + split_mode="multi_turn", # Split each turn for training |
| 122 | + ) |
| 123 | + |
| 124 | + training_args = GRPOConfig( |
| 125 | + output_dir="outputs/vllm-training", |
| 126 | + per_device_train_batch_size=2, |
| 127 | + num_generations=4, |
| 128 | + ) |
| 129 | + |
| 130 | + trainer = GRPOTrainer( |
| 131 | + model="Qwen/Qwen2.5-7B", |
| 132 | + args=training_args, |
| 133 | + train_dataset=dataset, |
| 134 | + rollout_func=rollout_func, |
| 135 | + ) |
| 136 | + ``` |
| 137 | + """ |
| 138 | + |
| 139 | + # Import vLLM client (will be used for generation) |
| 140 | + try: |
| 141 | + from trl.extras.vllm_client import VLLMClient |
| 142 | + except ImportError: |
| 143 | + raise ImportError( |
| 144 | + "vLLM client not available. Install with: pip install trl[vllm]" |
| 145 | + ) |
| 146 | + |
| 147 | + # Initialize vLLM client |
| 148 | + vllm_client = VLLMClient(base_url=vllm_base_url) |
| 149 | + |
| 150 | + def rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List]: |
| 151 | + """ |
| 152 | + Execute rollouts and return TRL-compatible results. |
| 153 | + |
| 154 | + Flow: |
| 155 | + 1. Prompts → EvaluationRows (num_generations per prompt) |
| 156 | + 2. Execute rollouts via OpenEnvRolloutProcessor |
| 157 | + 3. Split completions using evaluation_row_utils |
| 158 | + 4. Generate completions via vLLM for each split row |
| 159 | + 5. Convert to TRL format |
| 160 | + """ |
| 161 | + num_generations = getattr(args, "num_generations", 8) |
| 162 | + |
| 163 | + # 1) Build evaluation rows (one per generation per prompt) |
| 164 | + evaluation_rows: List[EvaluationRow] = [] |
| 165 | + for prompt in prompts: |
| 166 | + for gen_idx in range(num_generations): |
| 167 | + evaluation_rows.append( |
| 168 | + EvaluationRow( |
| 169 | + messages=[Message(role="user", content=prompt)], |
| 170 | + input_metadata=InputMetadata( |
| 171 | + completion_params={}, |
| 172 | + extra={"generation_idx": gen_idx} |
| 173 | + ), |
| 174 | + ) |
| 175 | + ) |
| 176 | + |
| 177 | + # 2) Build processor config |
| 178 | + base_params: Dict[str, Any] = { |
| 179 | + "temperature": getattr(args, "temperature", 1.0), |
| 180 | + "max_tokens": getattr(args, "max_completion_length", 100), |
| 181 | + } |
| 182 | + if completion_params: |
| 183 | + base_params.update(completion_params) |
| 184 | + |
| 185 | + max_concurrency = concurrency if concurrency is not None else getattr( |
| 186 | + args, "per_device_train_batch_size", 1 |
| 187 | + ) |
| 188 | + |
| 189 | + config = RolloutProcessorConfig( |
| 190 | + completion_params=base_params, |
| 191 | + mcp_config_path="", |
| 192 | + semaphore=asyncio.Semaphore(max_concurrency), |
| 193 | + steps=max_steps, |
| 194 | + ) |
| 195 | + |
| 196 | + # 3) Execute rollouts using OpenEnvRolloutProcessor |
| 197 | + Processor = processor_cls or OpenEnvRolloutProcessor |
| 198 | + _kwargs: Dict[str, Any] = dict(processor_kwargs or {}) |
| 199 | + _kwargs.setdefault("env_factory", env_factory) |
| 200 | + _kwargs.setdefault("prompt_builder", prompt_builder) |
| 201 | + _kwargs.setdefault("action_parser", action_parser) |
| 202 | + _kwargs.setdefault("env_client_cls", env_client_cls) |
| 203 | + _kwargs.setdefault("tasks", tasks) |
| 204 | + _kwargs.setdefault("miniwob_url", miniwob_url) |
| 205 | + _kwargs.setdefault("docker_image", docker_image) |
| 206 | + _kwargs.setdefault("env_base_url", env_base_url) |
| 207 | + _kwargs.setdefault("request_timeout_s", request_timeout_s) |
| 208 | + _kwargs.setdefault("default_headers", default_headers) |
| 209 | + _kwargs.setdefault("provider", provider) |
| 210 | + _kwargs.setdefault("docker_port", docker_port) |
| 211 | + _kwargs.setdefault("env_vars", env_vars) |
| 212 | + _kwargs.setdefault("benchmark", benchmark) |
| 213 | + _kwargs.setdefault("headless", headless) |
| 214 | + _kwargs.setdefault("viewport_width", viewport_width) |
| 215 | + _kwargs.setdefault("viewport_height", viewport_height) |
| 216 | + _kwargs.setdefault("timeout_ms", timeout_ms) |
| 217 | + _kwargs.setdefault("num_generations", num_generations) |
| 218 | + |
| 219 | + processor = Processor(**_kwargs) |
| 220 | + |
| 221 | + loop = asyncio.new_event_loop() |
| 222 | + asyncio.set_event_loop(loop) |
| 223 | + try: |
| 224 | + async def _run_all(): |
| 225 | + tasks = processor(evaluation_rows, config) |
| 226 | + return await asyncio.gather(*tasks) |
| 227 | + |
| 228 | + completed_rows = loop.run_until_complete(_run_all()) |
| 229 | + finally: |
| 230 | + loop.close() |
| 231 | + |
| 232 | + # 4) Split completions based on split_mode |
| 233 | + if split_mode == "multi_turn": |
| 234 | + # Split each assistant message into separate rows |
| 235 | + split_rows = multi_turn_assistant_to_ground_truth(completed_rows) |
| 236 | + elif split_mode == "last_turn": |
| 237 | + # Extract last assistant message as ground truth |
| 238 | + split_rows = assistant_to_ground_truth(completed_rows) |
| 239 | + elif split_mode == "longest": |
| 240 | + # Keep only longest conversation per rollout_id |
| 241 | + split_rows = filter_longest_conversation(completed_rows) |
| 242 | + elif split_mode is None: |
| 243 | + # No splitting |
| 244 | + split_rows = completed_rows |
| 245 | + else: |
| 246 | + raise ValueError( |
| 247 | + f"Invalid split_mode: {split_mode}. " |
| 248 | + "Must be 'multi_turn', 'last_turn', 'longest', or None" |
| 249 | + ) |
| 250 | + |
| 251 | + print(f"[OpenEnvVLLM] Split {len(completed_rows)} rows → {len(split_rows)} rows (mode={split_mode})") |
| 252 | + |
| 253 | + # 5) Generate completions via vLLM for each split row |
| 254 | + # Build messages for vLLM chat endpoint |
| 255 | + all_messages: List[List[Dict]] = [] |
| 256 | + for row in split_rows: |
| 257 | + messages = [{"role": msg.role, "content": msg.content} for msg in row.messages] |
| 258 | + all_messages.append(messages) |
| 259 | + |
| 260 | + # Call vLLM to generate completions |
| 261 | + # Check if we have conversational format |
| 262 | + is_conversational = all_messages and isinstance(all_messages[0], list) |
| 263 | + |
| 264 | + vllm_params = { |
| 265 | + "n": 1, # One completion per split row |
| 266 | + "temperature": base_params["temperature"], |
| 267 | + "max_tokens": base_params["max_tokens"], |
| 268 | + } |
| 269 | + |
| 270 | + # Add any extra vLLM parameters from completion_params |
| 271 | + if completion_params: |
| 272 | + for key in ["top_p", "top_k", "min_p", "repetition_penalty"]: |
| 273 | + if key in completion_params: |
| 274 | + vllm_params[key] = completion_params[key] |
| 275 | + |
| 276 | + if is_conversational: |
| 277 | + print(f"[OpenEnvVLLM] Calling vLLM chat endpoint with {len(all_messages)} conversations") |
| 278 | + vllm_response = vllm_client.chat( |
| 279 | + messages=all_messages, |
| 280 | + **vllm_params, |
| 281 | + ) |
| 282 | + else: |
| 283 | + # Convert messages to prompts for generate endpoint |
| 284 | + prompts_for_vllm = [] |
| 285 | + for msgs in all_messages: |
| 286 | + # Simple concatenation (you may want to use a chat template here) |
| 287 | + prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in msgs) |
| 288 | + prompts_for_vllm.append(prompt_text) |
| 289 | + |
| 290 | + print(f"[OpenEnvVLLM] Calling vLLM generate endpoint with {len(prompts_for_vllm)} prompts") |
| 291 | + vllm_response = vllm_client.generate( |
| 292 | + prompts=prompts_for_vllm, |
| 293 | + **vllm_params, |
| 294 | + ) |
| 295 | + |
| 296 | + # 6) Convert to TRL format |
| 297 | + prompt_ids = vllm_response["prompt_ids"] |
| 298 | + completion_ids = vllm_response["completion_ids"] |
| 299 | + logprobs = vllm_response["logprobs"] |
| 300 | + |
| 301 | + # Extract step rewards from completed rows |
| 302 | + step_rewards: List[List[float]] = [] |
| 303 | + for row in split_rows: |
| 304 | + rewards: List[float] = [] |
| 305 | + |
| 306 | + # Look for rewards in system messages (sentinel pattern) |
| 307 | + for msg in row.messages: |
| 308 | + if msg.role == "system": |
| 309 | + try: |
| 310 | + content = msg.content or "" |
| 311 | + if isinstance(content, str) and content.startswith("__ep_step_rewards__:"): |
| 312 | + import json |
| 313 | + payload = content.split(":", 1)[1] |
| 314 | + rewards = json.loads(payload) or [] |
| 315 | + break |
| 316 | + except Exception: |
| 317 | + pass |
| 318 | + |
| 319 | + # Fallback to execution metadata |
| 320 | + if not rewards and hasattr(row.execution_metadata, "extra"): |
| 321 | + try: |
| 322 | + rewards = row.execution_metadata.extra.get("step_rewards", []) or [] |
| 323 | + except Exception: |
| 324 | + pass |
| 325 | + |
| 326 | + step_rewards.append(rewards if rewards else [0.0]) |
| 327 | + |
| 328 | + # Compute statistics |
| 329 | + total_reward = sum(sum(r) for r in step_rewards) |
| 330 | + avg_reward = total_reward / len(step_rewards) if step_rewards else 0.0 |
| 331 | + print(f"[OpenEnvVLLM] Total reward: {total_reward:.2f}, Avg: {avg_reward:.2f}") |
| 332 | + |
| 333 | + # TRL expects prompt_ids at unique-prompt level (not per-generation) |
| 334 | + # Deduplicate while preserving order |
| 335 | + seen_prompts = set() |
| 336 | + prompt_ids_unique = [] |
| 337 | + for p_ids in prompt_ids: |
| 338 | + p_tuple = tuple(p_ids) |
| 339 | + if p_tuple not in seen_prompts: |
| 340 | + seen_prompts.add(p_tuple) |
| 341 | + prompt_ids_unique.append(p_ids) |
| 342 | + |
| 343 | + return { |
| 344 | + "prompt_ids": prompt_ids_unique, |
| 345 | + "completion_ids": completion_ids, |
| 346 | + "logprobs": logprobs, |
| 347 | + "step_rewards": step_rewards, |
| 348 | + } |
| 349 | + |
| 350 | + return rollout_func |
| 351 | + |
0 commit comments