From 17e9c253bda6e24cf2bc6fe85002c91af40149be Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 11 Feb 2026 00:45:47 +0530 Subject: [PATCH 01/21] init --- tests/test_cli_agent_env_gateway.py | 254 ++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 tests/test_cli_agent_env_gateway.py diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py new file mode 100644 index 000000000..cbb05a928 --- /dev/null +++ b/tests/test_cli_agent_env_gateway.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import httpx +import pytest +from datasets import Dataset + +import verifiers as vf +import verifiers.envs.experimental.cli_agent_env as cli_agent_env + +pytestmark = [pytest.mark.integration, pytest.mark.environments] + + +class FakeTunnel: + instances: list["FakeTunnel"] = [] + + def __init__(self, local_port: int, log_level: str | None = None): + self.local_port = local_port + self.log_level = log_level + self.url: str | None = None + self.start_calls = 0 + self.stop_calls = 0 + FakeTunnel.instances.append(self) + + async def start(self) -> str: + self.start_calls += 1 + self.url = "https://unit-test.tunnel.prime.ai" + return self.url + + def sync_stop(self) -> None: + self.stop_calls += 1 + + +class GatewayCliAgentEnv(vf.CliAgentEnv): + async def post_rollout(self, state: vf.State): + state["reward"] = 1.0 + state["test_output"] = "ok" + + +def _build_gateway_transport(tracker: dict) -> httpx.MockTransport: + trajectory = [ + { + "prompt": [{"role": "user", "content": "Hello"}], + "completion": [{"role": "assistant", "content": "reply-1"}], + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + { + "prompt": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + ], + "completion": [{"role": "assistant", "content": "reply-2"}], + "tokens": { + "prompt_ids": [1, 2, 3, 4], + "prompt_mask": [0, 0, 0, 0], + "completion_ids": [5], + "completion_mask": [1], + "completion_logprobs": [-0.2], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + { + "prompt": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + {"role": "assistant", "content": "reply-2"}, + {"role": "user", "content": "Turn 3"}, + ], + "completion": [{"role": "assistant", "content": "reply-3"}], + "tokens": { + "prompt_ids": [1, 2, 3, 4, 5, 6], + "prompt_mask": [0, 0, 0, 0, 0, 0], + "completion_ids": [7], + "completion_mask": [1], + "completion_logprobs": [-0.3], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + ] + + def _handler(request: httpx.Request) -> httpx.Response: + tracker["hosts"].add(request.url.host) + tracker["paths"].append(request.url.path) + path = request.url.path + + if request.method == "POST" and path.endswith("/register"): + payload = json.loads(request.content.decode("utf-8")) + tracker["register_payload"] = payload + tracker["rollout_id"] = path.split("/")[-2] + return httpx.Response(status_code=200, json={"status": "active"}) + + if request.method == "POST" and path.endswith("/unregister"): + tracker["unregister_calls"] += 1 + return httpx.Response(status_code=200, json={"status": "active"}) + + if request.method == "GET" and path.endswith("/trajectory"): + tracker["trajectory_calls"] += 1 + return httpx.Response( + status_code=200, + json={ + "rollout_id": tracker["rollout_id"], + "status": "completed", + "num_turns": 3, + "model": "Qwen/Qwen3-0.6B", + "prompt": trajectory[0]["prompt"], + "completion": [ + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + {"role": "assistant", "content": "reply-2"}, + {"role": "user", "content": "Turn 3"}, + {"role": "assistant", "content": "reply-3"}, + ], + "is_truncated": False, + "trajectory": trajectory, + }, + ) + + return httpx.Response(status_code=404, json={"error": f"Unhandled path {path}"}) + + return httpx.MockTransport(_handler) + + +@pytest.mark.asyncio +async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): + FakeTunnel.instances.clear() + monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel) + + tracker = { + "paths": [], + "hosts": set(), + "register_payload": None, + "rollout_id": None, + "trajectory_calls": 0, + "unregister_calls": 0, + } + transport = _build_gateway_transport(tracker) + real_async_client = httpx.AsyncClient + + def _client_factory(*args, **kwargs): + kwargs["transport"] = transport + return real_async_client(*args, **kwargs) + + monkeypatch.setattr(cli_agent_env.httpx, "AsyncClient", _client_factory) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + max_turns=10, + timeout_seconds=30.0, + ) + + env.sandbox_client.create = AsyncMock(return_value=SimpleNamespace(id="sb-123")) + env.sandbox_client.wait_for_creation = AsyncMock(return_value=None) + env.sandbox_client.start_background_job = AsyncMock( + return_value=SimpleNamespace(id="job-1") + ) + env.sandbox_client.get_background_job = AsyncMock( + return_value=SimpleNamespace( + completed=True, + exit_code=0, + stdout="agent ok", + stderr="", + ) + ) + env.sandbox_client.delete = AsyncMock(return_value=None) + + rollout_input = { + "prompt": [{"role": "user", "content": "Hello"}], + "answer": "", + "example_id": 0, + "task": "gateway-test", + } + client = cast(Any, SimpleNamespace(base_url="http://gateway.internal:8000/v1/")) + state = await env.rollout( + input=rollout_input, + client=client, + model="Qwen/Qwen3-0.6B", + sampling_args={"temperature": 0.7, "max_completion_tokens": 64}, + ) + + assert state.get("error") is None + assert state["gateway_url"] == "http://gateway.internal:8000" + assert state["tunnel_url"] == "https://unit-test.tunnel.prime.ai" + assert state["rollout_base_url"].startswith( + "https://unit-test.tunnel.prime.ai/v1/rollouts/" + ) + assert len(state["trajectory"]) == 3 + assert state["prompt"] == [{"role": "user", "content": "Hello"}] + assert state["completion"][-1]["content"] == "reply-3" + assert state["reward"] == 1.0 + + create_request = env.sandbox_client.create.await_args.args[0] + assert ( + create_request.environment_vars["OPENAI_BASE_URL"] + == f"https://unit-test.tunnel.prime.ai/v1/rollouts/{state['rollout_id']}" + ) + assert create_request.environment_vars["OPENAI_MODEL"] == "Qwen/Qwen3-0.6B" + + assert tracker["register_payload"]["max_turns"] == 10 + assert tracker["register_payload"]["sampling_params"]["temperature"] == 0.7 + assert tracker["register_payload"]["sampling_params"]["max_completion_tokens"] == 64 + assert tracker["trajectory_calls"] == 1 + assert tracker["unregister_calls"] == 1 + assert tracker["hosts"] == {"gateway.internal"} + + assert len(FakeTunnel.instances) == 1 + tunnel = FakeTunnel.instances[0] + assert tunnel.local_port == 8000 + assert tunnel.start_calls == 1 + assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai" + assert tunnel.start_calls == 1 + + await env.teardown_resources() + assert tunnel.stop_calls == 1 From 1e42cbd0267ef70044fa8458e91b8fba2654db72 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 11 Feb 2026 01:02:22 +0530 Subject: [PATCH 02/21] fix tunnel address --- tests/test_cli_agent_env_gateway.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py index cbb05a928..3616a5e0a 100644 --- a/tests/test_cli_agent_env_gateway.py +++ b/tests/test_cli_agent_env_gateway.py @@ -18,8 +18,14 @@ class FakeTunnel: instances: list["FakeTunnel"] = [] - def __init__(self, local_port: int, log_level: str | None = None): + def __init__( + self, + local_port: int, + local_addr: str = "127.0.0.1", + log_level: str | None = None, + ): self.local_port = local_port + self.local_addr = local_addr self.log_level = log_level self.url: str | None = None self.start_calls = 0 @@ -246,6 +252,7 @@ def _client_factory(*args, **kwargs): assert len(FakeTunnel.instances) == 1 tunnel = FakeTunnel.instances[0] assert tunnel.local_port == 8000 + assert tunnel.local_addr == "gateway.internal" assert tunnel.start_calls == 1 assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai" assert tunnel.start_calls == 1 From fcc2a839cca60bcf2c0a4e2c6ba5c4ca49578f7a Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 11 Feb 2026 22:07:30 +0530 Subject: [PATCH 03/21] bump health timeout to 120s --- verifiers/envs/experimental/cli_agent_env.py | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 5cfb9e42e..834a3b130 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -172,6 +172,30 @@ async def setup_state(self, state: State) -> State: return state + if logger.isEnabledFor(logging.DEBUG): + rollout_id = state.get("rollout_id") + logger.debug( + "rollout=%s fetched trajectory steps=%d truncated=%s", + rollout_id, + len(trajectory), + state["is_truncated"], + ) + for turn_idx, step in enumerate(trajectory): + tokens = step.get("tokens") + prompt_token_count = ( + len(tokens["prompt_ids"]) if tokens is not None else 0 + ) + completion_token_count = ( + len(tokens["completion_ids"]) if tokens is not None else 0 + ) + logger.debug( + "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d", + rollout_id, + turn_idx, + prompt_token_count, + completion_token_count, + ) + async def get_docker_image(self, state: State) -> str: """Get the Docker image for the sandbox. Override for per-task images.""" return self.docker_image From 17fb3bdf78265d917b44d8fced186a5c6a10f8da Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Thu, 12 Feb 2026 01:17:46 +0530 Subject: [PATCH 04/21] fix(env): initialize state["completion"] in Environment.init_state to prevent rubric KeyError on early rollout failures --- verifiers/envs/environment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 40ef74cbe..96126b506 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -622,6 +622,7 @@ async def init_state( state["tool_defs"] = self._normalize_tool_defs(resolved_tool_defs) or [] state["trajectory"] = [] + state["completion"] = None self._get_usage_tracker(state, create_if_missing=True) state["trajectory_id"] = uuid.uuid4().hex state["reward"] = None From 898196c5eaefee906044d61c6aec4f5f701c2265 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Thu, 19 Feb 2026 23:28:36 +0530 Subject: [PATCH 05/21] multiple tunnels --- tests/test_cli_agent_env_gateway.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py index 3616a5e0a..18a17cfc2 100644 --- a/tests/test_cli_agent_env_gateway.py +++ b/tests/test_cli_agent_env_gateway.py @@ -259,3 +259,43 @@ def _client_factory(*args, **kwargs): await env.teardown_resources() assert tunnel.stop_calls == 1 + + +@pytest.mark.asyncio +async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): + FakeTunnel.instances.clear() + monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + url_a = await env.get_tunnel_url(local_addr="10.20.0.58") + url_b = await env.get_tunnel_url(local_addr="10.20.0.59") + url_a_reuse = await env.get_tunnel_url(local_addr="10.20.0.58") + + assert url_a == "https://unit-test.tunnel.prime.ai" + assert url_b == "https://unit-test.tunnel.prime.ai" + assert url_a_reuse == url_a + + assert len(FakeTunnel.instances) == 2 + assert {t.local_addr for t in FakeTunnel.instances} == {"10.20.0.58", "10.20.0.59"} + assert sum(t.start_calls for t in FakeTunnel.instances) == 2 + + with pytest.raises( + ValueError, match="local_addr is required when multiple tunnels are active" + ): + await env.get_tunnel_url() + + await env.teardown_resources() + assert sum(t.stop_calls for t in FakeTunnel.instances) == 2 From 9e2e7354ca403b5c321e7f0e2a214d75fced828d Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Sun, 22 Feb 2026 02:37:24 +0530 Subject: [PATCH 06/21] rename: `RolloutGatewayEnv` --- verifiers/__init__.py | 3 + .../envs/experimental/rollout_gateway_env.py | 567 ++++++++++++++++++ 2 files changed, 570 insertions(+) create mode 100644 verifiers/envs/experimental/rollout_gateway_env.py diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 2cd3480d8..7c91a5600 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -67,6 +67,7 @@ "ReasoningGymEnv", "GymEnv", "CliAgentEnv", + "RolloutGatewayEnv", "HarborEnv", "MCPEnv", "BrowserEnv", @@ -121,6 +122,7 @@ "PythonEnv": "verifiers.envs.python_env:PythonEnv", "GymEnv": "verifiers.envs.experimental.gym_env:GymEnv", "CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv", + "RolloutGatewayEnv": "verifiers.envs.experimental.rollout_gateway_env:RolloutGatewayEnv", "HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv", "MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv", "ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv", @@ -160,6 +162,7 @@ def __getattr__(name: str): from typing import Any from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401 + from .envs.experimental.rollout_gateway_env import RolloutGatewayEnv # noqa: F401 from .envs.experimental.gym_env import GymEnv # noqa: F401 from .envs.experimental.harbor_env import HarborEnv # noqa: F401 from .envs.experimental.mcp_env import MCPEnv # noqa: F401 diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py new file mode 100644 index 000000000..b0fc4a08e --- /dev/null +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -0,0 +1,567 @@ +import asyncio +import logging +import time +import uuid +from typing import Any, cast +from urllib.parse import urlparse + +import httpx +from openai import AsyncOpenAI +from prime_sandboxes import ( + AdvancedConfigs, + BackgroundJob, + BackgroundJobStatus, + CreateSandboxRequest, +) +from prime_tunnel import Tunnel + +import verifiers as vf +from verifiers.envs.experimental.sandbox_mixin import SandboxMixin +from verifiers.envs.multiturn_env import MultiTurnMonitorRubric +from verifiers.types import RolloutInput, SamplingArgs, State, TrajectoryStep + +logger = logging.getLogger(__name__) + + +class RolloutGatewayEnv(SandboxMixin, vf.Environment): + """ + Environment for running full agent code inside sandboxes. + + The sandboxed agent talks directly to the rollout gateway running in the vLLM + server through a prime tunnel URL. The environment only handles sandbox + lifecycle, rollout registration, trajectory fetch, and reward computation. + """ + + def __init__( + self, + run_command: str, + gateway_port: int = 8000, + max_turns: int = -1, + timeout_seconds: float = 3600.0, + poll_interval: float = 2.0, + docker_image: str = "python:3.11-slim", + start_command: str = "tail -f /dev/null", + cpu_cores: int = 1, + memory_gb: int = 2, + disk_size_gb: int = 5, + gpu_count: int = 0, + timeout_minutes: int = 60, + environment_vars: dict[str, str] | None = None, + team_id: str | None = None, + advanced_configs: AdvancedConfigs | None = None, + labels: list[str] | None = None, + max_retries: int = 5, + base_delay: float = 0.5, + backoff_factor: float = 2.0, + max_backoff_seconds: float = 30.0, + jitter: float = 1e-3, + sandbox_client_max_workers: int = 10, + sandbox_client_max_connections: int = 100, + sandbox_client_max_keepalive_connections: int = 50, + sandbox_wait_for_creation_max_attempts: int = 120, + **kwargs, + ): + super().__init__(message_type="chat", **kwargs) + self.add_rubric(MultiTurnMonitorRubric()) + + self.init_sandbox_client( + max_retries=max_retries, + base_delay=base_delay, + backoff_factor=backoff_factor, + max_backoff_seconds=max_backoff_seconds, + jitter=jitter, + sandbox_client_max_workers=sandbox_client_max_workers, + sandbox_client_max_connections=sandbox_client_max_connections, + sandbox_client_max_keepalive_connections=sandbox_client_max_keepalive_connections, + sandbox_wait_for_creation_max_attempts=sandbox_wait_for_creation_max_attempts, + ) + + self.run_command = run_command + self.gateway_port = gateway_port + self.max_turns = max_turns + self.poll_interval = poll_interval + self.timeout_seconds = timeout_seconds + self.docker_image = docker_image + self.start_command = start_command + self.cpu_cores = cpu_cores + self.memory_gb = memory_gb + self.disk_size_gb = disk_size_gb + self.gpu_count = gpu_count + self.timeout_minutes = timeout_minutes + self.environment_vars = environment_vars + self.team_id = team_id + self.advanced_configs = advanced_configs + self.labels = labels + + self._tunnels: dict[str, Tunnel] = {} + self._tunnel_lock = asyncio.Lock() + + def _resolve_tunnel_local_addr(self, state: State) -> str: + gateway_url = cast(str, state["gateway_url"]) + parsed = urlparse(gateway_url) + host = parsed.hostname + if host is None: + raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}") + return host + + async def get_tunnel_url(self, local_addr: str | None = None) -> str: + """Get tunnel URL, starting the tunnel if needed.""" + async with self._tunnel_lock: + if local_addr is None: + if len(self._tunnels) == 1: + tunnel = next(iter(self._tunnels.values())) + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + if len(self._tunnels) == 0: + raise ValueError("local_addr is required when starting tunnel") + raise ValueError( + "local_addr is required when multiple tunnels are active" + ) + + tunnel = self._tunnels.get(local_addr) + if tunnel is None: + if logger.isEnabledFor(logging.DEBUG): + tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=local_addr, + log_level="debug", + ) + else: + tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=local_addr, + ) + url = await tunnel.start() + self._tunnels[local_addr] = tunnel + logger.debug( + "Prime Tunnel started local_addr=%s url=%s", + local_addr, + url, + ) + return url + + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + + def _resolve_gateway_url(self, state: State) -> str: + client = cast(AsyncOpenAI, state["client"]) + gateway_url = str(client.base_url).rstrip("/") + if gateway_url.endswith("/v1"): + gateway_url = gateway_url[:-3] + return gateway_url + + @staticmethod + def _tail_text(value: Any, max_chars: int = 1200) -> str: + if value is None: + return "" + text = str(value) + if len(text) <= max_chars: + return text + return text[-max_chars:] + + def _rollout_endpoint(self, state: State, suffix: str) -> str: + gateway_url = cast(str, state["gateway_url"]) + rollout_id = cast(str, state["rollout_id"]) + return f"{gateway_url}/v1/rollouts/{rollout_id}/{suffix.lstrip('/')}" + + async def _gateway_post( + self, + state: State, + suffix: str, + payload: dict[str, Any] | None = None, + ) -> dict[str, Any]: + timeout = httpx.Timeout(self.timeout_seconds) + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + self._rollout_endpoint(state, suffix), + json=payload, + ) + response.raise_for_status() + if not response.content: + return {} + return cast(dict[str, Any], response.json()) + + async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: + timeout = httpx.Timeout(self.timeout_seconds) + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.get(self._rollout_endpoint(state, suffix)) + response.raise_for_status() + return cast(dict[str, Any], response.json()) + + async def register_rollout(self, state: State) -> None: + sampling_params = dict(state.get("sampling_args") or {}) + payload = { + "model": state["model"], + "sampling_params": sampling_params, + "max_turns": self.max_turns, + "max_seq_len": self.max_seq_len, + } + await self._gateway_post(state, "register", payload) + + async def unregister_rollout(self, state: State) -> None: + await self._gateway_post(state, "unregister") + + async def fetch_trajectory(self, state: State) -> None: + data = await self._gateway_get(state, "trajectory") + raw_trajectory = cast(list[dict[str, Any]], data.get("trajectory", [])) + + trajectory: list[TrajectoryStep] = [] + for raw_step in raw_trajectory: + step = dict(raw_step) + step.setdefault("response", None) + step.setdefault("reward", None) + step.setdefault("advantage", None) + step.setdefault("is_truncated", False) + step.setdefault("trajectory_id", state.get("trajectory_id", "")) + step.setdefault("extras", {}) + trajectory.append(cast(TrajectoryStep, step)) + + state["trajectory"] = trajectory + state["prompt"] = data.get("prompt") + state["completion"] = data.get("completion") + state["is_truncated"] = bool( + data.get("is_truncated", state.get("is_truncated", False)) + ) + + if logger.isEnabledFor(logging.DEBUG): + rollout_id = state.get("rollout_id") + logger.debug( + "rollout=%s fetched trajectory steps=%d truncated=%s", + rollout_id, + len(trajectory), + state["is_truncated"], + ) + for turn_idx, step in enumerate(trajectory): + tokens = step.get("tokens") + prompt_token_count = ( + len(tokens["prompt_ids"]) if tokens is not None else 0 + ) + completion_token_count = ( + len(tokens["completion_ids"]) if tokens is not None else 0 + ) + logger.debug( + "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d", + rollout_id, + turn_idx, + prompt_token_count, + completion_token_count, + ) + + async def get_docker_image(self, state: State) -> str: + """Get the Docker image for the sandbox. Override for per-task images.""" + return self.docker_image + + async def build_env_vars(self, state: State) -> dict[str, str]: + """Build environment variables for the sandbox. Override to add custom vars.""" + env_vars = dict(self.environment_vars) if self.environment_vars else {} + env_vars["OPENAI_BASE_URL"] = cast(str, state["rollout_base_url"]) + env_vars.setdefault("OPENAI_TIMEOUT", "600") + env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") + env_vars.setdefault("HTTPX_TIMEOUT", "600") + model = state.get("model") + if model: + env_vars["OPENAI_MODEL"] = model + return env_vars + + async def post_sandbox_setup(self, state: State) -> None: + """Hook for post-sandbox setup. Override to upload files, run commands, etc.""" + pass + + async def start_agent(self, state: State) -> None: + """Start the agent command using background job.""" + sandbox_id = cast(str, state["sandbox_id"]) + background_job: BackgroundJob = await self.sandbox_client.start_background_job( + sandbox_id, + self.run_command, + ) + state["background_job"] = background_job + state["agent_start_time"] = time.time() + state["agent_completed"] = False + + async def wait_for_agent_completion(self, state: State) -> None: + """Poll for agent completion using background job API.""" + sandbox_id = state.get("sandbox_id") + background_job = state.get("background_job") + if not sandbox_id or not background_job: + state["agent_completed"] = True + return + + try: + await asyncio.wait_for( + self.poll_job_completion( + state, + cast(str, sandbox_id), + cast(BackgroundJob, background_job), + ), + timeout=self.timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + "rollout=%s sandbox=%s stage=wait_for_agent_completion timed out after %.1fs", + state.get("rollout_id"), + state.get("sandbox_id"), + self.timeout_seconds, + ) + state["agent_timed_out"] = True + finally: + state["agent_completed"] = True + + async def poll_job_completion( + self, + state: State, + sandbox_id: str, + background_job: BackgroundJob, + ) -> None: + """Poll until background job completes, capturing output.""" + while True: + status: BackgroundJobStatus = await self.sandbox_client.get_background_job( + sandbox_id, + background_job, + ) + if status.completed: + state["agent_exit_code"] = status.exit_code + state["agent_stdout"] = status.stdout + state["agent_stderr"] = status.stderr + if status.exit_code not in (None, 0): + logger.warning( + "rollout=%s sandbox=%s stage=agent_completed exit_code=%s stdout_tail=%r stderr_tail=%r", + state.get("rollout_id"), + sandbox_id, + status.exit_code, + self._tail_text(status.stdout), + self._tail_text(status.stderr), + ) + else: + logger.debug( + "rollout=%s sandbox=%s stage=agent_completed exit_code=%s", + state.get("rollout_id"), + sandbox_id, + status.exit_code, + ) + return + await asyncio.sleep(1) + + def _render_timing(self, state: State) -> None: + start_time = cast(float, state["timing"]["start_time"]) + end_time = time.time() + generation_ms = (end_time - start_time) * 1000 + state["timing"]["generation_ms"] = generation_ms + state["timing"]["total_ms"] = generation_ms + + async def rollout( + self, + input: RolloutInput, + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + state = await self.init_state(input, client, model, sampling_args) + state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}" + state["gateway_url"] = self._resolve_gateway_url(state) + rollout_id = cast(str, state["rollout_id"]) + info = cast(dict[str, Any], state.get("info") or {}) + logger.info( + "rollout=%s stage=start model=%s example_id=%s repo=%s", + rollout_id, + model, + info.get("instance_id") or info.get("example_id"), + info.get("repo_name"), + ) + + rollout_registered = False + failure_stage = "register_rollout" + error_stage: str | None = None + try: + failure_stage = "register_rollout" + await self.register_rollout(state) + rollout_registered = True + logger.debug("rollout=%s stage=register_rollout ok", rollout_id) + + failure_stage = "resolve_tunnel_local_addr" + tunnel_local_addr = self._resolve_tunnel_local_addr(state) + state["tunnel_local_addr"] = tunnel_local_addr + logger.debug( + "rollout=%s stage=resolve_tunnel_local_addr addr=%s", + rollout_id, + tunnel_local_addr, + ) + + failure_stage = "start_tunnel" + tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr) + state["tunnel_url"] = tunnel_url + state["rollout_base_url"] = ( + f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}" + ) + logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url) + + failure_stage = "build_env_vars" + env_vars = await self.build_env_vars(state) + failure_stage = "get_docker_image" + docker_image = await self.get_docker_image(state) + sandbox_request = CreateSandboxRequest( + name=cast(str, state["rollout_id"]), + docker_image=docker_image, + start_command=self.start_command, + cpu_cores=self.cpu_cores, + memory_gb=self.memory_gb, + disk_size_gb=self.disk_size_gb, + gpu_count=self.gpu_count, + timeout_minutes=self.timeout_minutes, + environment_vars=env_vars, + team_id=self.team_id, + advanced_configs=self.advanced_configs, + labels=self.labels if self.labels else [], + ) + logger.debug( + f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} " + f"docker_image={docker_image}" + ) + failure_stage = "create_sandbox" + await self.create_sandbox(state, sandbox_request) + logger.info( + "rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s", + rollout_id, + state.get("sandbox_id"), + docker_image, + ) + + failure_stage = "start_agent" + await self.start_agent(state) + logger.debug( + "rollout=%s stage=start_agent ok sandbox_id=%s", + rollout_id, + state.get("sandbox_id"), + ) + failure_stage = "wait_for_agent_completion" + await self.wait_for_agent_completion(state) + logger.debug( + "rollout=%s stage=wait_for_agent_completion ok exit_code=%s", + rollout_id, + state.get("agent_exit_code"), + ) + failure_stage = "fetch_trajectory" + await self.fetch_trajectory(state) + trajectory = cast(list[Any], state.get("trajectory") or []) + logger.info( + "rollout=%s stage=fetch_trajectory ok turns=%d truncated=%s", + rollout_id, + len(trajectory), + state.get("is_truncated", False), + ) + if len(trajectory) == 0: + logger.warning( + "rollout=%s stage=fetch_trajectory empty_trajectory agent_exit_code=%s stdout_tail=%r stderr_tail=%r", + rollout_id, + state.get("agent_exit_code"), + self._tail_text(state.get("agent_stdout")), + self._tail_text(state.get("agent_stderr")), + ) + except vf.Error as e: + error_stage = failure_stage + state["error"] = e + logger.exception( + "rollout=%s stage=%s vf_error=%s message=%s", + rollout_id, + failure_stage, + type(e).__name__, + e, + ) + except Exception as e: + error_stage = failure_stage + state["error"] = vf.InfraError(str(e)) + logger.exception( + "rollout=%s stage=%s unhandled_error=%s message=%s", + rollout_id, + failure_stage, + type(e).__name__, + e, + ) + finally: + if rollout_registered: + try: + failure_stage = "unregister_rollout" + await self.unregister_rollout(state) + except Exception as e: + logger.warning( + f"Failed to unregister rollout {state['rollout_id']}: {e}" + ) + if error_stage is None: + error_stage = failure_stage + if state.get("error") is None: + state["error"] = vf.InfraError(str(e)) + + if state.get("sandbox_id"): + try: + failure_stage = "destroy_sandbox" + await self.destroy_sandbox(state) + except Exception as e: + logger.warning( + f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}" + ) + if error_stage is None: + error_stage = failure_stage + if state.get("error") is None: + state["error"] = vf.InfraError(str(e)) + + if state.get("completion") is None: + state["completion"] = [] + state["failure_stage"] = error_stage + if state.get("error") is not None: + if state.get("stop_condition") is None: + state["stop_condition"] = ( + f"{error_stage}_error" if error_stage else "has_error" + ) + elif state.get("agent_timed_out", False): + if state.get("stop_condition") is None: + state["stop_condition"] = "agent_timeout" + else: + if state.get("stop_condition") is None: + state["stop_condition"] = "completed" + state["is_completed"] = True + self._render_timing(state) + logger.info( + "rollout=%s stage=finish stop=%s failure_stage=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s", + rollout_id, + state.get("stop_condition"), + error_stage, + state.get("sandbox_id"), + len(cast(list[Any], state.get("trajectory") or [])), + state.get("agent_exit_code"), + type(state["error"]).__name__ + if state.get("error") is not None + else None, + ) + + return state + + @vf.teardown + async def teardown_resources(self): + """Stop Prime Tunnel.""" + async with self._tunnel_lock: + tunnels = list(self._tunnels.items()) + self._tunnels = {} + for local_addr, tunnel in tunnels: + try: + tunnel.sync_stop() + logger.debug("Prime Tunnel stopped local_addr=%s", local_addr) + except Exception as e: + logger.warning( + "Error stopping Prime Tunnel local_addr=%s: %s", + local_addr, + e, + ) + + async def post_rollout(self, state: State): + """ + Override for custom post-rollout logic. For example, if sandbox state is needed for reward functions, + run computation here and cache the result in state before sandbox is destroyed. + """ + pass + + @vf.cleanup + async def destroy_sandbox(self, state: State): + """Cleanup sandbox after rollout.""" + await self.post_rollout(state) + sandbox_id = state.get("sandbox_id") + if sandbox_id: + await self.delete_sandbox(cast(str, sandbox_id)) From 6dc007835ba5cae1e01dcb6487bcc6d2738b45bf Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Sun, 22 Feb 2026 03:02:10 +0530 Subject: [PATCH 07/21] revert `CliAgentEnv` --- verifiers/envs/experimental/cli_agent_env.py | 24 -------------------- 1 file changed, 24 deletions(-) diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 834a3b130..5cfb9e42e 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -172,30 +172,6 @@ async def setup_state(self, state: State) -> State: return state - if logger.isEnabledFor(logging.DEBUG): - rollout_id = state.get("rollout_id") - logger.debug( - "rollout=%s fetched trajectory steps=%d truncated=%s", - rollout_id, - len(trajectory), - state["is_truncated"], - ) - for turn_idx, step in enumerate(trajectory): - tokens = step.get("tokens") - prompt_token_count = ( - len(tokens["prompt_ids"]) if tokens is not None else 0 - ) - completion_token_count = ( - len(tokens["completion_ids"]) if tokens is not None else 0 - ) - logger.debug( - "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d", - rollout_id, - turn_idx, - prompt_token_count, - completion_token_count, - ) - async def get_docker_image(self, state: State) -> str: """Get the Docker image for the sandbox. Override for per-task images.""" return self.docker_image From 5d446d0bd48ff3a6a5493c63a63c070551afe806 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Sun, 22 Feb 2026 03:49:34 +0530 Subject: [PATCH 08/21] rename test --- ...est_cli_agent_env_gateway.py => test_rollout_gateway_env.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/{test_cli_agent_env_gateway.py => test_rollout_gateway_env.py} (99%) diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_rollout_gateway_env.py similarity index 99% rename from tests/test_cli_agent_env_gateway.py rename to tests/test_rollout_gateway_env.py index 18a17cfc2..01545f82d 100644 --- a/tests/test_cli_agent_env_gateway.py +++ b/tests/test_rollout_gateway_env.py @@ -41,7 +41,7 @@ def sync_stop(self) -> None: self.stop_calls += 1 -class GatewayCliAgentEnv(vf.CliAgentEnv): +class GatewayCliAgentEnv(vf.RolloutGatewayEnv): async def post_rollout(self, state: vf.State): state["reward"] = 1.0 state["test_output"] = "ok" From afa451186069feb4e7534e1cad0048f1d81ce869 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Sun, 22 Feb 2026 04:08:55 +0530 Subject: [PATCH 09/21] fix tests --- tests/test_rollout_gateway_env.py | 16 ++++++++++------ .../envs/experimental/rollout_gateway_env.py | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py index 01545f82d..60a7d72b7 100644 --- a/tests/test_rollout_gateway_env.py +++ b/tests/test_rollout_gateway_env.py @@ -2,7 +2,6 @@ import json from types import SimpleNamespace -from typing import Any, cast from unittest.mock import AsyncMock import httpx @@ -10,7 +9,7 @@ from datasets import Dataset import verifiers as vf -import verifiers.envs.experimental.cli_agent_env as cli_agent_env +import verifiers.envs.experimental.rollout_gateway_env as rollout_gateway_env pytestmark = [pytest.mark.integration, pytest.mark.environments] @@ -160,7 +159,7 @@ def _handler(request: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): FakeTunnel.instances.clear() - monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel) + monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel) tracker = { "paths": [], @@ -172,12 +171,18 @@ async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): } transport = _build_gateway_transport(tracker) real_async_client = httpx.AsyncClient + client = vf.OpenAIChatCompletionsClient( + vf.ClientConfig( + api_key_var="UNIT_TEST_API_KEY", + api_base_url="http://gateway.internal:8000/v1/", + ) + ) def _client_factory(*args, **kwargs): kwargs["transport"] = transport return real_async_client(*args, **kwargs) - monkeypatch.setattr(cli_agent_env.httpx, "AsyncClient", _client_factory) + monkeypatch.setattr(rollout_gateway_env.httpx, "AsyncClient", _client_factory) dataset = Dataset.from_dict( { @@ -216,7 +221,6 @@ def _client_factory(*args, **kwargs): "example_id": 0, "task": "gateway-test", } - client = cast(Any, SimpleNamespace(base_url="http://gateway.internal:8000/v1/")) state = await env.rollout( input=rollout_input, client=client, @@ -264,7 +268,7 @@ def _client_factory(*args, **kwargs): @pytest.mark.asyncio async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): FakeTunnel.instances.clear() - monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel) + monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel) dataset = Dataset.from_dict( { diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index b0fc4a08e..3d25ae94f 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -144,7 +144,8 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str: return tunnel.url def _resolve_gateway_url(self, state: State) -> str: - client = cast(AsyncOpenAI, state["client"]) + # `state["client"]` may be a Verifiers wrapper with the raw client on `.client`. + client = getattr(state["client"], "client", state["client"]) gateway_url = str(client.base_url).rstrip("/") if gateway_url.endswith("/v1"): gateway_url = gateway_url[:-3] From 715f6fcfb689601ca6a100b5f1be0df0b24a3278 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Sun, 22 Feb 2026 04:34:59 +0530 Subject: [PATCH 10/21] simplify --- .../envs/experimental/rollout_gateway_env.py | 97 ++++--------------- 1 file changed, 19 insertions(+), 78 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index 3d25ae94f..c6a12970b 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -120,17 +120,11 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str: tunnel = self._tunnels.get(local_addr) if tunnel is None: - if logger.isEnabledFor(logging.DEBUG): - tunnel = Tunnel( - local_port=self.gateway_port, - local_addr=local_addr, - log_level="debug", - ) - else: - tunnel = Tunnel( - local_port=self.gateway_port, - local_addr=local_addr, - ) + tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=local_addr, + log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info", + ) url = await tunnel.start() self._tunnels[local_addr] = tunnel logger.debug( @@ -161,9 +155,7 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str: return text[-max_chars:] def _rollout_endpoint(self, state: State, suffix: str) -> str: - gateway_url = cast(str, state["gateway_url"]) - rollout_id = cast(str, state["rollout_id"]) - return f"{gateway_url}/v1/rollouts/{rollout_id}/{suffix.lstrip('/')}" + return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}" async def _gateway_post( self, @@ -190,7 +182,7 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: return cast(dict[str, Any], response.json()) async def register_rollout(self, state: State) -> None: - sampling_params = dict(state.get("sampling_args") or {}) + sampling_params = state.get("sampling_args") or {} payload = { "model": state["model"], "sampling_params": sampling_params, @@ -204,16 +196,16 @@ async def unregister_rollout(self, state: State) -> None: async def fetch_trajectory(self, state: State) -> None: data = await self._gateway_get(state, "trajectory") - raw_trajectory = cast(list[dict[str, Any]], data.get("trajectory", [])) + raw_trajectory = data.get("trajectory", []) trajectory: list[TrajectoryStep] = [] - for raw_step in raw_trajectory: - step = dict(raw_step) + # TODO: Pydantic response schema in gateway + for step in raw_trajectory: step.setdefault("response", None) step.setdefault("reward", None) step.setdefault("advantage", None) step.setdefault("is_truncated", False) - step.setdefault("trajectory_id", state.get("trajectory_id", "")) + step.setdefault("trajectory_id", state["trajectory_id"]) step.setdefault("extras", {}) trajectory.append(cast(TrajectoryStep, step)) @@ -224,30 +216,6 @@ async def fetch_trajectory(self, state: State) -> None: data.get("is_truncated", state.get("is_truncated", False)) ) - if logger.isEnabledFor(logging.DEBUG): - rollout_id = state.get("rollout_id") - logger.debug( - "rollout=%s fetched trajectory steps=%d truncated=%s", - rollout_id, - len(trajectory), - state["is_truncated"], - ) - for turn_idx, step in enumerate(trajectory): - tokens = step.get("tokens") - prompt_token_count = ( - len(tokens["prompt_ids"]) if tokens is not None else 0 - ) - completion_token_count = ( - len(tokens["completion_ids"]) if tokens is not None else 0 - ) - logger.debug( - "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d", - rollout_id, - turn_idx, - prompt_token_count, - completion_token_count, - ) - async def get_docker_image(self, state: State) -> str: """Get the Docker image for the sandbox. Override for per-task images.""" return self.docker_image @@ -343,8 +311,8 @@ async def poll_job_completion( await asyncio.sleep(1) def _render_timing(self, state: State) -> None: - start_time = cast(float, state["timing"]["start_time"]) - end_time = time.time() + start_time = state["timing"]["start_time"] + end_time = time.perf_counter() generation_ms = (end_time - start_time) * 1000 state["timing"]["generation_ms"] = generation_ms state["timing"]["total_ms"] = generation_ms @@ -359,8 +327,8 @@ async def rollout( state = await self.init_state(input, client, model, sampling_args) state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}" state["gateway_url"] = self._resolve_gateway_url(state) - rollout_id = cast(str, state["rollout_id"]) - info = cast(dict[str, Any], state.get("info") or {}) + rollout_id = state["rollout_id"] + info = state.get("info") or {} logger.info( "rollout=%s stage=start model=%s example_id=%s repo=%s", rollout_id, @@ -370,15 +338,11 @@ async def rollout( ) rollout_registered = False - failure_stage = "register_rollout" - error_stage: str | None = None try: - failure_stage = "register_rollout" await self.register_rollout(state) rollout_registered = True logger.debug("rollout=%s stage=register_rollout ok", rollout_id) - failure_stage = "resolve_tunnel_local_addr" tunnel_local_addr = self._resolve_tunnel_local_addr(state) state["tunnel_local_addr"] = tunnel_local_addr logger.debug( @@ -387,7 +351,6 @@ async def rollout( tunnel_local_addr, ) - failure_stage = "start_tunnel" tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr) state["tunnel_url"] = tunnel_url state["rollout_base_url"] = ( @@ -395,9 +358,7 @@ async def rollout( ) logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url) - failure_stage = "build_env_vars" env_vars = await self.build_env_vars(state) - failure_stage = "get_docker_image" docker_image = await self.get_docker_image(state) sandbox_request = CreateSandboxRequest( name=cast(str, state["rollout_id"]), @@ -417,7 +378,6 @@ async def rollout( f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} " f"docker_image={docker_image}" ) - failure_stage = "create_sandbox" await self.create_sandbox(state, sandbox_request) logger.info( "rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s", @@ -426,21 +386,18 @@ async def rollout( docker_image, ) - failure_stage = "start_agent" await self.start_agent(state) logger.debug( "rollout=%s stage=start_agent ok sandbox_id=%s", rollout_id, state.get("sandbox_id"), ) - failure_stage = "wait_for_agent_completion" await self.wait_for_agent_completion(state) logger.debug( "rollout=%s stage=wait_for_agent_completion ok exit_code=%s", rollout_id, state.get("agent_exit_code"), ) - failure_stage = "fetch_trajectory" await self.fetch_trajectory(state) trajectory = cast(list[Any], state.get("trajectory") or []) logger.info( @@ -458,60 +415,47 @@ async def rollout( self._tail_text(state.get("agent_stderr")), ) except vf.Error as e: - error_stage = failure_stage state["error"] = e logger.exception( "rollout=%s stage=%s vf_error=%s message=%s", rollout_id, - failure_stage, type(e).__name__, e, ) except Exception as e: - error_stage = failure_stage state["error"] = vf.InfraError(str(e)) logger.exception( "rollout=%s stage=%s unhandled_error=%s message=%s", rollout_id, - failure_stage, type(e).__name__, e, ) finally: if rollout_registered: try: - failure_stage = "unregister_rollout" await self.unregister_rollout(state) except Exception as e: logger.warning( f"Failed to unregister rollout {state['rollout_id']}: {e}" ) - if error_stage is None: - error_stage = failure_stage if state.get("error") is None: state["error"] = vf.InfraError(str(e)) if state.get("sandbox_id"): try: - failure_stage = "destroy_sandbox" await self.destroy_sandbox(state) except Exception as e: logger.warning( f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}" ) - if error_stage is None: - error_stage = failure_stage if state.get("error") is None: state["error"] = vf.InfraError(str(e)) if state.get("completion") is None: state["completion"] = [] - state["failure_stage"] = error_stage if state.get("error") is not None: if state.get("stop_condition") is None: - state["stop_condition"] = ( - f"{error_stage}_error" if error_stage else "has_error" - ) + state["stop_condition"] = "has_error" elif state.get("agent_timed_out", False): if state.get("stop_condition") is None: state["stop_condition"] = "agent_timeout" @@ -521,16 +465,13 @@ async def rollout( state["is_completed"] = True self._render_timing(state) logger.info( - "rollout=%s stage=finish stop=%s failure_stage=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s", + "rollout=%s stage=finish stop=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s", rollout_id, state.get("stop_condition"), - error_stage, state.get("sandbox_id"), - len(cast(list[Any], state.get("trajectory") or [])), + len(state.get("trajectory", [])), state.get("agent_exit_code"), - type(state["error"]).__name__ - if state.get("error") is not None - else None, + type(state["error"]).__name__ if state.get("error") else None, ) return state From b2d42b8e93a8798eb49834e6ed5fb2aecd5cbea0 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Mon, 23 Feb 2026 06:05:34 +0530 Subject: [PATCH 11/21] server side reponse models --- .../envs/experimental/rollout_gateway_env.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index c6a12970b..754a307dc 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -196,20 +196,7 @@ async def unregister_rollout(self, state: State) -> None: async def fetch_trajectory(self, state: State) -> None: data = await self._gateway_get(state, "trajectory") - raw_trajectory = data.get("trajectory", []) - - trajectory: list[TrajectoryStep] = [] - # TODO: Pydantic response schema in gateway - for step in raw_trajectory: - step.setdefault("response", None) - step.setdefault("reward", None) - step.setdefault("advantage", None) - step.setdefault("is_truncated", False) - step.setdefault("trajectory_id", state["trajectory_id"]) - step.setdefault("extras", {}) - trajectory.append(cast(TrajectoryStep, step)) - - state["trajectory"] = trajectory + state["trajectory"] = cast(list[TrajectoryStep], data.get("trajectory", [])) state["prompt"] = data.get("prompt") state["completion"] = data.get("completion") state["is_truncated"] = bool( From 5d8f422555eb1a86a71e19177c3bd5fe3873eb40 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Mon, 23 Feb 2026 06:05:59 +0530 Subject: [PATCH 12/21] fix slop --- .../envs/experimental/rollout_gateway_env.py | 116 +++++------------- 1 file changed, 32 insertions(+), 84 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index 754a307dc..4b4df2fe1 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -2,7 +2,7 @@ import logging import time import uuid -from typing import Any, cast +from typing import Any from urllib.parse import urlparse import httpx @@ -18,7 +18,7 @@ import verifiers as vf from verifiers.envs.experimental.sandbox_mixin import SandboxMixin from verifiers.envs.multiturn_env import MultiTurnMonitorRubric -from verifiers.types import RolloutInput, SamplingArgs, State, TrajectoryStep +from verifiers.types import RolloutInput, SamplingArgs, State logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ def __init__( self._tunnel_lock = asyncio.Lock() def _resolve_tunnel_local_addr(self, state: State) -> str: - gateway_url = cast(str, state["gateway_url"]) + gateway_url = state["gateway_url"] parsed = urlparse(gateway_url) host = parsed.hostname if host is None: @@ -127,11 +127,7 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str: ) url = await tunnel.start() self._tunnels[local_addr] = tunnel - logger.debug( - "Prime Tunnel started local_addr=%s url=%s", - local_addr, - url, - ) + logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") return url assert tunnel.url is not None, "Tunnel started but URL is None" @@ -172,14 +168,14 @@ async def _gateway_post( response.raise_for_status() if not response.content: return {} - return cast(dict[str, Any], response.json()) + return response.json() async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: timeout = httpx.Timeout(self.timeout_seconds) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get(self._rollout_endpoint(state, suffix)) response.raise_for_status() - return cast(dict[str, Any], response.json()) + return response.json() async def register_rollout(self, state: State) -> None: sampling_params = state.get("sampling_args") or {} @@ -196,7 +192,7 @@ async def unregister_rollout(self, state: State) -> None: async def fetch_trajectory(self, state: State) -> None: data = await self._gateway_get(state, "trajectory") - state["trajectory"] = cast(list[TrajectoryStep], data.get("trajectory", [])) + state["trajectory"] = data.get("trajectory", []) state["prompt"] = data.get("prompt") state["completion"] = data.get("completion") state["is_truncated"] = bool( @@ -210,7 +206,7 @@ async def get_docker_image(self, state: State) -> str: async def build_env_vars(self, state: State) -> dict[str, str]: """Build environment variables for the sandbox. Override to add custom vars.""" env_vars = dict(self.environment_vars) if self.environment_vars else {} - env_vars["OPENAI_BASE_URL"] = cast(str, state["rollout_base_url"]) + env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"] env_vars.setdefault("OPENAI_TIMEOUT", "600") env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") env_vars.setdefault("HTTPX_TIMEOUT", "600") @@ -225,8 +221,8 @@ async def post_sandbox_setup(self, state: State) -> None: async def start_agent(self, state: State) -> None: """Start the agent command using background job.""" - sandbox_id = cast(str, state["sandbox_id"]) - background_job: BackgroundJob = await self.sandbox_client.start_background_job( + sandbox_id = state["sandbox_id"] + background_job = await self.sandbox_client.start_background_job( sandbox_id, self.run_command, ) @@ -244,19 +240,12 @@ async def wait_for_agent_completion(self, state: State) -> None: try: await asyncio.wait_for( - self.poll_job_completion( - state, - cast(str, sandbox_id), - cast(BackgroundJob, background_job), - ), + self.poll_job_completion(state, sandbox_id, background_job), timeout=self.timeout_seconds, ) except asyncio.TimeoutError: logger.warning( - "rollout=%s sandbox=%s stage=wait_for_agent_completion timed out after %.1fs", - state.get("rollout_id"), - state.get("sandbox_id"), - self.timeout_seconds, + f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} stage=wait_for_agent_completion timed out after {self.timeout_seconds:.1f}s" ) state["agent_timed_out"] = True finally: @@ -280,19 +269,11 @@ async def poll_job_completion( state["agent_stderr"] = status.stderr if status.exit_code not in (None, 0): logger.warning( - "rollout=%s sandbox=%s stage=agent_completed exit_code=%s stdout_tail=%r stderr_tail=%r", - state.get("rollout_id"), - sandbox_id, - status.exit_code, - self._tail_text(status.stdout), - self._tail_text(status.stderr), + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={self._tail_text(status.stdout)!r} stderr_tail={self._tail_text(status.stderr)!r}" ) else: logger.debug( - "rollout=%s sandbox=%s stage=agent_completed exit_code=%s", - state.get("rollout_id"), - sandbox_id, - status.exit_code, + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}" ) return await asyncio.sleep(1) @@ -317,25 +298,19 @@ async def rollout( rollout_id = state["rollout_id"] info = state.get("info") or {} logger.info( - "rollout=%s stage=start model=%s example_id=%s repo=%s", - rollout_id, - model, - info.get("instance_id") or info.get("example_id"), - info.get("repo_name"), + f"rollout={rollout_id} stage=start model={model} example_id={info.get('instance_id') or info.get('example_id')} repo={info.get('repo_name')}" ) rollout_registered = False try: await self.register_rollout(state) rollout_registered = True - logger.debug("rollout=%s stage=register_rollout ok", rollout_id) + logger.debug(f"rollout={rollout_id} stage=register_rollout ok") tunnel_local_addr = self._resolve_tunnel_local_addr(state) state["tunnel_local_addr"] = tunnel_local_addr logger.debug( - "rollout=%s stage=resolve_tunnel_local_addr addr=%s", - rollout_id, - tunnel_local_addr, + f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}" ) tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr) @@ -343,12 +318,12 @@ async def rollout( state["rollout_base_url"] = ( f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}" ) - logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url) + logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}") env_vars = await self.build_env_vars(state) docker_image = await self.get_docker_image(state) sandbox_request = CreateSandboxRequest( - name=cast(str, state["rollout_id"]), + name=state["rollout_id"], docker_image=docker_image, start_command=self.start_command, cpu_cores=self.cpu_cores, @@ -367,55 +342,35 @@ async def rollout( ) await self.create_sandbox(state, sandbox_request) logger.info( - "rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s", - rollout_id, - state.get("sandbox_id"), - docker_image, + f"rollout={rollout_id} stage=create_sandbox ok sandbox_id={state.get('sandbox_id')} docker_image={docker_image}" ) await self.start_agent(state) logger.debug( - "rollout=%s stage=start_agent ok sandbox_id=%s", - rollout_id, - state.get("sandbox_id"), + f"rollout={rollout_id} stage=start_agent ok sandbox_id={state.get('sandbox_id')}" ) await self.wait_for_agent_completion(state) logger.debug( - "rollout=%s stage=wait_for_agent_completion ok exit_code=%s", - rollout_id, - state.get("agent_exit_code"), + f"rollout={rollout_id} stage=wait_for_agent_completion ok exit_code={state.get('agent_exit_code')}" ) await self.fetch_trajectory(state) - trajectory = cast(list[Any], state.get("trajectory") or []) + trajectory = state.get("trajectory") or [] logger.info( - "rollout=%s stage=fetch_trajectory ok turns=%d truncated=%s", - rollout_id, - len(trajectory), - state.get("is_truncated", False), + f"rollout={rollout_id} stage=fetch_trajectory ok turns={len(trajectory)} truncated={state.get('is_truncated', False)}" ) if len(trajectory) == 0: logger.warning( - "rollout=%s stage=fetch_trajectory empty_trajectory agent_exit_code=%s stdout_tail=%r stderr_tail=%r", - rollout_id, - state.get("agent_exit_code"), - self._tail_text(state.get("agent_stdout")), - self._tail_text(state.get("agent_stderr")), + f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={self._tail_text(state.get('agent_stdout'))!r} stderr_tail={self._tail_text(state.get('agent_stderr'))!r}" ) except vf.Error as e: state["error"] = e logger.exception( - "rollout=%s stage=%s vf_error=%s message=%s", - rollout_id, - type(e).__name__, - e, + f"rollout={rollout_id} stage={type(e).__name__} vf_error={e}" ) except Exception as e: state["error"] = vf.InfraError(str(e)) logger.exception( - "rollout=%s stage=%s unhandled_error=%s message=%s", - rollout_id, - type(e).__name__, - e, + f"rollout={rollout_id} stage={type(e).__name__} unhandled_error={e}" ) finally: if rollout_registered: @@ -451,14 +406,9 @@ async def rollout( state["stop_condition"] = "completed" state["is_completed"] = True self._render_timing(state) + error_name = type(state["error"]).__name__ if state.get("error") else None logger.info( - "rollout=%s stage=finish stop=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s", - rollout_id, - state.get("stop_condition"), - state.get("sandbox_id"), - len(state.get("trajectory", [])), - state.get("agent_exit_code"), - type(state["error"]).__name__ if state.get("error") else None, + f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}" ) return state @@ -472,12 +422,10 @@ async def teardown_resources(self): for local_addr, tunnel in tunnels: try: tunnel.sync_stop() - logger.debug("Prime Tunnel stopped local_addr=%s", local_addr) + logger.debug(f"Prime Tunnel stopped local_addr={local_addr}") except Exception as e: logger.warning( - "Error stopping Prime Tunnel local_addr=%s: %s", - local_addr, - e, + f"Error stopping Prime Tunnel local_addr={local_addr}: {e}" ) async def post_rollout(self, state: State): @@ -493,4 +441,4 @@ async def destroy_sandbox(self, state: State): await self.post_rollout(state) sandbox_id = state.get("sandbox_id") if sandbox_id: - await self.delete_sandbox(cast(str, sandbox_id)) + await self.delete_sandbox(sandbox_id) From 899ce24283f72f7e81baf918f288844f90096763 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Tue, 24 Feb 2026 07:07:00 +0530 Subject: [PATCH 13/21] refactor + ty --- .../envs/experimental/rollout_gateway_env.py | 78 +++++++++---------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index 4b4df2fe1..be8c2726b 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -6,7 +6,6 @@ from urllib.parse import urlparse import httpx -from openai import AsyncOpenAI from prime_sandboxes import ( AdvancedConfigs, BackgroundJob, @@ -16,13 +15,21 @@ from prime_tunnel import Tunnel import verifiers as vf +from verifiers.clients import Client from verifiers.envs.experimental.sandbox_mixin import SandboxMixin from verifiers.envs.multiturn_env import MultiTurnMonitorRubric -from verifiers.types import RolloutInput, SamplingArgs, State +from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State logger = logging.getLogger(__name__) +def _tail_text(value: Any, max_chars: int = 1200) -> str: + if value is None: + return "" + text = str(value) + return text[-max_chars:] if len(text) > max_chars else text + + class RolloutGatewayEnv(SandboxMixin, vf.Environment): """ Environment for running full agent code inside sandboxes. @@ -93,6 +100,9 @@ def __init__( self.advanced_configs = advanced_configs self.labels = labels + self._http_client = httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout_seconds) + ) self._tunnels: dict[str, Tunnel] = {} self._tunnel_lock = asyncio.Lock() @@ -141,15 +151,6 @@ def _resolve_gateway_url(self, state: State) -> str: gateway_url = gateway_url[:-3] return gateway_url - @staticmethod - def _tail_text(value: Any, max_chars: int = 1200) -> str: - if value is None: - return "" - text = str(value) - if len(text) <= max_chars: - return text - return text[-max_chars:] - def _rollout_endpoint(self, state: State, suffix: str) -> str: return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}" @@ -159,23 +160,19 @@ async def _gateway_post( suffix: str, payload: dict[str, Any] | None = None, ) -> dict[str, Any]: - timeout = httpx.Timeout(self.timeout_seconds) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - self._rollout_endpoint(state, suffix), - json=payload, - ) - response.raise_for_status() - if not response.content: - return {} - return response.json() + response = await self._http_client.post( + self._rollout_endpoint(state, suffix), + json=payload, + ) + response.raise_for_status() + if not response.content: + return {} + return response.json() async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: - timeout = httpx.Timeout(self.timeout_seconds) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.get(self._rollout_endpoint(state, suffix)) - response.raise_for_status() - return response.json() + response = await self._http_client.get(self._rollout_endpoint(state, suffix)) + response.raise_for_status() + return response.json() async def register_rollout(self, state: State) -> None: sampling_params = state.get("sampling_args") or {} @@ -269,16 +266,16 @@ async def poll_job_completion( state["agent_stderr"] = status.stderr if status.exit_code not in (None, 0): logger.warning( - f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={self._tail_text(status.stdout)!r} stderr_tail={self._tail_text(status.stderr)!r}" + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={_tail_text(status.stdout)!r} stderr_tail={_tail_text(status.stderr)!r}" ) else: logger.debug( f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}" ) return - await asyncio.sleep(1) + await asyncio.sleep(self.poll_interval) - def _render_timing(self, state: State) -> None: + async def _render_timing(self, state: State) -> None: start_time = state["timing"]["start_time"] end_time = time.perf_counter() generation_ms = (end_time - start_time) * 1000 @@ -288,7 +285,7 @@ def _render_timing(self, state: State) -> None: async def rollout( self, input: RolloutInput, - client: AsyncOpenAI, + client: Client | ClientConfig, model: str, sampling_args: SamplingArgs | None = None, ) -> State: @@ -334,7 +331,7 @@ async def rollout( environment_vars=env_vars, team_id=self.team_id, advanced_configs=self.advanced_configs, - labels=self.labels if self.labels else [], + labels=self.labels or [], ) logger.debug( f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} " @@ -360,7 +357,7 @@ async def rollout( ) if len(trajectory) == 0: logger.warning( - f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={self._tail_text(state.get('agent_stdout'))!r} stderr_tail={self._tail_text(state.get('agent_stderr'))!r}" + f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={_tail_text(state.get('agent_stdout'))!r} stderr_tail={_tail_text(state.get('agent_stderr'))!r}" ) except vf.Error as e: state["error"] = e @@ -385,7 +382,7 @@ async def rollout( if state.get("sandbox_id"): try: - await self.destroy_sandbox(state) + await self._cleanup(state) except Exception as e: logger.warning( f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}" @@ -395,17 +392,15 @@ async def rollout( if state.get("completion") is None: state["completion"] = [] - if state.get("error") is not None: - if state.get("stop_condition") is None: + if state.get("stop_condition") is None: + if state.get("error") is not None: state["stop_condition"] = "has_error" - elif state.get("agent_timed_out", False): - if state.get("stop_condition") is None: + elif state.get("agent_timed_out", False): state["stop_condition"] = "agent_timeout" - else: - if state.get("stop_condition") is None: + else: state["stop_condition"] = "completed" state["is_completed"] = True - self._render_timing(state) + await self._render_timing(state) error_name = type(state["error"]).__name__ if state.get("error") else None logger.info( f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}" @@ -415,7 +410,8 @@ async def rollout( @vf.teardown async def teardown_resources(self): - """Stop Prime Tunnel.""" + """Stop Prime Tunnel and close HTTP client.""" + await self._http_client.aclose() async with self._tunnel_lock: tunnels = list(self._tunnels.items()) self._tunnels = {} From d1eb28ffc25b5e42d22ec31bbd743c5200cd4609 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Tue, 24 Feb 2026 08:29:17 +0530 Subject: [PATCH 14/21] fix --- verifiers/envs/experimental/rollout_gateway_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py index be8c2726b..53a264bea 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_env.py @@ -277,7 +277,7 @@ async def poll_job_completion( async def _render_timing(self, state: State) -> None: start_time = state["timing"]["start_time"] - end_time = time.perf_counter() + end_time = time.time() generation_ms = (end_time - start_time) * 1000 state["timing"]["generation_ms"] = generation_ms state["timing"]["total_ms"] = generation_ms From d2301af4420c41df6a31c5cc70db72dfd4411feb Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 01:42:05 +0530 Subject: [PATCH 15/21] refactor: `RolloutGatewayMixin` --- tests/test_rollout_gateway_env.py | 32 +- verifiers/__init__.py | 6 +- verifiers/envs/experimental/__init__.py | 3 +- verifiers/envs/experimental/cli_agent_env.py | 17 +- ...ateway_env.py => rollout_gateway_mixin.py} | 354 +++++++++--------- 5 files changed, 206 insertions(+), 206 deletions(-) rename verifiers/envs/experimental/{rollout_gateway_env.py => rollout_gateway_mixin.py} (64%) diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py index 60a7d72b7..d16da5d35 100644 --- a/tests/test_rollout_gateway_env.py +++ b/tests/test_rollout_gateway_env.py @@ -9,7 +9,7 @@ from datasets import Dataset import verifiers as vf -import verifiers.envs.experimental.rollout_gateway_env as rollout_gateway_env +import verifiers.envs.experimental.rollout_gateway_mixin as rollout_gateway_mixin pytestmark = [pytest.mark.integration, pytest.mark.environments] @@ -40,7 +40,15 @@ def sync_stop(self) -> None: self.stop_calls += 1 -class GatewayCliAgentEnv(vf.RolloutGatewayEnv): +class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv): + def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs): + super().__init__(**kwargs) + self.use_gateway = use_gateway + if use_gateway: + self.init_gateway( + gateway_port=gateway_port, timeout_seconds=self.timeout_seconds + ) + async def post_rollout(self, state: vf.State): state["reward"] = 1.0 state["test_output"] = "ok" @@ -159,7 +167,7 @@ def _handler(request: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): FakeTunnel.instances.clear() - monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel) + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) tracker = { "paths": [], @@ -182,7 +190,7 @@ def _client_factory(*args, **kwargs): kwargs["transport"] = transport return real_async_client(*args, **kwargs) - monkeypatch.setattr(rollout_gateway_env.httpx, "AsyncClient", _client_factory) + monkeypatch.setattr(rollout_gateway_mixin.httpx, "AsyncClient", _client_factory) dataset = Dataset.from_dict( { @@ -258,17 +266,17 @@ def _client_factory(*args, **kwargs): assert tunnel.local_port == 8000 assert tunnel.local_addr == "gateway.internal" assert tunnel.start_calls == 1 - assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai" + assert await env.get_gateway_tunnel_url() == "https://unit-test.tunnel.prime.ai" assert tunnel.start_calls == 1 - await env.teardown_resources() + await env.teardown_gateway() assert tunnel.stop_calls == 1 @pytest.mark.asyncio async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): FakeTunnel.instances.clear() - monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel) + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) dataset = Dataset.from_dict( { @@ -284,9 +292,9 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): gateway_port=8000, ) - url_a = await env.get_tunnel_url(local_addr="10.20.0.58") - url_b = await env.get_tunnel_url(local_addr="10.20.0.59") - url_a_reuse = await env.get_tunnel_url(local_addr="10.20.0.58") + url_a = await env.get_gateway_tunnel_url(local_addr="10.20.0.58") + url_b = await env.get_gateway_tunnel_url(local_addr="10.20.0.59") + url_a_reuse = await env.get_gateway_tunnel_url(local_addr="10.20.0.58") assert url_a == "https://unit-test.tunnel.prime.ai" assert url_b == "https://unit-test.tunnel.prime.ai" @@ -299,7 +307,7 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): with pytest.raises( ValueError, match="local_addr is required when multiple tunnels are active" ): - await env.get_tunnel_url() + await env.get_gateway_tunnel_url() - await env.teardown_resources() + await env.teardown_gateway() assert sum(t.stop_calls for t in FakeTunnel.instances) == 2 diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 7c91a5600..e7fa269ea 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -67,7 +67,7 @@ "ReasoningGymEnv", "GymEnv", "CliAgentEnv", - "RolloutGatewayEnv", + "RolloutGatewayMixin", "HarborEnv", "MCPEnv", "BrowserEnv", @@ -122,7 +122,7 @@ "PythonEnv": "verifiers.envs.python_env:PythonEnv", "GymEnv": "verifiers.envs.experimental.gym_env:GymEnv", "CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv", - "RolloutGatewayEnv": "verifiers.envs.experimental.rollout_gateway_env:RolloutGatewayEnv", + "RolloutGatewayMixin": "verifiers.envs.experimental.rollout_gateway_mixin:RolloutGatewayMixin", "HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv", "MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv", "ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv", @@ -162,7 +162,7 @@ def __getattr__(name: str): from typing import Any from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401 - from .envs.experimental.rollout_gateway_env import RolloutGatewayEnv # noqa: F401 + from .envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin # noqa: F401 from .envs.experimental.gym_env import GymEnv # noqa: F401 from .envs.experimental.harbor_env import HarborEnv # noqa: F401 from .envs.experimental.mcp_env import MCPEnv # noqa: F401 diff --git a/verifiers/envs/experimental/__init__.py b/verifiers/envs/experimental/__init__.py index 034cb817f..ff81549e0 100644 --- a/verifiers/envs/experimental/__init__.py +++ b/verifiers/envs/experimental/__init__.py @@ -1,3 +1,4 @@ +from verifiers.envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin from verifiers.envs.experimental.sandbox_mixin import SandboxMixin -__all__ = ["SandboxMixin"] +__all__ = ["RolloutGatewayMixin", "SandboxMixin"] diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 5cfb9e42e..946308f5b 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -84,8 +84,6 @@ def __init__( ) self.run_command = run_command self.poll_interval = poll_interval - self.interception_port = interception_port - self.interception_url = interception_url self.timeout_seconds = timeout_seconds self.docker_image = docker_image self.start_command = start_command @@ -99,7 +97,16 @@ def __init__( self.advanced_configs = advanced_configs self.labels = labels - # Tunnel and interception server + self.init_interception(interception_port, interception_url) + + def init_interception( + self, + interception_port: int = 8765, + interception_url: str | None = None, + ): + """Initialize interception server and tunnel resources. Call from __init__.""" + self.interception_port = interception_port + self.interception_url = interception_url self._tunnel: Tunnel | None = None self._tunnel_lock = asyncio.Lock() self._interception_server = InterceptionServer(port=interception_port) @@ -411,6 +418,8 @@ async def add_model_response( @vf.teardown async def teardown_resources(self): """Stop Prime Tunnel and HTTP interception server.""" + if not hasattr(self, "_tunnel_lock"): + return async with self._tunnel_lock: if self._tunnel is not None: try: @@ -437,7 +446,7 @@ async def cleanup_interception_context(self, state: State): state.pop("background_job", None) rollout_id = state.get("rollout_id") - if rollout_id: + if rollout_id and hasattr(self, "_interception_server"): self._interception_server.unregister_rollout(rollout_id) @vf.stop diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_mixin.py similarity index 64% rename from verifiers/envs/experimental/rollout_gateway_env.py rename to verifiers/envs/experimental/rollout_gateway_mixin.py index 53a264bea..880e452b3 100644 --- a/verifiers/envs/experimental/rollout_gateway_env.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -6,18 +6,11 @@ from urllib.parse import urlparse import httpx -from prime_sandboxes import ( - AdvancedConfigs, - BackgroundJob, - BackgroundJobStatus, - CreateSandboxRequest, -) +from prime_sandboxes import CreateSandboxRequest from prime_tunnel import Tunnel import verifiers as vf from verifiers.clients import Client -from verifiers.envs.experimental.sandbox_mixin import SandboxMixin -from verifiers.envs.multiturn_env import MultiTurnMonitorRubric from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State logger = logging.getLogger(__name__) @@ -30,137 +23,64 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str: return text[-max_chars:] if len(text) > max_chars else text -class RolloutGatewayEnv(SandboxMixin, vf.Environment): - """ - Environment for running full agent code inside sandboxes. +class RolloutGatewayMixin: + """Opt-in mixin that replaces MultiTurnEnv's interception-based rollout + with a server-side gateway path. Toggle via ``use_gateway`` attribute. + + When gateway is active, the agent talks directly to vLLM's rollout + gateway through a prime tunnel. The env only manages sandbox lifecycle. + When inactive, falls through to CliAgentEnv's interception path. - The sandboxed agent talks directly to the rollout gateway running in the vLLM - server through a prime tunnel URL. The environment only handles sandbox - lifecycle, rollout registration, trajectory fetch, and reward computation. + MRO: ``MyEnv → RolloutGatewayMixin → CliAgentEnv → SandboxMixin → MultiTurnEnv → Environment`` """ - def __init__( + use_gateway: bool = True + + def init_gateway( self, - run_command: str, gateway_port: int = 8000, - max_turns: int = -1, timeout_seconds: float = 3600.0, - poll_interval: float = 2.0, - docker_image: str = "python:3.11-slim", - start_command: str = "tail -f /dev/null", - cpu_cores: int = 1, - memory_gb: int = 2, - disk_size_gb: int = 5, - gpu_count: int = 0, - timeout_minutes: int = 60, - environment_vars: dict[str, str] | None = None, - team_id: str | None = None, - advanced_configs: AdvancedConfigs | None = None, - labels: list[str] | None = None, - max_retries: int = 5, - base_delay: float = 0.5, - backoff_factor: float = 2.0, - max_backoff_seconds: float = 30.0, - jitter: float = 1e-3, - sandbox_client_max_workers: int = 10, - sandbox_client_max_connections: int = 100, - sandbox_client_max_keepalive_connections: int = 50, - sandbox_wait_for_creation_max_attempts: int = 120, - **kwargs, ): - super().__init__(message_type="chat", **kwargs) - self.add_rubric(MultiTurnMonitorRubric()) - - self.init_sandbox_client( - max_retries=max_retries, - base_delay=base_delay, - backoff_factor=backoff_factor, - max_backoff_seconds=max_backoff_seconds, - jitter=jitter, - sandbox_client_max_workers=sandbox_client_max_workers, - sandbox_client_max_connections=sandbox_client_max_connections, - sandbox_client_max_keepalive_connections=sandbox_client_max_keepalive_connections, - sandbox_wait_for_creation_max_attempts=sandbox_wait_for_creation_max_attempts, - ) - - self.run_command = run_command + """Initialize gateway resources. Call in __init__ when use_gateway=True.""" self.gateway_port = gateway_port - self.max_turns = max_turns - self.poll_interval = poll_interval - self.timeout_seconds = timeout_seconds - self.docker_image = docker_image - self.start_command = start_command - self.cpu_cores = cpu_cores - self.memory_gb = memory_gb - self.disk_size_gb = disk_size_gb - self.gpu_count = gpu_count - self.timeout_minutes = timeout_minutes - self.environment_vars = environment_vars - self.team_id = team_id - self.advanced_configs = advanced_configs - self.labels = labels - - self._http_client = httpx.AsyncClient( - timeout=httpx.Timeout(self.timeout_seconds) - ) - self._tunnels: dict[str, Tunnel] = {} - self._tunnel_lock = asyncio.Lock() + self._gw_timeout_seconds = timeout_seconds + self._gw_http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds)) + self._gw_tunnels: dict[str, Tunnel] = {} + self._gw_tunnel_lock = asyncio.Lock() - def _resolve_tunnel_local_addr(self, state: State) -> str: - gateway_url = state["gateway_url"] - parsed = urlparse(gateway_url) - host = parsed.hostname - if host is None: - raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}") - return host - - async def get_tunnel_url(self, local_addr: str | None = None) -> str: - """Get tunnel URL, starting the tunnel if needed.""" - async with self._tunnel_lock: - if local_addr is None: - if len(self._tunnels) == 1: - tunnel = next(iter(self._tunnels.values())) - assert tunnel.url is not None, "Tunnel started but URL is None" - return tunnel.url - if len(self._tunnels) == 0: - raise ValueError("local_addr is required when starting tunnel") - raise ValueError( - "local_addr is required when multiple tunnels are active" - ) - - tunnel = self._tunnels.get(local_addr) - if tunnel is None: - tunnel = Tunnel( - local_port=self.gateway_port, - local_addr=local_addr, - log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info", - ) - url = await tunnel.start() - self._tunnels[local_addr] = tunnel - logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") - return url - - assert tunnel.url is not None, "Tunnel started but URL is None" - return tunnel.url + # ------------------------------------------------------------------ + # Gateway URL resolution + # ------------------------------------------------------------------ def _resolve_gateway_url(self, state: State) -> str: - # `state["client"]` may be a Verifiers wrapper with the raw client on `.client`. client = getattr(state["client"], "client", state["client"]) gateway_url = str(client.base_url).rstrip("/") if gateway_url.endswith("/v1"): gateway_url = gateway_url[:-3] return gateway_url + def _resolve_tunnel_local_addr(self, state: State) -> str: + gateway_url = state["gateway_url"] + parsed = urlparse(gateway_url) + host = parsed.hostname + if host is None: + raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}") + return host + def _rollout_endpoint(self, state: State, suffix: str) -> str: return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}" + # ------------------------------------------------------------------ + # Gateway HTTP helpers + # ------------------------------------------------------------------ + async def _gateway_post( self, state: State, suffix: str, payload: dict[str, Any] | None = None, ) -> dict[str, Any]: - response = await self._http_client.post( + response = await self._gw_http_client.post( self._rollout_endpoint(state, suffix), json=payload, ) @@ -170,10 +90,32 @@ async def _gateway_post( return response.json() async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: - response = await self._http_client.get(self._rollout_endpoint(state, suffix)) + response = await self._gw_http_client.get(self._rollout_endpoint(state, suffix)) response.raise_for_status() return response.json() + # ------------------------------------------------------------------ + # Env var override for gateway path + # ------------------------------------------------------------------ + + async def build_env_vars(self, state: State) -> dict[str, str]: + """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode.""" + if not self.use_gateway: + return await super().build_env_vars(state) + env_vars = dict(self.environment_vars) if self.environment_vars else {} + env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"] + env_vars.setdefault("OPENAI_TIMEOUT", "600") + env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") + env_vars.setdefault("HTTPX_TIMEOUT", "600") + model = state.get("model") + if model: + env_vars["OPENAI_MODEL"] = model + return env_vars + + # ------------------------------------------------------------------ + # Rollout registration & trajectory + # ------------------------------------------------------------------ + async def register_rollout(self, state: State) -> None: sampling_params = state.get("sampling_args") or {} payload = { @@ -196,28 +138,47 @@ async def fetch_trajectory(self, state: State) -> None: data.get("is_truncated", state.get("is_truncated", False)) ) - async def get_docker_image(self, state: State) -> str: - """Get the Docker image for the sandbox. Override for per-task images.""" - return self.docker_image + # ------------------------------------------------------------------ + # Gateway tunnel management + # ------------------------------------------------------------------ - async def build_env_vars(self, state: State) -> dict[str, str]: - """Build environment variables for the sandbox. Override to add custom vars.""" - env_vars = dict(self.environment_vars) if self.environment_vars else {} - env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"] - env_vars.setdefault("OPENAI_TIMEOUT", "600") - env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") - env_vars.setdefault("HTTPX_TIMEOUT", "600") - model = state.get("model") - if model: - env_vars["OPENAI_MODEL"] = model - return env_vars + async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: + """Get gateway tunnel URL, starting the tunnel if needed.""" + async with self._gw_tunnel_lock: + if local_addr is None: + if len(self._gw_tunnels) == 1: + tunnel = next(iter(self._gw_tunnels.values())) + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + if len(self._gw_tunnels) == 0: + raise ValueError("local_addr is required when starting tunnel") + raise ValueError( + "local_addr is required when multiple tunnels are active" + ) - async def post_sandbox_setup(self, state: State) -> None: - """Hook for post-sandbox setup. Override to upload files, run commands, etc.""" - pass + tunnel = self._gw_tunnels.get(local_addr) + if tunnel is None: + tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=local_addr, + log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info", + ) + url = await tunnel.start() + self._gw_tunnels[local_addr] = tunnel + logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") + return url + + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + + # ------------------------------------------------------------------ + # Agent start & completion polling + # ------------------------------------------------------------------ async def start_agent(self, state: State) -> None: - """Start the agent command using background job.""" + """Start the agent command. In gateway mode, skip background completion task.""" + if not self.use_gateway: + return await super().start_agent(state) sandbox_id = state["sandbox_id"] background_job = await self.sandbox_client.start_background_job( sandbox_id, @@ -227,36 +188,17 @@ async def start_agent(self, state: State) -> None: state["agent_start_time"] = time.time() state["agent_completed"] = False - async def wait_for_agent_completion(self, state: State) -> None: - """Poll for agent completion using background job API.""" - sandbox_id = state.get("sandbox_id") - background_job = state.get("background_job") - if not sandbox_id or not background_job: - state["agent_completed"] = True - return - - try: - await asyncio.wait_for( - self.poll_job_completion(state, sandbox_id, background_job), - timeout=self.timeout_seconds, - ) - except asyncio.TimeoutError: - logger.warning( - f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} stage=wait_for_agent_completion timed out after {self.timeout_seconds:.1f}s" - ) - state["agent_timed_out"] = True - finally: - state["agent_completed"] = True - async def poll_job_completion( self, state: State, sandbox_id: str, - background_job: BackgroundJob, + background_job, ) -> None: """Poll until background job completes, capturing output.""" + if not self.use_gateway: + return await super().poll_job_completion(state, sandbox_id, background_job) while True: - status: BackgroundJobStatus = await self.sandbox_client.get_background_job( + status = await self.sandbox_client.get_background_job( sandbox_id, background_job, ) @@ -266,15 +208,45 @@ async def poll_job_completion( state["agent_stderr"] = status.stderr if status.exit_code not in (None, 0): logger.warning( - f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={_tail_text(status.stdout)!r} stderr_tail={_tail_text(status.stderr)!r}" + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} " + f"stage=agent_completed exit_code={status.exit_code} " + f"stdout_tail={_tail_text(status.stdout)!r} " + f"stderr_tail={_tail_text(status.stderr)!r}" ) else: logger.debug( - f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}" + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} " + f"stage=agent_completed exit_code={status.exit_code}" ) return await asyncio.sleep(self.poll_interval) + async def wait_for_agent_completion(self, state: State) -> None: + """Poll for agent completion using background job API.""" + sandbox_id = state.get("sandbox_id") + background_job = state.get("background_job") + if not sandbox_id or not background_job: + state["agent_completed"] = True + return + + try: + await asyncio.wait_for( + self.poll_job_completion(state, sandbox_id, background_job), + timeout=self._gw_timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} " + f"stage=wait_for_agent_completion timed out after {self._gw_timeout_seconds:.1f}s" + ) + state["agent_timed_out"] = True + finally: + state["agent_completed"] = True + + # ------------------------------------------------------------------ + # Timing + # ------------------------------------------------------------------ + async def _render_timing(self, state: State) -> None: start_time = state["timing"]["start_time"] end_time = time.time() @@ -282,6 +254,10 @@ async def _render_timing(self, state: State) -> None: state["timing"]["generation_ms"] = generation_ms state["timing"]["total_ms"] = generation_ms + # ------------------------------------------------------------------ + # rollout() — gateway path with fallback + # ------------------------------------------------------------------ + async def rollout( self, input: RolloutInput, @@ -289,13 +265,18 @@ async def rollout( model: str, sampling_args: SamplingArgs | None = None, ) -> State: + if not self.use_gateway: + return await super().rollout(input, client, model, sampling_args) + state = await self.init_state(input, client, model, sampling_args) state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}" state["gateway_url"] = self._resolve_gateway_url(state) rollout_id = state["rollout_id"] info = state.get("info") or {} logger.info( - f"rollout={rollout_id} stage=start model={model} example_id={info.get('instance_id') or info.get('example_id')} repo={info.get('repo_name')}" + f"rollout={rollout_id} stage=start model={model} " + f"example_id={info.get('instance_id') or info.get('example_id')} " + f"repo={info.get('repo_name')}" ) rollout_registered = False @@ -310,7 +291,7 @@ async def rollout( f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}" ) - tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr) + tunnel_url = await self.get_gateway_tunnel_url(local_addr=tunnel_local_addr) state["tunnel_url"] = tunnel_url state["rollout_base_url"] = ( f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}" @@ -339,25 +320,32 @@ async def rollout( ) await self.create_sandbox(state, sandbox_request) logger.info( - f"rollout={rollout_id} stage=create_sandbox ok sandbox_id={state.get('sandbox_id')} docker_image={docker_image}" + f"rollout={rollout_id} stage=create_sandbox ok " + f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}" ) await self.start_agent(state) logger.debug( - f"rollout={rollout_id} stage=start_agent ok sandbox_id={state.get('sandbox_id')}" + f"rollout={rollout_id} stage=start_agent ok " + f"sandbox_id={state.get('sandbox_id')}" ) await self.wait_for_agent_completion(state) logger.debug( - f"rollout={rollout_id} stage=wait_for_agent_completion ok exit_code={state.get('agent_exit_code')}" + f"rollout={rollout_id} stage=wait_for_agent_completion ok " + f"exit_code={state.get('agent_exit_code')}" ) await self.fetch_trajectory(state) trajectory = state.get("trajectory") or [] logger.info( - f"rollout={rollout_id} stage=fetch_trajectory ok turns={len(trajectory)} truncated={state.get('is_truncated', False)}" + f"rollout={rollout_id} stage=fetch_trajectory ok " + f"turns={len(trajectory)} truncated={state.get('is_truncated', False)}" ) if len(trajectory) == 0: logger.warning( - f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={_tail_text(state.get('agent_stdout'))!r} stderr_tail={_tail_text(state.get('agent_stderr'))!r}" + f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory " + f"agent_exit_code={state.get('agent_exit_code')} " + f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} " + f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}" ) except vf.Error as e: state["error"] = e @@ -403,18 +391,27 @@ async def rollout( await self._render_timing(state) error_name = type(state["error"]).__name__ if state.get("error") else None logger.info( - f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}" + f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} " + f"sandbox_id={state.get('sandbox_id')} " + f"turns={len(state.get('trajectory', []))} " + f"agent_exit_code={state.get('agent_exit_code')} error={error_name}" ) return state + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + @vf.teardown - async def teardown_resources(self): - """Stop Prime Tunnel and close HTTP client.""" - await self._http_client.aclose() - async with self._tunnel_lock: - tunnels = list(self._tunnels.items()) - self._tunnels = {} + async def teardown_gateway(self): + """Close gateway HTTP client and stop gateway tunnels.""" + if not hasattr(self, "_gw_http_client"): + return + await self._gw_http_client.aclose() + async with self._gw_tunnel_lock: + tunnels = list(self._gw_tunnels.items()) + self._gw_tunnels = {} for local_addr, tunnel in tunnels: try: tunnel.sync_stop() @@ -423,18 +420,3 @@ async def teardown_resources(self): logger.warning( f"Error stopping Prime Tunnel local_addr={local_addr}: {e}" ) - - async def post_rollout(self, state: State): - """ - Override for custom post-rollout logic. For example, if sandbox state is needed for reward functions, - run computation here and cache the result in state before sandbox is destroyed. - """ - pass - - @vf.cleanup - async def destroy_sandbox(self, state: State): - """Cleanup sandbox after rollout.""" - await self.post_rollout(state) - sandbox_id = state.get("sandbox_id") - if sandbox_id: - await self.delete_sandbox(sandbox_id) From 7a36c41c0322fa972ebf856916d41589b9c6795d Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 02:24:00 +0530 Subject: [PATCH 16/21] cancel properly --- verifiers/envs/experimental/rollout_gateway_mixin.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index 880e452b3..1d5a78630 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -347,6 +347,13 @@ async def rollout( f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} " f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}" ) + except asyncio.CancelledError: + if rollout_registered: + try: + await self._gateway_post(state, "cancel") + except Exception: + pass + raise except vf.Error as e: state["error"] = e logger.exception( From d664d31cebb2576306ecb648abfc11422be0ae71 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:08:45 +0530 Subject: [PATCH 17/21] refactor --- tests/test_rollout_gateway_env.py | 39 +++++++++++++++- verifiers/envs/experimental/cli_agent_env.py | 11 +++-- .../experimental/rollout_gateway_mixin.py | 44 +++++++++++-------- 3 files changed, 70 insertions(+), 24 deletions(-) diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py index d16da5d35..87bfe7b07 100644 --- a/tests/test_rollout_gateway_env.py +++ b/tests/test_rollout_gateway_env.py @@ -42,8 +42,8 @@ def sync_stop(self) -> None: class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv): def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs): - super().__init__(**kwargs) self.use_gateway = use_gateway + super().__init__(**kwargs) if use_gateway: self.init_gateway( gateway_port=gateway_port, timeout_seconds=self.timeout_seconds @@ -311,3 +311,40 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): await env.teardown_gateway() assert sum(t.stop_calls for t in FakeTunnel.instances) == 2 + + +@pytest.mark.asyncio +async def test_use_gateway_false_initializes_interception(monkeypatch): + """With use_gateway=False, interception server is created and gateway is not.""" + FakeTunnel.instances.clear() + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + use_gateway=False, + timeout_seconds=30.0, + ) + + # Interception server should be initialized + assert env._interception_server is not None + assert env._tunnel is None + assert env._tunnel_lock is not None + + # Gateway attributes should not exist (init_gateway was never called) + assert not hasattr(env, "_http_client") + assert not hasattr(env, "_tunnels") + + # Teardowns should be safe no-ops for the inactive path + await env.teardown_gateway() # early return via use_gateway=False + await env.teardown_resources() # stops interception (which was never started) + + assert len(FakeTunnel.instances) == 0 diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 946308f5b..f4f91093b 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -97,6 +97,10 @@ def __init__( self.advanced_configs = advanced_configs self.labels = labels + self._interception_server: InterceptionServer | None = None + self._tunnel: Tunnel | None = None + self._tunnel_lock = asyncio.Lock() + self.init_interception(interception_port, interception_url) def init_interception( @@ -418,8 +422,6 @@ async def add_model_response( @vf.teardown async def teardown_resources(self): """Stop Prime Tunnel and HTTP interception server.""" - if not hasattr(self, "_tunnel_lock"): - return async with self._tunnel_lock: if self._tunnel is not None: try: @@ -429,7 +431,8 @@ async def teardown_resources(self): logger.warning(f"Error stopping Prime Tunnel: {e}") finally: self._tunnel = None - await self._interception_server.stop() + if self._interception_server is not None: + await self._interception_server.stop() @vf.cleanup async def cleanup_interception_context(self, state: State): @@ -446,7 +449,7 @@ async def cleanup_interception_context(self, state: State): state.pop("background_job", None) rollout_id = state.get("rollout_id") - if rollout_id and hasattr(self, "_interception_server"): + if rollout_id and self._interception_server is not None: self._interception_server.unregister_rollout(rollout_id) @vf.stop diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index 1d5a78630..0e0ca49b1 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -36,6 +36,10 @@ class RolloutGatewayMixin: use_gateway: bool = True + def init_interception(self, *args, **kwargs): + if not self.use_gateway: + super().init_interception(*args, **kwargs) + def init_gateway( self, gateway_port: int = 8000, @@ -43,10 +47,12 @@ def init_gateway( ): """Initialize gateway resources. Call in __init__ when use_gateway=True.""" self.gateway_port = gateway_port - self._gw_timeout_seconds = timeout_seconds - self._gw_http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds)) - self._gw_tunnels: dict[str, Tunnel] = {} - self._gw_tunnel_lock = asyncio.Lock() + self._gateway_timeout_seconds = timeout_seconds + self._http_client: httpx.AsyncClient | None = httpx.AsyncClient( + timeout=httpx.Timeout(timeout_seconds) + ) + self._tunnels: dict[str, Tunnel] = {} + self._tunnel_lock = asyncio.Lock() # ------------------------------------------------------------------ # Gateway URL resolution @@ -80,7 +86,7 @@ async def _gateway_post( suffix: str, payload: dict[str, Any] | None = None, ) -> dict[str, Any]: - response = await self._gw_http_client.post( + response = await self._http_client.post( self._rollout_endpoint(state, suffix), json=payload, ) @@ -90,7 +96,7 @@ async def _gateway_post( return response.json() async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: - response = await self._gw_http_client.get(self._rollout_endpoint(state, suffix)) + response = await self._http_client.get(self._rollout_endpoint(state, suffix)) response.raise_for_status() return response.json() @@ -144,19 +150,19 @@ async def fetch_trajectory(self, state: State) -> None: async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: """Get gateway tunnel URL, starting the tunnel if needed.""" - async with self._gw_tunnel_lock: + async with self._tunnel_lock: if local_addr is None: - if len(self._gw_tunnels) == 1: - tunnel = next(iter(self._gw_tunnels.values())) + if len(self._tunnels) == 1: + tunnel = next(iter(self._tunnels.values())) assert tunnel.url is not None, "Tunnel started but URL is None" return tunnel.url - if len(self._gw_tunnels) == 0: + if len(self._tunnels) == 0: raise ValueError("local_addr is required when starting tunnel") raise ValueError( "local_addr is required when multiple tunnels are active" ) - tunnel = self._gw_tunnels.get(local_addr) + tunnel = self._tunnels.get(local_addr) if tunnel is None: tunnel = Tunnel( local_port=self.gateway_port, @@ -164,7 +170,7 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info", ) url = await tunnel.start() - self._gw_tunnels[local_addr] = tunnel + self._tunnels[local_addr] = tunnel logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") return url @@ -232,12 +238,12 @@ async def wait_for_agent_completion(self, state: State) -> None: try: await asyncio.wait_for( self.poll_job_completion(state, sandbox_id, background_job), - timeout=self._gw_timeout_seconds, + timeout=self._gateway_timeout_seconds, ) except asyncio.TimeoutError: logger.warning( f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} " - f"stage=wait_for_agent_completion timed out after {self._gw_timeout_seconds:.1f}s" + f"stage=wait_for_agent_completion timed out after {self._gateway_timeout_seconds:.1f}s" ) state["agent_timed_out"] = True finally: @@ -413,12 +419,12 @@ async def rollout( @vf.teardown async def teardown_gateway(self): """Close gateway HTTP client and stop gateway tunnels.""" - if not hasattr(self, "_gw_http_client"): + if not self.use_gateway: return - await self._gw_http_client.aclose() - async with self._gw_tunnel_lock: - tunnels = list(self._gw_tunnels.items()) - self._gw_tunnels = {} + await self._http_client.aclose() + async with self._tunnel_lock: + tunnels = list(self._tunnels.items()) + self._tunnels = {} for local_addr, tunnel in tunnels: try: tunnel.sync_stop() From d0e58f799528b0e5d541f6fe2162b0b10f604df1 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:13:40 +0530 Subject: [PATCH 18/21] clean up comments --- .../experimental/rollout_gateway_mixin.py | 40 +------------------ 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index 0e0ca49b1..67f96672c 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -24,10 +24,10 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str: class RolloutGatewayMixin: - """Opt-in mixin that replaces MultiTurnEnv's interception-based rollout + """Opt-in mixin that replaces CliAgentEnv's interception-based rollout with a server-side gateway path. Toggle via ``use_gateway`` attribute. - When gateway is active, the agent talks directly to vLLM's rollout + When gateway is active, the agent talks directly to prime-rl's rollout gateway through a prime tunnel. The env only manages sandbox lifecycle. When inactive, falls through to CliAgentEnv's interception path. @@ -54,10 +54,6 @@ def init_gateway( self._tunnels: dict[str, Tunnel] = {} self._tunnel_lock = asyncio.Lock() - # ------------------------------------------------------------------ - # Gateway URL resolution - # ------------------------------------------------------------------ - def _resolve_gateway_url(self, state: State) -> str: client = getattr(state["client"], "client", state["client"]) gateway_url = str(client.base_url).rstrip("/") @@ -76,10 +72,6 @@ def _resolve_tunnel_local_addr(self, state: State) -> str: def _rollout_endpoint(self, state: State, suffix: str) -> str: return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}" - # ------------------------------------------------------------------ - # Gateway HTTP helpers - # ------------------------------------------------------------------ - async def _gateway_post( self, state: State, @@ -100,10 +92,6 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: response.raise_for_status() return response.json() - # ------------------------------------------------------------------ - # Env var override for gateway path - # ------------------------------------------------------------------ - async def build_env_vars(self, state: State) -> dict[str, str]: """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode.""" if not self.use_gateway: @@ -118,10 +106,6 @@ async def build_env_vars(self, state: State) -> dict[str, str]: env_vars["OPENAI_MODEL"] = model return env_vars - # ------------------------------------------------------------------ - # Rollout registration & trajectory - # ------------------------------------------------------------------ - async def register_rollout(self, state: State) -> None: sampling_params = state.get("sampling_args") or {} payload = { @@ -144,10 +128,6 @@ async def fetch_trajectory(self, state: State) -> None: data.get("is_truncated", state.get("is_truncated", False)) ) - # ------------------------------------------------------------------ - # Gateway tunnel management - # ------------------------------------------------------------------ - async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: """Get gateway tunnel URL, starting the tunnel if needed.""" async with self._tunnel_lock: @@ -177,10 +157,6 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: assert tunnel.url is not None, "Tunnel started but URL is None" return tunnel.url - # ------------------------------------------------------------------ - # Agent start & completion polling - # ------------------------------------------------------------------ - async def start_agent(self, state: State) -> None: """Start the agent command. In gateway mode, skip background completion task.""" if not self.use_gateway: @@ -249,10 +225,6 @@ async def wait_for_agent_completion(self, state: State) -> None: finally: state["agent_completed"] = True - # ------------------------------------------------------------------ - # Timing - # ------------------------------------------------------------------ - async def _render_timing(self, state: State) -> None: start_time = state["timing"]["start_time"] end_time = time.time() @@ -260,10 +232,6 @@ async def _render_timing(self, state: State) -> None: state["timing"]["generation_ms"] = generation_ms state["timing"]["total_ms"] = generation_ms - # ------------------------------------------------------------------ - # rollout() — gateway path with fallback - # ------------------------------------------------------------------ - async def rollout( self, input: RolloutInput, @@ -412,10 +380,6 @@ async def rollout( return state - # ------------------------------------------------------------------ - # Teardown - # ------------------------------------------------------------------ - @vf.teardown async def teardown_gateway(self): """Close gateway HTTP client and stop gateway tunnels.""" From d74e447f6a96c6d357274d0d9d052052e8654743 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:23:09 +0530 Subject: [PATCH 19/21] fix --- verifiers/envs/experimental/rollout_gateway_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index 67f96672c..a30a1f6ff 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -11,7 +11,7 @@ import verifiers as vf from verifiers.clients import Client -from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State +from verifiers.types import RolloutInput, SamplingArgs, State logger = logging.getLogger(__name__) @@ -235,7 +235,7 @@ async def _render_timing(self, state: State) -> None: async def rollout( self, input: RolloutInput, - client: Client | ClientConfig, + client: Client, model: str, sampling_args: SamplingArgs | None = None, ) -> State: From 4518df7efa52c94b20437fdcf2ecbd171e1662bd Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:35:47 +0530 Subject: [PATCH 20/21] fix `ty` --- .../experimental/rollout_gateway_mixin.py | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index a30a1f6ff..833b40b30 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -38,7 +38,7 @@ class RolloutGatewayMixin: def init_interception(self, *args, **kwargs): if not self.use_gateway: - super().init_interception(*args, **kwargs) + super().init_interception(*args, **kwargs) # ty: ignore[unresolved-attribute] def init_gateway( self, @@ -48,9 +48,7 @@ def init_gateway( """Initialize gateway resources. Call in __init__ when use_gateway=True.""" self.gateway_port = gateway_port self._gateway_timeout_seconds = timeout_seconds - self._http_client: httpx.AsyncClient | None = httpx.AsyncClient( - timeout=httpx.Timeout(timeout_seconds) - ) + self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds)) self._tunnels: dict[str, Tunnel] = {} self._tunnel_lock = asyncio.Lock() @@ -95,8 +93,8 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: async def build_env_vars(self, state: State) -> dict[str, str]: """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode.""" if not self.use_gateway: - return await super().build_env_vars(state) - env_vars = dict(self.environment_vars) if self.environment_vars else {} + return await super().build_env_vars(state) # ty: ignore[unresolved-attribute] + env_vars = dict(self.environment_vars) if self.environment_vars else {} # ty: ignore[unresolved-attribute] env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"] env_vars.setdefault("OPENAI_TIMEOUT", "600") env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") @@ -111,8 +109,8 @@ async def register_rollout(self, state: State) -> None: payload = { "model": state["model"], "sampling_params": sampling_params, - "max_turns": self.max_turns, - "max_seq_len": self.max_seq_len, + "max_turns": self.max_turns, # ty: ignore[unresolved-attribute] + "max_seq_len": self.max_seq_len, # ty: ignore[unresolved-attribute] } await self._gateway_post(state, "register", payload) @@ -160,11 +158,11 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: async def start_agent(self, state: State) -> None: """Start the agent command. In gateway mode, skip background completion task.""" if not self.use_gateway: - return await super().start_agent(state) + return await super().start_agent(state) # ty: ignore[unresolved-attribute] sandbox_id = state["sandbox_id"] - background_job = await self.sandbox_client.start_background_job( + background_job = await self.sandbox_client.start_background_job( # ty: ignore[unresolved-attribute] sandbox_id, - self.run_command, + self.run_command, # ty: ignore[unresolved-attribute] ) state["background_job"] = background_job state["agent_start_time"] = time.time() @@ -178,9 +176,9 @@ async def poll_job_completion( ) -> None: """Poll until background job completes, capturing output.""" if not self.use_gateway: - return await super().poll_job_completion(state, sandbox_id, background_job) + return await super().poll_job_completion(state, sandbox_id, background_job) # ty: ignore[unresolved-attribute] while True: - status = await self.sandbox_client.get_background_job( + status = await self.sandbox_client.get_background_job( # ty: ignore[unresolved-attribute] sandbox_id, background_job, ) @@ -201,7 +199,7 @@ async def poll_job_completion( f"stage=agent_completed exit_code={status.exit_code}" ) return - await asyncio.sleep(self.poll_interval) + await asyncio.sleep(self.poll_interval) # ty: ignore[unresolved-attribute] async def wait_for_agent_completion(self, state: State) -> None: """Poll for agent completion using background job API.""" @@ -240,9 +238,9 @@ async def rollout( sampling_args: SamplingArgs | None = None, ) -> State: if not self.use_gateway: - return await super().rollout(input, client, model, sampling_args) + return await super().rollout(input, client, model, sampling_args) # ty: ignore[unresolved-attribute] - state = await self.init_state(input, client, model, sampling_args) + state = await self.init_state(input, client, model, sampling_args) # ty: ignore[unresolved-attribute] state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}" state["gateway_url"] = self._resolve_gateway_url(state) rollout_id = state["rollout_id"] @@ -273,26 +271,26 @@ async def rollout( logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}") env_vars = await self.build_env_vars(state) - docker_image = await self.get_docker_image(state) + docker_image = await self.get_docker_image(state) # ty: ignore[unresolved-attribute] sandbox_request = CreateSandboxRequest( name=state["rollout_id"], docker_image=docker_image, - start_command=self.start_command, - cpu_cores=self.cpu_cores, - memory_gb=self.memory_gb, - disk_size_gb=self.disk_size_gb, - gpu_count=self.gpu_count, - timeout_minutes=self.timeout_minutes, + start_command=self.start_command, # ty: ignore[unresolved-attribute] + cpu_cores=self.cpu_cores, # ty: ignore[unresolved-attribute] + memory_gb=self.memory_gb, # ty: ignore[unresolved-attribute] + disk_size_gb=self.disk_size_gb, # ty: ignore[unresolved-attribute] + gpu_count=self.gpu_count, # ty: ignore[unresolved-attribute] + timeout_minutes=self.timeout_minutes, # ty: ignore[unresolved-attribute] environment_vars=env_vars, - team_id=self.team_id, - advanced_configs=self.advanced_configs, - labels=self.labels or [], + team_id=self.team_id, # ty: ignore[unresolved-attribute] + advanced_configs=self.advanced_configs, # ty: ignore[unresolved-attribute] + labels=self.labels or [], # ty: ignore[unresolved-attribute] ) logger.debug( f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} " f"docker_image={docker_image}" ) - await self.create_sandbox(state, sandbox_request) + await self.create_sandbox(state, sandbox_request) # ty: ignore[unresolved-attribute] logger.info( f"rollout={rollout_id} stage=create_sandbox ok " f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}" @@ -351,7 +349,7 @@ async def rollout( if state.get("sandbox_id"): try: - await self._cleanup(state) + await self._cleanup(state) # ty: ignore[unresolved-attribute] except Exception as e: logger.warning( f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}" From 21d45c0460214c951549bca978795bd6cd26cc61 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:37:36 +0530 Subject: [PATCH 21/21] docs --- assets/lab/environments/AGENTS.md | 3 ++- docs/environments.md | 1 + environments/AGENTS.md | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/assets/lab/environments/AGENTS.md b/assets/lab/environments/AGENTS.md index cc7f76d02..98e63b1d7 100644 --- a/assets/lab/environments/AGENTS.md +++ b/assets/lab/environments/AGENTS.md @@ -802,5 +802,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks -- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`. +- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls diff --git a/docs/environments.md b/docs/environments.md index 84f884482..4f8e43121 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -796,5 +796,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks - **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls diff --git a/environments/AGENTS.md b/environments/AGENTS.md index da5323579..632fb0ee0 100644 --- a/environments/AGENTS.md +++ b/environments/AGENTS.md @@ -802,5 +802,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks -- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`. +- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls