From 17e9c253bda6e24cf2bc6fe85002c91af40149be Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 11 Feb 2026 00:45:47 +0530
Subject: [PATCH 01/21] init
---
tests/test_cli_agent_env_gateway.py | 254 ++++++++++++++++++++++++++++
1 file changed, 254 insertions(+)
create mode 100644 tests/test_cli_agent_env_gateway.py
diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py
new file mode 100644
index 000000000..cbb05a928
--- /dev/null
+++ b/tests/test_cli_agent_env_gateway.py
@@ -0,0 +1,254 @@
+from __future__ import annotations
+
+import json
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import AsyncMock
+
+import httpx
+import pytest
+from datasets import Dataset
+
+import verifiers as vf
+import verifiers.envs.experimental.cli_agent_env as cli_agent_env
+
+pytestmark = [pytest.mark.integration, pytest.mark.environments]
+
+
+class FakeTunnel:
+ instances: list["FakeTunnel"] = []
+
+ def __init__(self, local_port: int, log_level: str | None = None):
+ self.local_port = local_port
+ self.log_level = log_level
+ self.url: str | None = None
+ self.start_calls = 0
+ self.stop_calls = 0
+ FakeTunnel.instances.append(self)
+
+ async def start(self) -> str:
+ self.start_calls += 1
+ self.url = "https://unit-test.tunnel.prime.ai"
+ return self.url
+
+ def sync_stop(self) -> None:
+ self.stop_calls += 1
+
+
+class GatewayCliAgentEnv(vf.CliAgentEnv):
+ async def post_rollout(self, state: vf.State):
+ state["reward"] = 1.0
+ state["test_output"] = "ok"
+
+
+def _build_gateway_transport(tracker: dict) -> httpx.MockTransport:
+ trajectory = [
+ {
+ "prompt": [{"role": "user", "content": "Hello"}],
+ "completion": [{"role": "assistant", "content": "reply-1"}],
+ "tokens": {
+ "prompt_ids": [1, 2],
+ "prompt_mask": [0, 0],
+ "completion_ids": [3],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.1],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ {
+ "prompt": [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ ],
+ "completion": [{"role": "assistant", "content": "reply-2"}],
+ "tokens": {
+ "prompt_ids": [1, 2, 3, 4],
+ "prompt_mask": [0, 0, 0, 0],
+ "completion_ids": [5],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.2],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ {
+ "prompt": [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ {"role": "assistant", "content": "reply-2"},
+ {"role": "user", "content": "Turn 3"},
+ ],
+ "completion": [{"role": "assistant", "content": "reply-3"}],
+ "tokens": {
+ "prompt_ids": [1, 2, 3, 4, 5, 6],
+ "prompt_mask": [0, 0, 0, 0, 0, 0],
+ "completion_ids": [7],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.3],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ ]
+
+ def _handler(request: httpx.Request) -> httpx.Response:
+ tracker["hosts"].add(request.url.host)
+ tracker["paths"].append(request.url.path)
+ path = request.url.path
+
+ if request.method == "POST" and path.endswith("/register"):
+ payload = json.loads(request.content.decode("utf-8"))
+ tracker["register_payload"] = payload
+ tracker["rollout_id"] = path.split("/")[-2]
+ return httpx.Response(status_code=200, json={"status": "active"})
+
+ if request.method == "POST" and path.endswith("/unregister"):
+ tracker["unregister_calls"] += 1
+ return httpx.Response(status_code=200, json={"status": "active"})
+
+ if request.method == "GET" and path.endswith("/trajectory"):
+ tracker["trajectory_calls"] += 1
+ return httpx.Response(
+ status_code=200,
+ json={
+ "rollout_id": tracker["rollout_id"],
+ "status": "completed",
+ "num_turns": 3,
+ "model": "Qwen/Qwen3-0.6B",
+ "prompt": trajectory[0]["prompt"],
+ "completion": [
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ {"role": "assistant", "content": "reply-2"},
+ {"role": "user", "content": "Turn 3"},
+ {"role": "assistant", "content": "reply-3"},
+ ],
+ "is_truncated": False,
+ "trajectory": trajectory,
+ },
+ )
+
+ return httpx.Response(status_code=404, json={"error": f"Unhandled path {path}"})
+
+ return httpx.MockTransport(_handler)
+
+
+@pytest.mark.asyncio
+async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel)
+
+ tracker = {
+ "paths": [],
+ "hosts": set(),
+ "register_payload": None,
+ "rollout_id": None,
+ "trajectory_calls": 0,
+ "unregister_calls": 0,
+ }
+ transport = _build_gateway_transport(tracker)
+ real_async_client = httpx.AsyncClient
+
+ def _client_factory(*args, **kwargs):
+ kwargs["transport"] = transport
+ return real_async_client(*args, **kwargs)
+
+ monkeypatch.setattr(cli_agent_env.httpx, "AsyncClient", _client_factory)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ gateway_port=8000,
+ max_turns=10,
+ timeout_seconds=30.0,
+ )
+
+ env.sandbox_client.create = AsyncMock(return_value=SimpleNamespace(id="sb-123"))
+ env.sandbox_client.wait_for_creation = AsyncMock(return_value=None)
+ env.sandbox_client.start_background_job = AsyncMock(
+ return_value=SimpleNamespace(id="job-1")
+ )
+ env.sandbox_client.get_background_job = AsyncMock(
+ return_value=SimpleNamespace(
+ completed=True,
+ exit_code=0,
+ stdout="agent ok",
+ stderr="",
+ )
+ )
+ env.sandbox_client.delete = AsyncMock(return_value=None)
+
+ rollout_input = {
+ "prompt": [{"role": "user", "content": "Hello"}],
+ "answer": "",
+ "example_id": 0,
+ "task": "gateway-test",
+ }
+ client = cast(Any, SimpleNamespace(base_url="http://gateway.internal:8000/v1/"))
+ state = await env.rollout(
+ input=rollout_input,
+ client=client,
+ model="Qwen/Qwen3-0.6B",
+ sampling_args={"temperature": 0.7, "max_completion_tokens": 64},
+ )
+
+ assert state.get("error") is None
+ assert state["gateway_url"] == "http://gateway.internal:8000"
+ assert state["tunnel_url"] == "https://unit-test.tunnel.prime.ai"
+ assert state["rollout_base_url"].startswith(
+ "https://unit-test.tunnel.prime.ai/v1/rollouts/"
+ )
+ assert len(state["trajectory"]) == 3
+ assert state["prompt"] == [{"role": "user", "content": "Hello"}]
+ assert state["completion"][-1]["content"] == "reply-3"
+ assert state["reward"] == 1.0
+
+ create_request = env.sandbox_client.create.await_args.args[0]
+ assert (
+ create_request.environment_vars["OPENAI_BASE_URL"]
+ == f"https://unit-test.tunnel.prime.ai/v1/rollouts/{state['rollout_id']}"
+ )
+ assert create_request.environment_vars["OPENAI_MODEL"] == "Qwen/Qwen3-0.6B"
+
+ assert tracker["register_payload"]["max_turns"] == 10
+ assert tracker["register_payload"]["sampling_params"]["temperature"] == 0.7
+ assert tracker["register_payload"]["sampling_params"]["max_completion_tokens"] == 64
+ assert tracker["trajectory_calls"] == 1
+ assert tracker["unregister_calls"] == 1
+ assert tracker["hosts"] == {"gateway.internal"}
+
+ assert len(FakeTunnel.instances) == 1
+ tunnel = FakeTunnel.instances[0]
+ assert tunnel.local_port == 8000
+ assert tunnel.start_calls == 1
+ assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai"
+ assert tunnel.start_calls == 1
+
+ await env.teardown_resources()
+ assert tunnel.stop_calls == 1
From 1e42cbd0267ef70044fa8458e91b8fba2654db72 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 11 Feb 2026 01:02:22 +0530
Subject: [PATCH 02/21] fix tunnel address
---
tests/test_cli_agent_env_gateway.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py
index cbb05a928..3616a5e0a 100644
--- a/tests/test_cli_agent_env_gateway.py
+++ b/tests/test_cli_agent_env_gateway.py
@@ -18,8 +18,14 @@
class FakeTunnel:
instances: list["FakeTunnel"] = []
- def __init__(self, local_port: int, log_level: str | None = None):
+ def __init__(
+ self,
+ local_port: int,
+ local_addr: str = "127.0.0.1",
+ log_level: str | None = None,
+ ):
self.local_port = local_port
+ self.local_addr = local_addr
self.log_level = log_level
self.url: str | None = None
self.start_calls = 0
@@ -246,6 +252,7 @@ def _client_factory(*args, **kwargs):
assert len(FakeTunnel.instances) == 1
tunnel = FakeTunnel.instances[0]
assert tunnel.local_port == 8000
+ assert tunnel.local_addr == "gateway.internal"
assert tunnel.start_calls == 1
assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai"
assert tunnel.start_calls == 1
From fcc2a839cca60bcf2c0a4e2c6ba5c4ca49578f7a Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 11 Feb 2026 22:07:30 +0530
Subject: [PATCH 03/21] bump health timeout to 120s
---
verifiers/envs/experimental/cli_agent_env.py | 24 ++++++++++++++++++++
1 file changed, 24 insertions(+)
diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py
index 5cfb9e42e..834a3b130 100644
--- a/verifiers/envs/experimental/cli_agent_env.py
+++ b/verifiers/envs/experimental/cli_agent_env.py
@@ -172,6 +172,30 @@ async def setup_state(self, state: State) -> State:
return state
+ if logger.isEnabledFor(logging.DEBUG):
+ rollout_id = state.get("rollout_id")
+ logger.debug(
+ "rollout=%s fetched trajectory steps=%d truncated=%s",
+ rollout_id,
+ len(trajectory),
+ state["is_truncated"],
+ )
+ for turn_idx, step in enumerate(trajectory):
+ tokens = step.get("tokens")
+ prompt_token_count = (
+ len(tokens["prompt_ids"]) if tokens is not None else 0
+ )
+ completion_token_count = (
+ len(tokens["completion_ids"]) if tokens is not None else 0
+ )
+ logger.debug(
+ "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d",
+ rollout_id,
+ turn_idx,
+ prompt_token_count,
+ completion_token_count,
+ )
+
async def get_docker_image(self, state: State) -> str:
"""Get the Docker image for the sandbox. Override for per-task images."""
return self.docker_image
From 17fb3bdf78265d917b44d8fced186a5c6a10f8da Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Thu, 12 Feb 2026 01:17:46 +0530
Subject: [PATCH 04/21] fix(env): initialize state["completion"] in
Environment.init_state to prevent rubric KeyError on early rollout failures
---
verifiers/envs/environment.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py
index 40ef74cbe..96126b506 100644
--- a/verifiers/envs/environment.py
+++ b/verifiers/envs/environment.py
@@ -622,6 +622,7 @@ async def init_state(
state["tool_defs"] = self._normalize_tool_defs(resolved_tool_defs) or []
state["trajectory"] = []
+ state["completion"] = None
self._get_usage_tracker(state, create_if_missing=True)
state["trajectory_id"] = uuid.uuid4().hex
state["reward"] = None
From 898196c5eaefee906044d61c6aec4f5f701c2265 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Thu, 19 Feb 2026 23:28:36 +0530
Subject: [PATCH 05/21] multiple tunnels
---
tests/test_cli_agent_env_gateway.py | 40 +++++++++++++++++++++++++++++
1 file changed, 40 insertions(+)
diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_cli_agent_env_gateway.py
index 3616a5e0a..18a17cfc2 100644
--- a/tests/test_cli_agent_env_gateway.py
+++ b/tests/test_cli_agent_env_gateway.py
@@ -259,3 +259,43 @@ def _client_factory(*args, **kwargs):
await env.teardown_resources()
assert tunnel.stop_calls == 1
+
+
+@pytest.mark.asyncio
+async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ gateway_port=8000,
+ )
+
+ url_a = await env.get_tunnel_url(local_addr="10.20.0.58")
+ url_b = await env.get_tunnel_url(local_addr="10.20.0.59")
+ url_a_reuse = await env.get_tunnel_url(local_addr="10.20.0.58")
+
+ assert url_a == "https://unit-test.tunnel.prime.ai"
+ assert url_b == "https://unit-test.tunnel.prime.ai"
+ assert url_a_reuse == url_a
+
+ assert len(FakeTunnel.instances) == 2
+ assert {t.local_addr for t in FakeTunnel.instances} == {"10.20.0.58", "10.20.0.59"}
+ assert sum(t.start_calls for t in FakeTunnel.instances) == 2
+
+ with pytest.raises(
+ ValueError, match="local_addr is required when multiple tunnels are active"
+ ):
+ await env.get_tunnel_url()
+
+ await env.teardown_resources()
+ assert sum(t.stop_calls for t in FakeTunnel.instances) == 2
From 9e2e7354ca403b5c321e7f0e2a214d75fced828d Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Sun, 22 Feb 2026 02:37:24 +0530
Subject: [PATCH 06/21] rename: `RolloutGatewayEnv`
---
verifiers/__init__.py | 3 +
.../envs/experimental/rollout_gateway_env.py | 567 ++++++++++++++++++
2 files changed, 570 insertions(+)
create mode 100644 verifiers/envs/experimental/rollout_gateway_env.py
diff --git a/verifiers/__init__.py b/verifiers/__init__.py
index 2cd3480d8..7c91a5600 100644
--- a/verifiers/__init__.py
+++ b/verifiers/__init__.py
@@ -67,6 +67,7 @@
"ReasoningGymEnv",
"GymEnv",
"CliAgentEnv",
+ "RolloutGatewayEnv",
"HarborEnv",
"MCPEnv",
"BrowserEnv",
@@ -121,6 +122,7 @@
"PythonEnv": "verifiers.envs.python_env:PythonEnv",
"GymEnv": "verifiers.envs.experimental.gym_env:GymEnv",
"CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv",
+ "RolloutGatewayEnv": "verifiers.envs.experimental.rollout_gateway_env:RolloutGatewayEnv",
"HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv",
"MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv",
"ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv",
@@ -160,6 +162,7 @@ def __getattr__(name: str):
from typing import Any
from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401
+ from .envs.experimental.rollout_gateway_env import RolloutGatewayEnv # noqa: F401
from .envs.experimental.gym_env import GymEnv # noqa: F401
from .envs.experimental.harbor_env import HarborEnv # noqa: F401
from .envs.experimental.mcp_env import MCPEnv # noqa: F401
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
new file mode 100644
index 000000000..b0fc4a08e
--- /dev/null
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -0,0 +1,567 @@
+import asyncio
+import logging
+import time
+import uuid
+from typing import Any, cast
+from urllib.parse import urlparse
+
+import httpx
+from openai import AsyncOpenAI
+from prime_sandboxes import (
+ AdvancedConfigs,
+ BackgroundJob,
+ BackgroundJobStatus,
+ CreateSandboxRequest,
+)
+from prime_tunnel import Tunnel
+
+import verifiers as vf
+from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
+from verifiers.envs.multiturn_env import MultiTurnMonitorRubric
+from verifiers.types import RolloutInput, SamplingArgs, State, TrajectoryStep
+
+logger = logging.getLogger(__name__)
+
+
+class RolloutGatewayEnv(SandboxMixin, vf.Environment):
+ """
+ Environment for running full agent code inside sandboxes.
+
+ The sandboxed agent talks directly to the rollout gateway running in the vLLM
+ server through a prime tunnel URL. The environment only handles sandbox
+ lifecycle, rollout registration, trajectory fetch, and reward computation.
+ """
+
+ def __init__(
+ self,
+ run_command: str,
+ gateway_port: int = 8000,
+ max_turns: int = -1,
+ timeout_seconds: float = 3600.0,
+ poll_interval: float = 2.0,
+ docker_image: str = "python:3.11-slim",
+ start_command: str = "tail -f /dev/null",
+ cpu_cores: int = 1,
+ memory_gb: int = 2,
+ disk_size_gb: int = 5,
+ gpu_count: int = 0,
+ timeout_minutes: int = 60,
+ environment_vars: dict[str, str] | None = None,
+ team_id: str | None = None,
+ advanced_configs: AdvancedConfigs | None = None,
+ labels: list[str] | None = None,
+ max_retries: int = 5,
+ base_delay: float = 0.5,
+ backoff_factor: float = 2.0,
+ max_backoff_seconds: float = 30.0,
+ jitter: float = 1e-3,
+ sandbox_client_max_workers: int = 10,
+ sandbox_client_max_connections: int = 100,
+ sandbox_client_max_keepalive_connections: int = 50,
+ sandbox_wait_for_creation_max_attempts: int = 120,
+ **kwargs,
+ ):
+ super().__init__(message_type="chat", **kwargs)
+ self.add_rubric(MultiTurnMonitorRubric())
+
+ self.init_sandbox_client(
+ max_retries=max_retries,
+ base_delay=base_delay,
+ backoff_factor=backoff_factor,
+ max_backoff_seconds=max_backoff_seconds,
+ jitter=jitter,
+ sandbox_client_max_workers=sandbox_client_max_workers,
+ sandbox_client_max_connections=sandbox_client_max_connections,
+ sandbox_client_max_keepalive_connections=sandbox_client_max_keepalive_connections,
+ sandbox_wait_for_creation_max_attempts=sandbox_wait_for_creation_max_attempts,
+ )
+
+ self.run_command = run_command
+ self.gateway_port = gateway_port
+ self.max_turns = max_turns
+ self.poll_interval = poll_interval
+ self.timeout_seconds = timeout_seconds
+ self.docker_image = docker_image
+ self.start_command = start_command
+ self.cpu_cores = cpu_cores
+ self.memory_gb = memory_gb
+ self.disk_size_gb = disk_size_gb
+ self.gpu_count = gpu_count
+ self.timeout_minutes = timeout_minutes
+ self.environment_vars = environment_vars
+ self.team_id = team_id
+ self.advanced_configs = advanced_configs
+ self.labels = labels
+
+ self._tunnels: dict[str, Tunnel] = {}
+ self._tunnel_lock = asyncio.Lock()
+
+ def _resolve_tunnel_local_addr(self, state: State) -> str:
+ gateway_url = cast(str, state["gateway_url"])
+ parsed = urlparse(gateway_url)
+ host = parsed.hostname
+ if host is None:
+ raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}")
+ return host
+
+ async def get_tunnel_url(self, local_addr: str | None = None) -> str:
+ """Get tunnel URL, starting the tunnel if needed."""
+ async with self._tunnel_lock:
+ if local_addr is None:
+ if len(self._tunnels) == 1:
+ tunnel = next(iter(self._tunnels.values()))
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+ if len(self._tunnels) == 0:
+ raise ValueError("local_addr is required when starting tunnel")
+ raise ValueError(
+ "local_addr is required when multiple tunnels are active"
+ )
+
+ tunnel = self._tunnels.get(local_addr)
+ if tunnel is None:
+ if logger.isEnabledFor(logging.DEBUG):
+ tunnel = Tunnel(
+ local_port=self.gateway_port,
+ local_addr=local_addr,
+ log_level="debug",
+ )
+ else:
+ tunnel = Tunnel(
+ local_port=self.gateway_port,
+ local_addr=local_addr,
+ )
+ url = await tunnel.start()
+ self._tunnels[local_addr] = tunnel
+ logger.debug(
+ "Prime Tunnel started local_addr=%s url=%s",
+ local_addr,
+ url,
+ )
+ return url
+
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+
+ def _resolve_gateway_url(self, state: State) -> str:
+ client = cast(AsyncOpenAI, state["client"])
+ gateway_url = str(client.base_url).rstrip("/")
+ if gateway_url.endswith("/v1"):
+ gateway_url = gateway_url[:-3]
+ return gateway_url
+
+ @staticmethod
+ def _tail_text(value: Any, max_chars: int = 1200) -> str:
+ if value is None:
+ return ""
+ text = str(value)
+ if len(text) <= max_chars:
+ return text
+ return text[-max_chars:]
+
+ def _rollout_endpoint(self, state: State, suffix: str) -> str:
+ gateway_url = cast(str, state["gateway_url"])
+ rollout_id = cast(str, state["rollout_id"])
+ return f"{gateway_url}/v1/rollouts/{rollout_id}/{suffix.lstrip('/')}"
+
+ async def _gateway_post(
+ self,
+ state: State,
+ suffix: str,
+ payload: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ timeout = httpx.Timeout(self.timeout_seconds)
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ response = await client.post(
+ self._rollout_endpoint(state, suffix),
+ json=payload,
+ )
+ response.raise_for_status()
+ if not response.content:
+ return {}
+ return cast(dict[str, Any], response.json())
+
+ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
+ timeout = httpx.Timeout(self.timeout_seconds)
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ response = await client.get(self._rollout_endpoint(state, suffix))
+ response.raise_for_status()
+ return cast(dict[str, Any], response.json())
+
+ async def register_rollout(self, state: State) -> None:
+ sampling_params = dict(state.get("sampling_args") or {})
+ payload = {
+ "model": state["model"],
+ "sampling_params": sampling_params,
+ "max_turns": self.max_turns,
+ "max_seq_len": self.max_seq_len,
+ }
+ await self._gateway_post(state, "register", payload)
+
+ async def unregister_rollout(self, state: State) -> None:
+ await self._gateway_post(state, "unregister")
+
+ async def fetch_trajectory(self, state: State) -> None:
+ data = await self._gateway_get(state, "trajectory")
+ raw_trajectory = cast(list[dict[str, Any]], data.get("trajectory", []))
+
+ trajectory: list[TrajectoryStep] = []
+ for raw_step in raw_trajectory:
+ step = dict(raw_step)
+ step.setdefault("response", None)
+ step.setdefault("reward", None)
+ step.setdefault("advantage", None)
+ step.setdefault("is_truncated", False)
+ step.setdefault("trajectory_id", state.get("trajectory_id", ""))
+ step.setdefault("extras", {})
+ trajectory.append(cast(TrajectoryStep, step))
+
+ state["trajectory"] = trajectory
+ state["prompt"] = data.get("prompt")
+ state["completion"] = data.get("completion")
+ state["is_truncated"] = bool(
+ data.get("is_truncated", state.get("is_truncated", False))
+ )
+
+ if logger.isEnabledFor(logging.DEBUG):
+ rollout_id = state.get("rollout_id")
+ logger.debug(
+ "rollout=%s fetched trajectory steps=%d truncated=%s",
+ rollout_id,
+ len(trajectory),
+ state["is_truncated"],
+ )
+ for turn_idx, step in enumerate(trajectory):
+ tokens = step.get("tokens")
+ prompt_token_count = (
+ len(tokens["prompt_ids"]) if tokens is not None else 0
+ )
+ completion_token_count = (
+ len(tokens["completion_ids"]) if tokens is not None else 0
+ )
+ logger.debug(
+ "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d",
+ rollout_id,
+ turn_idx,
+ prompt_token_count,
+ completion_token_count,
+ )
+
+ async def get_docker_image(self, state: State) -> str:
+ """Get the Docker image for the sandbox. Override for per-task images."""
+ return self.docker_image
+
+ async def build_env_vars(self, state: State) -> dict[str, str]:
+ """Build environment variables for the sandbox. Override to add custom vars."""
+ env_vars = dict(self.environment_vars) if self.environment_vars else {}
+ env_vars["OPENAI_BASE_URL"] = cast(str, state["rollout_base_url"])
+ env_vars.setdefault("OPENAI_TIMEOUT", "600")
+ env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
+ env_vars.setdefault("HTTPX_TIMEOUT", "600")
+ model = state.get("model")
+ if model:
+ env_vars["OPENAI_MODEL"] = model
+ return env_vars
+
+ async def post_sandbox_setup(self, state: State) -> None:
+ """Hook for post-sandbox setup. Override to upload files, run commands, etc."""
+ pass
+
+ async def start_agent(self, state: State) -> None:
+ """Start the agent command using background job."""
+ sandbox_id = cast(str, state["sandbox_id"])
+ background_job: BackgroundJob = await self.sandbox_client.start_background_job(
+ sandbox_id,
+ self.run_command,
+ )
+ state["background_job"] = background_job
+ state["agent_start_time"] = time.time()
+ state["agent_completed"] = False
+
+ async def wait_for_agent_completion(self, state: State) -> None:
+ """Poll for agent completion using background job API."""
+ sandbox_id = state.get("sandbox_id")
+ background_job = state.get("background_job")
+ if not sandbox_id or not background_job:
+ state["agent_completed"] = True
+ return
+
+ try:
+ await asyncio.wait_for(
+ self.poll_job_completion(
+ state,
+ cast(str, sandbox_id),
+ cast(BackgroundJob, background_job),
+ ),
+ timeout=self.timeout_seconds,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ "rollout=%s sandbox=%s stage=wait_for_agent_completion timed out after %.1fs",
+ state.get("rollout_id"),
+ state.get("sandbox_id"),
+ self.timeout_seconds,
+ )
+ state["agent_timed_out"] = True
+ finally:
+ state["agent_completed"] = True
+
+ async def poll_job_completion(
+ self,
+ state: State,
+ sandbox_id: str,
+ background_job: BackgroundJob,
+ ) -> None:
+ """Poll until background job completes, capturing output."""
+ while True:
+ status: BackgroundJobStatus = await self.sandbox_client.get_background_job(
+ sandbox_id,
+ background_job,
+ )
+ if status.completed:
+ state["agent_exit_code"] = status.exit_code
+ state["agent_stdout"] = status.stdout
+ state["agent_stderr"] = status.stderr
+ if status.exit_code not in (None, 0):
+ logger.warning(
+ "rollout=%s sandbox=%s stage=agent_completed exit_code=%s stdout_tail=%r stderr_tail=%r",
+ state.get("rollout_id"),
+ sandbox_id,
+ status.exit_code,
+ self._tail_text(status.stdout),
+ self._tail_text(status.stderr),
+ )
+ else:
+ logger.debug(
+ "rollout=%s sandbox=%s stage=agent_completed exit_code=%s",
+ state.get("rollout_id"),
+ sandbox_id,
+ status.exit_code,
+ )
+ return
+ await asyncio.sleep(1)
+
+ def _render_timing(self, state: State) -> None:
+ start_time = cast(float, state["timing"]["start_time"])
+ end_time = time.time()
+ generation_ms = (end_time - start_time) * 1000
+ state["timing"]["generation_ms"] = generation_ms
+ state["timing"]["total_ms"] = generation_ms
+
+ async def rollout(
+ self,
+ input: RolloutInput,
+ client: AsyncOpenAI,
+ model: str,
+ sampling_args: SamplingArgs | None = None,
+ ) -> State:
+ state = await self.init_state(input, client, model, sampling_args)
+ state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}"
+ state["gateway_url"] = self._resolve_gateway_url(state)
+ rollout_id = cast(str, state["rollout_id"])
+ info = cast(dict[str, Any], state.get("info") or {})
+ logger.info(
+ "rollout=%s stage=start model=%s example_id=%s repo=%s",
+ rollout_id,
+ model,
+ info.get("instance_id") or info.get("example_id"),
+ info.get("repo_name"),
+ )
+
+ rollout_registered = False
+ failure_stage = "register_rollout"
+ error_stage: str | None = None
+ try:
+ failure_stage = "register_rollout"
+ await self.register_rollout(state)
+ rollout_registered = True
+ logger.debug("rollout=%s stage=register_rollout ok", rollout_id)
+
+ failure_stage = "resolve_tunnel_local_addr"
+ tunnel_local_addr = self._resolve_tunnel_local_addr(state)
+ state["tunnel_local_addr"] = tunnel_local_addr
+ logger.debug(
+ "rollout=%s stage=resolve_tunnel_local_addr addr=%s",
+ rollout_id,
+ tunnel_local_addr,
+ )
+
+ failure_stage = "start_tunnel"
+ tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr)
+ state["tunnel_url"] = tunnel_url
+ state["rollout_base_url"] = (
+ f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}"
+ )
+ logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url)
+
+ failure_stage = "build_env_vars"
+ env_vars = await self.build_env_vars(state)
+ failure_stage = "get_docker_image"
+ docker_image = await self.get_docker_image(state)
+ sandbox_request = CreateSandboxRequest(
+ name=cast(str, state["rollout_id"]),
+ docker_image=docker_image,
+ start_command=self.start_command,
+ cpu_cores=self.cpu_cores,
+ memory_gb=self.memory_gb,
+ disk_size_gb=self.disk_size_gb,
+ gpu_count=self.gpu_count,
+ timeout_minutes=self.timeout_minutes,
+ environment_vars=env_vars,
+ team_id=self.team_id,
+ advanced_configs=self.advanced_configs,
+ labels=self.labels if self.labels else [],
+ )
+ logger.debug(
+ f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} "
+ f"docker_image={docker_image}"
+ )
+ failure_stage = "create_sandbox"
+ await self.create_sandbox(state, sandbox_request)
+ logger.info(
+ "rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s",
+ rollout_id,
+ state.get("sandbox_id"),
+ docker_image,
+ )
+
+ failure_stage = "start_agent"
+ await self.start_agent(state)
+ logger.debug(
+ "rollout=%s stage=start_agent ok sandbox_id=%s",
+ rollout_id,
+ state.get("sandbox_id"),
+ )
+ failure_stage = "wait_for_agent_completion"
+ await self.wait_for_agent_completion(state)
+ logger.debug(
+ "rollout=%s stage=wait_for_agent_completion ok exit_code=%s",
+ rollout_id,
+ state.get("agent_exit_code"),
+ )
+ failure_stage = "fetch_trajectory"
+ await self.fetch_trajectory(state)
+ trajectory = cast(list[Any], state.get("trajectory") or [])
+ logger.info(
+ "rollout=%s stage=fetch_trajectory ok turns=%d truncated=%s",
+ rollout_id,
+ len(trajectory),
+ state.get("is_truncated", False),
+ )
+ if len(trajectory) == 0:
+ logger.warning(
+ "rollout=%s stage=fetch_trajectory empty_trajectory agent_exit_code=%s stdout_tail=%r stderr_tail=%r",
+ rollout_id,
+ state.get("agent_exit_code"),
+ self._tail_text(state.get("agent_stdout")),
+ self._tail_text(state.get("agent_stderr")),
+ )
+ except vf.Error as e:
+ error_stage = failure_stage
+ state["error"] = e
+ logger.exception(
+ "rollout=%s stage=%s vf_error=%s message=%s",
+ rollout_id,
+ failure_stage,
+ type(e).__name__,
+ e,
+ )
+ except Exception as e:
+ error_stage = failure_stage
+ state["error"] = vf.InfraError(str(e))
+ logger.exception(
+ "rollout=%s stage=%s unhandled_error=%s message=%s",
+ rollout_id,
+ failure_stage,
+ type(e).__name__,
+ e,
+ )
+ finally:
+ if rollout_registered:
+ try:
+ failure_stage = "unregister_rollout"
+ await self.unregister_rollout(state)
+ except Exception as e:
+ logger.warning(
+ f"Failed to unregister rollout {state['rollout_id']}: {e}"
+ )
+ if error_stage is None:
+ error_stage = failure_stage
+ if state.get("error") is None:
+ state["error"] = vf.InfraError(str(e))
+
+ if state.get("sandbox_id"):
+ try:
+ failure_stage = "destroy_sandbox"
+ await self.destroy_sandbox(state)
+ except Exception as e:
+ logger.warning(
+ f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}"
+ )
+ if error_stage is None:
+ error_stage = failure_stage
+ if state.get("error") is None:
+ state["error"] = vf.InfraError(str(e))
+
+ if state.get("completion") is None:
+ state["completion"] = []
+ state["failure_stage"] = error_stage
+ if state.get("error") is not None:
+ if state.get("stop_condition") is None:
+ state["stop_condition"] = (
+ f"{error_stage}_error" if error_stage else "has_error"
+ )
+ elif state.get("agent_timed_out", False):
+ if state.get("stop_condition") is None:
+ state["stop_condition"] = "agent_timeout"
+ else:
+ if state.get("stop_condition") is None:
+ state["stop_condition"] = "completed"
+ state["is_completed"] = True
+ self._render_timing(state)
+ logger.info(
+ "rollout=%s stage=finish stop=%s failure_stage=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s",
+ rollout_id,
+ state.get("stop_condition"),
+ error_stage,
+ state.get("sandbox_id"),
+ len(cast(list[Any], state.get("trajectory") or [])),
+ state.get("agent_exit_code"),
+ type(state["error"]).__name__
+ if state.get("error") is not None
+ else None,
+ )
+
+ return state
+
+ @vf.teardown
+ async def teardown_resources(self):
+ """Stop Prime Tunnel."""
+ async with self._tunnel_lock:
+ tunnels = list(self._tunnels.items())
+ self._tunnels = {}
+ for local_addr, tunnel in tunnels:
+ try:
+ tunnel.sync_stop()
+ logger.debug("Prime Tunnel stopped local_addr=%s", local_addr)
+ except Exception as e:
+ logger.warning(
+ "Error stopping Prime Tunnel local_addr=%s: %s",
+ local_addr,
+ e,
+ )
+
+ async def post_rollout(self, state: State):
+ """
+ Override for custom post-rollout logic. For example, if sandbox state is needed for reward functions,
+ run computation here and cache the result in state before sandbox is destroyed.
+ """
+ pass
+
+ @vf.cleanup
+ async def destroy_sandbox(self, state: State):
+ """Cleanup sandbox after rollout."""
+ await self.post_rollout(state)
+ sandbox_id = state.get("sandbox_id")
+ if sandbox_id:
+ await self.delete_sandbox(cast(str, sandbox_id))
From 6dc007835ba5cae1e01dcb6487bcc6d2738b45bf Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Sun, 22 Feb 2026 03:02:10 +0530
Subject: [PATCH 07/21] revert `CliAgentEnv`
---
verifiers/envs/experimental/cli_agent_env.py | 24 --------------------
1 file changed, 24 deletions(-)
diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py
index 834a3b130..5cfb9e42e 100644
--- a/verifiers/envs/experimental/cli_agent_env.py
+++ b/verifiers/envs/experimental/cli_agent_env.py
@@ -172,30 +172,6 @@ async def setup_state(self, state: State) -> State:
return state
- if logger.isEnabledFor(logging.DEBUG):
- rollout_id = state.get("rollout_id")
- logger.debug(
- "rollout=%s fetched trajectory steps=%d truncated=%s",
- rollout_id,
- len(trajectory),
- state["is_truncated"],
- )
- for turn_idx, step in enumerate(trajectory):
- tokens = step.get("tokens")
- prompt_token_count = (
- len(tokens["prompt_ids"]) if tokens is not None else 0
- )
- completion_token_count = (
- len(tokens["completion_ids"]) if tokens is not None else 0
- )
- logger.debug(
- "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d",
- rollout_id,
- turn_idx,
- prompt_token_count,
- completion_token_count,
- )
-
async def get_docker_image(self, state: State) -> str:
"""Get the Docker image for the sandbox. Override for per-task images."""
return self.docker_image
From 5d446d0bd48ff3a6a5493c63a63c070551afe806 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Sun, 22 Feb 2026 03:49:34 +0530
Subject: [PATCH 08/21] rename test
---
...est_cli_agent_env_gateway.py => test_rollout_gateway_env.py} | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
rename tests/{test_cli_agent_env_gateway.py => test_rollout_gateway_env.py} (99%)
diff --git a/tests/test_cli_agent_env_gateway.py b/tests/test_rollout_gateway_env.py
similarity index 99%
rename from tests/test_cli_agent_env_gateway.py
rename to tests/test_rollout_gateway_env.py
index 18a17cfc2..01545f82d 100644
--- a/tests/test_cli_agent_env_gateway.py
+++ b/tests/test_rollout_gateway_env.py
@@ -41,7 +41,7 @@ def sync_stop(self) -> None:
self.stop_calls += 1
-class GatewayCliAgentEnv(vf.CliAgentEnv):
+class GatewayCliAgentEnv(vf.RolloutGatewayEnv):
async def post_rollout(self, state: vf.State):
state["reward"] = 1.0
state["test_output"] = "ok"
From afa451186069feb4e7534e1cad0048f1d81ce869 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Sun, 22 Feb 2026 04:08:55 +0530
Subject: [PATCH 09/21] fix tests
---
tests/test_rollout_gateway_env.py | 16 ++++++++++------
.../envs/experimental/rollout_gateway_env.py | 3 ++-
2 files changed, 12 insertions(+), 7 deletions(-)
diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py
index 01545f82d..60a7d72b7 100644
--- a/tests/test_rollout_gateway_env.py
+++ b/tests/test_rollout_gateway_env.py
@@ -2,7 +2,6 @@
import json
from types import SimpleNamespace
-from typing import Any, cast
from unittest.mock import AsyncMock
import httpx
@@ -10,7 +9,7 @@
from datasets import Dataset
import verifiers as vf
-import verifiers.envs.experimental.cli_agent_env as cli_agent_env
+import verifiers.envs.experimental.rollout_gateway_env as rollout_gateway_env
pytestmark = [pytest.mark.integration, pytest.mark.environments]
@@ -160,7 +159,7 @@ def _handler(request: httpx.Request) -> httpx.Response:
@pytest.mark.asyncio
async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
FakeTunnel.instances.clear()
- monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel)
+ monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel)
tracker = {
"paths": [],
@@ -172,12 +171,18 @@ async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
}
transport = _build_gateway_transport(tracker)
real_async_client = httpx.AsyncClient
+ client = vf.OpenAIChatCompletionsClient(
+ vf.ClientConfig(
+ api_key_var="UNIT_TEST_API_KEY",
+ api_base_url="http://gateway.internal:8000/v1/",
+ )
+ )
def _client_factory(*args, **kwargs):
kwargs["transport"] = transport
return real_async_client(*args, **kwargs)
- monkeypatch.setattr(cli_agent_env.httpx, "AsyncClient", _client_factory)
+ monkeypatch.setattr(rollout_gateway_env.httpx, "AsyncClient", _client_factory)
dataset = Dataset.from_dict(
{
@@ -216,7 +221,6 @@ def _client_factory(*args, **kwargs):
"example_id": 0,
"task": "gateway-test",
}
- client = cast(Any, SimpleNamespace(base_url="http://gateway.internal:8000/v1/"))
state = await env.rollout(
input=rollout_input,
client=client,
@@ -264,7 +268,7 @@ def _client_factory(*args, **kwargs):
@pytest.mark.asyncio
async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
FakeTunnel.instances.clear()
- monkeypatch.setattr(cli_agent_env, "Tunnel", FakeTunnel)
+ monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel)
dataset = Dataset.from_dict(
{
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index b0fc4a08e..3d25ae94f 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -144,7 +144,8 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str:
return tunnel.url
def _resolve_gateway_url(self, state: State) -> str:
- client = cast(AsyncOpenAI, state["client"])
+ # `state["client"]` may be a Verifiers wrapper with the raw client on `.client`.
+ client = getattr(state["client"], "client", state["client"])
gateway_url = str(client.base_url).rstrip("/")
if gateway_url.endswith("/v1"):
gateway_url = gateway_url[:-3]
From 715f6fcfb689601ca6a100b5f1be0df0b24a3278 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Sun, 22 Feb 2026 04:34:59 +0530
Subject: [PATCH 10/21] simplify
---
.../envs/experimental/rollout_gateway_env.py | 97 ++++---------------
1 file changed, 19 insertions(+), 78 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index 3d25ae94f..c6a12970b 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -120,17 +120,11 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str:
tunnel = self._tunnels.get(local_addr)
if tunnel is None:
- if logger.isEnabledFor(logging.DEBUG):
- tunnel = Tunnel(
- local_port=self.gateway_port,
- local_addr=local_addr,
- log_level="debug",
- )
- else:
- tunnel = Tunnel(
- local_port=self.gateway_port,
- local_addr=local_addr,
- )
+ tunnel = Tunnel(
+ local_port=self.gateway_port,
+ local_addr=local_addr,
+ log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info",
+ )
url = await tunnel.start()
self._tunnels[local_addr] = tunnel
logger.debug(
@@ -161,9 +155,7 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str:
return text[-max_chars:]
def _rollout_endpoint(self, state: State, suffix: str) -> str:
- gateway_url = cast(str, state["gateway_url"])
- rollout_id = cast(str, state["rollout_id"])
- return f"{gateway_url}/v1/rollouts/{rollout_id}/{suffix.lstrip('/')}"
+ return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}"
async def _gateway_post(
self,
@@ -190,7 +182,7 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
return cast(dict[str, Any], response.json())
async def register_rollout(self, state: State) -> None:
- sampling_params = dict(state.get("sampling_args") or {})
+ sampling_params = state.get("sampling_args") or {}
payload = {
"model": state["model"],
"sampling_params": sampling_params,
@@ -204,16 +196,16 @@ async def unregister_rollout(self, state: State) -> None:
async def fetch_trajectory(self, state: State) -> None:
data = await self._gateway_get(state, "trajectory")
- raw_trajectory = cast(list[dict[str, Any]], data.get("trajectory", []))
+ raw_trajectory = data.get("trajectory", [])
trajectory: list[TrajectoryStep] = []
- for raw_step in raw_trajectory:
- step = dict(raw_step)
+ # TODO: Pydantic response schema in gateway
+ for step in raw_trajectory:
step.setdefault("response", None)
step.setdefault("reward", None)
step.setdefault("advantage", None)
step.setdefault("is_truncated", False)
- step.setdefault("trajectory_id", state.get("trajectory_id", ""))
+ step.setdefault("trajectory_id", state["trajectory_id"])
step.setdefault("extras", {})
trajectory.append(cast(TrajectoryStep, step))
@@ -224,30 +216,6 @@ async def fetch_trajectory(self, state: State) -> None:
data.get("is_truncated", state.get("is_truncated", False))
)
- if logger.isEnabledFor(logging.DEBUG):
- rollout_id = state.get("rollout_id")
- logger.debug(
- "rollout=%s fetched trajectory steps=%d truncated=%s",
- rollout_id,
- len(trajectory),
- state["is_truncated"],
- )
- for turn_idx, step in enumerate(trajectory):
- tokens = step.get("tokens")
- prompt_token_count = (
- len(tokens["prompt_ids"]) if tokens is not None else 0
- )
- completion_token_count = (
- len(tokens["completion_ids"]) if tokens is not None else 0
- )
- logger.debug(
- "rollout=%s turn=%d prompt_tokens=%d completion_tokens=%d",
- rollout_id,
- turn_idx,
- prompt_token_count,
- completion_token_count,
- )
-
async def get_docker_image(self, state: State) -> str:
"""Get the Docker image for the sandbox. Override for per-task images."""
return self.docker_image
@@ -343,8 +311,8 @@ async def poll_job_completion(
await asyncio.sleep(1)
def _render_timing(self, state: State) -> None:
- start_time = cast(float, state["timing"]["start_time"])
- end_time = time.time()
+ start_time = state["timing"]["start_time"]
+ end_time = time.perf_counter()
generation_ms = (end_time - start_time) * 1000
state["timing"]["generation_ms"] = generation_ms
state["timing"]["total_ms"] = generation_ms
@@ -359,8 +327,8 @@ async def rollout(
state = await self.init_state(input, client, model, sampling_args)
state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}"
state["gateway_url"] = self._resolve_gateway_url(state)
- rollout_id = cast(str, state["rollout_id"])
- info = cast(dict[str, Any], state.get("info") or {})
+ rollout_id = state["rollout_id"]
+ info = state.get("info") or {}
logger.info(
"rollout=%s stage=start model=%s example_id=%s repo=%s",
rollout_id,
@@ -370,15 +338,11 @@ async def rollout(
)
rollout_registered = False
- failure_stage = "register_rollout"
- error_stage: str | None = None
try:
- failure_stage = "register_rollout"
await self.register_rollout(state)
rollout_registered = True
logger.debug("rollout=%s stage=register_rollout ok", rollout_id)
- failure_stage = "resolve_tunnel_local_addr"
tunnel_local_addr = self._resolve_tunnel_local_addr(state)
state["tunnel_local_addr"] = tunnel_local_addr
logger.debug(
@@ -387,7 +351,6 @@ async def rollout(
tunnel_local_addr,
)
- failure_stage = "start_tunnel"
tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr)
state["tunnel_url"] = tunnel_url
state["rollout_base_url"] = (
@@ -395,9 +358,7 @@ async def rollout(
)
logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url)
- failure_stage = "build_env_vars"
env_vars = await self.build_env_vars(state)
- failure_stage = "get_docker_image"
docker_image = await self.get_docker_image(state)
sandbox_request = CreateSandboxRequest(
name=cast(str, state["rollout_id"]),
@@ -417,7 +378,6 @@ async def rollout(
f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} "
f"docker_image={docker_image}"
)
- failure_stage = "create_sandbox"
await self.create_sandbox(state, sandbox_request)
logger.info(
"rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s",
@@ -426,21 +386,18 @@ async def rollout(
docker_image,
)
- failure_stage = "start_agent"
await self.start_agent(state)
logger.debug(
"rollout=%s stage=start_agent ok sandbox_id=%s",
rollout_id,
state.get("sandbox_id"),
)
- failure_stage = "wait_for_agent_completion"
await self.wait_for_agent_completion(state)
logger.debug(
"rollout=%s stage=wait_for_agent_completion ok exit_code=%s",
rollout_id,
state.get("agent_exit_code"),
)
- failure_stage = "fetch_trajectory"
await self.fetch_trajectory(state)
trajectory = cast(list[Any], state.get("trajectory") or [])
logger.info(
@@ -458,60 +415,47 @@ async def rollout(
self._tail_text(state.get("agent_stderr")),
)
except vf.Error as e:
- error_stage = failure_stage
state["error"] = e
logger.exception(
"rollout=%s stage=%s vf_error=%s message=%s",
rollout_id,
- failure_stage,
type(e).__name__,
e,
)
except Exception as e:
- error_stage = failure_stage
state["error"] = vf.InfraError(str(e))
logger.exception(
"rollout=%s stage=%s unhandled_error=%s message=%s",
rollout_id,
- failure_stage,
type(e).__name__,
e,
)
finally:
if rollout_registered:
try:
- failure_stage = "unregister_rollout"
await self.unregister_rollout(state)
except Exception as e:
logger.warning(
f"Failed to unregister rollout {state['rollout_id']}: {e}"
)
- if error_stage is None:
- error_stage = failure_stage
if state.get("error") is None:
state["error"] = vf.InfraError(str(e))
if state.get("sandbox_id"):
try:
- failure_stage = "destroy_sandbox"
await self.destroy_sandbox(state)
except Exception as e:
logger.warning(
f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}"
)
- if error_stage is None:
- error_stage = failure_stage
if state.get("error") is None:
state["error"] = vf.InfraError(str(e))
if state.get("completion") is None:
state["completion"] = []
- state["failure_stage"] = error_stage
if state.get("error") is not None:
if state.get("stop_condition") is None:
- state["stop_condition"] = (
- f"{error_stage}_error" if error_stage else "has_error"
- )
+ state["stop_condition"] = "has_error"
elif state.get("agent_timed_out", False):
if state.get("stop_condition") is None:
state["stop_condition"] = "agent_timeout"
@@ -521,16 +465,13 @@ async def rollout(
state["is_completed"] = True
self._render_timing(state)
logger.info(
- "rollout=%s stage=finish stop=%s failure_stage=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s",
+ "rollout=%s stage=finish stop=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s",
rollout_id,
state.get("stop_condition"),
- error_stage,
state.get("sandbox_id"),
- len(cast(list[Any], state.get("trajectory") or [])),
+ len(state.get("trajectory", [])),
state.get("agent_exit_code"),
- type(state["error"]).__name__
- if state.get("error") is not None
- else None,
+ type(state["error"]).__name__ if state.get("error") else None,
)
return state
From b2d42b8e93a8798eb49834e6ed5fb2aecd5cbea0 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Mon, 23 Feb 2026 06:05:34 +0530
Subject: [PATCH 11/21] server side reponse models
---
.../envs/experimental/rollout_gateway_env.py | 15 +--------------
1 file changed, 1 insertion(+), 14 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index c6a12970b..754a307dc 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -196,20 +196,7 @@ async def unregister_rollout(self, state: State) -> None:
async def fetch_trajectory(self, state: State) -> None:
data = await self._gateway_get(state, "trajectory")
- raw_trajectory = data.get("trajectory", [])
-
- trajectory: list[TrajectoryStep] = []
- # TODO: Pydantic response schema in gateway
- for step in raw_trajectory:
- step.setdefault("response", None)
- step.setdefault("reward", None)
- step.setdefault("advantage", None)
- step.setdefault("is_truncated", False)
- step.setdefault("trajectory_id", state["trajectory_id"])
- step.setdefault("extras", {})
- trajectory.append(cast(TrajectoryStep, step))
-
- state["trajectory"] = trajectory
+ state["trajectory"] = cast(list[TrajectoryStep], data.get("trajectory", []))
state["prompt"] = data.get("prompt")
state["completion"] = data.get("completion")
state["is_truncated"] = bool(
From 5d8f422555eb1a86a71e19177c3bd5fe3873eb40 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Mon, 23 Feb 2026 06:05:59 +0530
Subject: [PATCH 12/21] fix slop
---
.../envs/experimental/rollout_gateway_env.py | 116 +++++-------------
1 file changed, 32 insertions(+), 84 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index 754a307dc..4b4df2fe1 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -2,7 +2,7 @@
import logging
import time
import uuid
-from typing import Any, cast
+from typing import Any
from urllib.parse import urlparse
import httpx
@@ -18,7 +18,7 @@
import verifiers as vf
from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
from verifiers.envs.multiturn_env import MultiTurnMonitorRubric
-from verifiers.types import RolloutInput, SamplingArgs, State, TrajectoryStep
+from verifiers.types import RolloutInput, SamplingArgs, State
logger = logging.getLogger(__name__)
@@ -97,7 +97,7 @@ def __init__(
self._tunnel_lock = asyncio.Lock()
def _resolve_tunnel_local_addr(self, state: State) -> str:
- gateway_url = cast(str, state["gateway_url"])
+ gateway_url = state["gateway_url"]
parsed = urlparse(gateway_url)
host = parsed.hostname
if host is None:
@@ -127,11 +127,7 @@ async def get_tunnel_url(self, local_addr: str | None = None) -> str:
)
url = await tunnel.start()
self._tunnels[local_addr] = tunnel
- logger.debug(
- "Prime Tunnel started local_addr=%s url=%s",
- local_addr,
- url,
- )
+ logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}")
return url
assert tunnel.url is not None, "Tunnel started but URL is None"
@@ -172,14 +168,14 @@ async def _gateway_post(
response.raise_for_status()
if not response.content:
return {}
- return cast(dict[str, Any], response.json())
+ return response.json()
async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
timeout = httpx.Timeout(self.timeout_seconds)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(self._rollout_endpoint(state, suffix))
response.raise_for_status()
- return cast(dict[str, Any], response.json())
+ return response.json()
async def register_rollout(self, state: State) -> None:
sampling_params = state.get("sampling_args") or {}
@@ -196,7 +192,7 @@ async def unregister_rollout(self, state: State) -> None:
async def fetch_trajectory(self, state: State) -> None:
data = await self._gateway_get(state, "trajectory")
- state["trajectory"] = cast(list[TrajectoryStep], data.get("trajectory", []))
+ state["trajectory"] = data.get("trajectory", [])
state["prompt"] = data.get("prompt")
state["completion"] = data.get("completion")
state["is_truncated"] = bool(
@@ -210,7 +206,7 @@ async def get_docker_image(self, state: State) -> str:
async def build_env_vars(self, state: State) -> dict[str, str]:
"""Build environment variables for the sandbox. Override to add custom vars."""
env_vars = dict(self.environment_vars) if self.environment_vars else {}
- env_vars["OPENAI_BASE_URL"] = cast(str, state["rollout_base_url"])
+ env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"]
env_vars.setdefault("OPENAI_TIMEOUT", "600")
env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
env_vars.setdefault("HTTPX_TIMEOUT", "600")
@@ -225,8 +221,8 @@ async def post_sandbox_setup(self, state: State) -> None:
async def start_agent(self, state: State) -> None:
"""Start the agent command using background job."""
- sandbox_id = cast(str, state["sandbox_id"])
- background_job: BackgroundJob = await self.sandbox_client.start_background_job(
+ sandbox_id = state["sandbox_id"]
+ background_job = await self.sandbox_client.start_background_job(
sandbox_id,
self.run_command,
)
@@ -244,19 +240,12 @@ async def wait_for_agent_completion(self, state: State) -> None:
try:
await asyncio.wait_for(
- self.poll_job_completion(
- state,
- cast(str, sandbox_id),
- cast(BackgroundJob, background_job),
- ),
+ self.poll_job_completion(state, sandbox_id, background_job),
timeout=self.timeout_seconds,
)
except asyncio.TimeoutError:
logger.warning(
- "rollout=%s sandbox=%s stage=wait_for_agent_completion timed out after %.1fs",
- state.get("rollout_id"),
- state.get("sandbox_id"),
- self.timeout_seconds,
+ f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} stage=wait_for_agent_completion timed out after {self.timeout_seconds:.1f}s"
)
state["agent_timed_out"] = True
finally:
@@ -280,19 +269,11 @@ async def poll_job_completion(
state["agent_stderr"] = status.stderr
if status.exit_code not in (None, 0):
logger.warning(
- "rollout=%s sandbox=%s stage=agent_completed exit_code=%s stdout_tail=%r stderr_tail=%r",
- state.get("rollout_id"),
- sandbox_id,
- status.exit_code,
- self._tail_text(status.stdout),
- self._tail_text(status.stderr),
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={self._tail_text(status.stdout)!r} stderr_tail={self._tail_text(status.stderr)!r}"
)
else:
logger.debug(
- "rollout=%s sandbox=%s stage=agent_completed exit_code=%s",
- state.get("rollout_id"),
- sandbox_id,
- status.exit_code,
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}"
)
return
await asyncio.sleep(1)
@@ -317,25 +298,19 @@ async def rollout(
rollout_id = state["rollout_id"]
info = state.get("info") or {}
logger.info(
- "rollout=%s stage=start model=%s example_id=%s repo=%s",
- rollout_id,
- model,
- info.get("instance_id") or info.get("example_id"),
- info.get("repo_name"),
+ f"rollout={rollout_id} stage=start model={model} example_id={info.get('instance_id') or info.get('example_id')} repo={info.get('repo_name')}"
)
rollout_registered = False
try:
await self.register_rollout(state)
rollout_registered = True
- logger.debug("rollout=%s stage=register_rollout ok", rollout_id)
+ logger.debug(f"rollout={rollout_id} stage=register_rollout ok")
tunnel_local_addr = self._resolve_tunnel_local_addr(state)
state["tunnel_local_addr"] = tunnel_local_addr
logger.debug(
- "rollout=%s stage=resolve_tunnel_local_addr addr=%s",
- rollout_id,
- tunnel_local_addr,
+ f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}"
)
tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr)
@@ -343,12 +318,12 @@ async def rollout(
state["rollout_base_url"] = (
f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}"
)
- logger.debug("rollout=%s stage=start_tunnel url=%s", rollout_id, tunnel_url)
+ logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}")
env_vars = await self.build_env_vars(state)
docker_image = await self.get_docker_image(state)
sandbox_request = CreateSandboxRequest(
- name=cast(str, state["rollout_id"]),
+ name=state["rollout_id"],
docker_image=docker_image,
start_command=self.start_command,
cpu_cores=self.cpu_cores,
@@ -367,55 +342,35 @@ async def rollout(
)
await self.create_sandbox(state, sandbox_request)
logger.info(
- "rollout=%s stage=create_sandbox ok sandbox_id=%s docker_image=%s",
- rollout_id,
- state.get("sandbox_id"),
- docker_image,
+ f"rollout={rollout_id} stage=create_sandbox ok sandbox_id={state.get('sandbox_id')} docker_image={docker_image}"
)
await self.start_agent(state)
logger.debug(
- "rollout=%s stage=start_agent ok sandbox_id=%s",
- rollout_id,
- state.get("sandbox_id"),
+ f"rollout={rollout_id} stage=start_agent ok sandbox_id={state.get('sandbox_id')}"
)
await self.wait_for_agent_completion(state)
logger.debug(
- "rollout=%s stage=wait_for_agent_completion ok exit_code=%s",
- rollout_id,
- state.get("agent_exit_code"),
+ f"rollout={rollout_id} stage=wait_for_agent_completion ok exit_code={state.get('agent_exit_code')}"
)
await self.fetch_trajectory(state)
- trajectory = cast(list[Any], state.get("trajectory") or [])
+ trajectory = state.get("trajectory") or []
logger.info(
- "rollout=%s stage=fetch_trajectory ok turns=%d truncated=%s",
- rollout_id,
- len(trajectory),
- state.get("is_truncated", False),
+ f"rollout={rollout_id} stage=fetch_trajectory ok turns={len(trajectory)} truncated={state.get('is_truncated', False)}"
)
if len(trajectory) == 0:
logger.warning(
- "rollout=%s stage=fetch_trajectory empty_trajectory agent_exit_code=%s stdout_tail=%r stderr_tail=%r",
- rollout_id,
- state.get("agent_exit_code"),
- self._tail_text(state.get("agent_stdout")),
- self._tail_text(state.get("agent_stderr")),
+ f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={self._tail_text(state.get('agent_stdout'))!r} stderr_tail={self._tail_text(state.get('agent_stderr'))!r}"
)
except vf.Error as e:
state["error"] = e
logger.exception(
- "rollout=%s stage=%s vf_error=%s message=%s",
- rollout_id,
- type(e).__name__,
- e,
+ f"rollout={rollout_id} stage={type(e).__name__} vf_error={e}"
)
except Exception as e:
state["error"] = vf.InfraError(str(e))
logger.exception(
- "rollout=%s stage=%s unhandled_error=%s message=%s",
- rollout_id,
- type(e).__name__,
- e,
+ f"rollout={rollout_id} stage={type(e).__name__} unhandled_error={e}"
)
finally:
if rollout_registered:
@@ -451,14 +406,9 @@ async def rollout(
state["stop_condition"] = "completed"
state["is_completed"] = True
self._render_timing(state)
+ error_name = type(state["error"]).__name__ if state.get("error") else None
logger.info(
- "rollout=%s stage=finish stop=%s sandbox_id=%s turns=%d agent_exit_code=%s error=%s",
- rollout_id,
- state.get("stop_condition"),
- state.get("sandbox_id"),
- len(state.get("trajectory", [])),
- state.get("agent_exit_code"),
- type(state["error"]).__name__ if state.get("error") else None,
+ f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}"
)
return state
@@ -472,12 +422,10 @@ async def teardown_resources(self):
for local_addr, tunnel in tunnels:
try:
tunnel.sync_stop()
- logger.debug("Prime Tunnel stopped local_addr=%s", local_addr)
+ logger.debug(f"Prime Tunnel stopped local_addr={local_addr}")
except Exception as e:
logger.warning(
- "Error stopping Prime Tunnel local_addr=%s: %s",
- local_addr,
- e,
+ f"Error stopping Prime Tunnel local_addr={local_addr}: {e}"
)
async def post_rollout(self, state: State):
@@ -493,4 +441,4 @@ async def destroy_sandbox(self, state: State):
await self.post_rollout(state)
sandbox_id = state.get("sandbox_id")
if sandbox_id:
- await self.delete_sandbox(cast(str, sandbox_id))
+ await self.delete_sandbox(sandbox_id)
From 899ce24283f72f7e81baf918f288844f90096763 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Tue, 24 Feb 2026 07:07:00 +0530
Subject: [PATCH 13/21] refactor + ty
---
.../envs/experimental/rollout_gateway_env.py | 78 +++++++++----------
1 file changed, 37 insertions(+), 41 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index 4b4df2fe1..be8c2726b 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -6,7 +6,6 @@
from urllib.parse import urlparse
import httpx
-from openai import AsyncOpenAI
from prime_sandboxes import (
AdvancedConfigs,
BackgroundJob,
@@ -16,13 +15,21 @@
from prime_tunnel import Tunnel
import verifiers as vf
+from verifiers.clients import Client
from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
from verifiers.envs.multiturn_env import MultiTurnMonitorRubric
-from verifiers.types import RolloutInput, SamplingArgs, State
+from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State
logger = logging.getLogger(__name__)
+def _tail_text(value: Any, max_chars: int = 1200) -> str:
+ if value is None:
+ return ""
+ text = str(value)
+ return text[-max_chars:] if len(text) > max_chars else text
+
+
class RolloutGatewayEnv(SandboxMixin, vf.Environment):
"""
Environment for running full agent code inside sandboxes.
@@ -93,6 +100,9 @@ def __init__(
self.advanced_configs = advanced_configs
self.labels = labels
+ self._http_client = httpx.AsyncClient(
+ timeout=httpx.Timeout(self.timeout_seconds)
+ )
self._tunnels: dict[str, Tunnel] = {}
self._tunnel_lock = asyncio.Lock()
@@ -141,15 +151,6 @@ def _resolve_gateway_url(self, state: State) -> str:
gateway_url = gateway_url[:-3]
return gateway_url
- @staticmethod
- def _tail_text(value: Any, max_chars: int = 1200) -> str:
- if value is None:
- return ""
- text = str(value)
- if len(text) <= max_chars:
- return text
- return text[-max_chars:]
-
def _rollout_endpoint(self, state: State, suffix: str) -> str:
return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}"
@@ -159,23 +160,19 @@ async def _gateway_post(
suffix: str,
payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
- timeout = httpx.Timeout(self.timeout_seconds)
- async with httpx.AsyncClient(timeout=timeout) as client:
- response = await client.post(
- self._rollout_endpoint(state, suffix),
- json=payload,
- )
- response.raise_for_status()
- if not response.content:
- return {}
- return response.json()
+ response = await self._http_client.post(
+ self._rollout_endpoint(state, suffix),
+ json=payload,
+ )
+ response.raise_for_status()
+ if not response.content:
+ return {}
+ return response.json()
async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
- timeout = httpx.Timeout(self.timeout_seconds)
- async with httpx.AsyncClient(timeout=timeout) as client:
- response = await client.get(self._rollout_endpoint(state, suffix))
- response.raise_for_status()
- return response.json()
+ response = await self._http_client.get(self._rollout_endpoint(state, suffix))
+ response.raise_for_status()
+ return response.json()
async def register_rollout(self, state: State) -> None:
sampling_params = state.get("sampling_args") or {}
@@ -269,16 +266,16 @@ async def poll_job_completion(
state["agent_stderr"] = status.stderr
if status.exit_code not in (None, 0):
logger.warning(
- f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={self._tail_text(status.stdout)!r} stderr_tail={self._tail_text(status.stderr)!r}"
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={_tail_text(status.stdout)!r} stderr_tail={_tail_text(status.stderr)!r}"
)
else:
logger.debug(
f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}"
)
return
- await asyncio.sleep(1)
+ await asyncio.sleep(self.poll_interval)
- def _render_timing(self, state: State) -> None:
+ async def _render_timing(self, state: State) -> None:
start_time = state["timing"]["start_time"]
end_time = time.perf_counter()
generation_ms = (end_time - start_time) * 1000
@@ -288,7 +285,7 @@ def _render_timing(self, state: State) -> None:
async def rollout(
self,
input: RolloutInput,
- client: AsyncOpenAI,
+ client: Client | ClientConfig,
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
@@ -334,7 +331,7 @@ async def rollout(
environment_vars=env_vars,
team_id=self.team_id,
advanced_configs=self.advanced_configs,
- labels=self.labels if self.labels else [],
+ labels=self.labels or [],
)
logger.debug(
f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} "
@@ -360,7 +357,7 @@ async def rollout(
)
if len(trajectory) == 0:
logger.warning(
- f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={self._tail_text(state.get('agent_stdout'))!r} stderr_tail={self._tail_text(state.get('agent_stderr'))!r}"
+ f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={_tail_text(state.get('agent_stdout'))!r} stderr_tail={_tail_text(state.get('agent_stderr'))!r}"
)
except vf.Error as e:
state["error"] = e
@@ -385,7 +382,7 @@ async def rollout(
if state.get("sandbox_id"):
try:
- await self.destroy_sandbox(state)
+ await self._cleanup(state)
except Exception as e:
logger.warning(
f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}"
@@ -395,17 +392,15 @@ async def rollout(
if state.get("completion") is None:
state["completion"] = []
- if state.get("error") is not None:
- if state.get("stop_condition") is None:
+ if state.get("stop_condition") is None:
+ if state.get("error") is not None:
state["stop_condition"] = "has_error"
- elif state.get("agent_timed_out", False):
- if state.get("stop_condition") is None:
+ elif state.get("agent_timed_out", False):
state["stop_condition"] = "agent_timeout"
- else:
- if state.get("stop_condition") is None:
+ else:
state["stop_condition"] = "completed"
state["is_completed"] = True
- self._render_timing(state)
+ await self._render_timing(state)
error_name = type(state["error"]).__name__ if state.get("error") else None
logger.info(
f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}"
@@ -415,7 +410,8 @@ async def rollout(
@vf.teardown
async def teardown_resources(self):
- """Stop Prime Tunnel."""
+ """Stop Prime Tunnel and close HTTP client."""
+ await self._http_client.aclose()
async with self._tunnel_lock:
tunnels = list(self._tunnels.items())
self._tunnels = {}
From d1eb28ffc25b5e42d22ec31bbd743c5200cd4609 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Tue, 24 Feb 2026 08:29:17 +0530
Subject: [PATCH 14/21] fix
---
verifiers/envs/experimental/rollout_gateway_env.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_env.py
index be8c2726b..53a264bea 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_env.py
@@ -277,7 +277,7 @@ async def poll_job_completion(
async def _render_timing(self, state: State) -> None:
start_time = state["timing"]["start_time"]
- end_time = time.perf_counter()
+ end_time = time.time()
generation_ms = (end_time - start_time) * 1000
state["timing"]["generation_ms"] = generation_ms
state["timing"]["total_ms"] = generation_ms
From d2301af4420c41df6a31c5cc70db72dfd4411feb Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 01:42:05 +0530
Subject: [PATCH 15/21] refactor: `RolloutGatewayMixin`
---
tests/test_rollout_gateway_env.py | 32 +-
verifiers/__init__.py | 6 +-
verifiers/envs/experimental/__init__.py | 3 +-
verifiers/envs/experimental/cli_agent_env.py | 17 +-
...ateway_env.py => rollout_gateway_mixin.py} | 354 +++++++++---------
5 files changed, 206 insertions(+), 206 deletions(-)
rename verifiers/envs/experimental/{rollout_gateway_env.py => rollout_gateway_mixin.py} (64%)
diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py
index 60a7d72b7..d16da5d35 100644
--- a/tests/test_rollout_gateway_env.py
+++ b/tests/test_rollout_gateway_env.py
@@ -9,7 +9,7 @@
from datasets import Dataset
import verifiers as vf
-import verifiers.envs.experimental.rollout_gateway_env as rollout_gateway_env
+import verifiers.envs.experimental.rollout_gateway_mixin as rollout_gateway_mixin
pytestmark = [pytest.mark.integration, pytest.mark.environments]
@@ -40,7 +40,15 @@ def sync_stop(self) -> None:
self.stop_calls += 1
-class GatewayCliAgentEnv(vf.RolloutGatewayEnv):
+class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):
+ def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs):
+ super().__init__(**kwargs)
+ self.use_gateway = use_gateway
+ if use_gateway:
+ self.init_gateway(
+ gateway_port=gateway_port, timeout_seconds=self.timeout_seconds
+ )
+
async def post_rollout(self, state: vf.State):
state["reward"] = 1.0
state["test_output"] = "ok"
@@ -159,7 +167,7 @@ def _handler(request: httpx.Request) -> httpx.Response:
@pytest.mark.asyncio
async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
FakeTunnel.instances.clear()
- monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel)
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
tracker = {
"paths": [],
@@ -182,7 +190,7 @@ def _client_factory(*args, **kwargs):
kwargs["transport"] = transport
return real_async_client(*args, **kwargs)
- monkeypatch.setattr(rollout_gateway_env.httpx, "AsyncClient", _client_factory)
+ monkeypatch.setattr(rollout_gateway_mixin.httpx, "AsyncClient", _client_factory)
dataset = Dataset.from_dict(
{
@@ -258,17 +266,17 @@ def _client_factory(*args, **kwargs):
assert tunnel.local_port == 8000
assert tunnel.local_addr == "gateway.internal"
assert tunnel.start_calls == 1
- assert await env.get_tunnel_url() == "https://unit-test.tunnel.prime.ai"
+ assert await env.get_gateway_tunnel_url() == "https://unit-test.tunnel.prime.ai"
assert tunnel.start_calls == 1
- await env.teardown_resources()
+ await env.teardown_gateway()
assert tunnel.stop_calls == 1
@pytest.mark.asyncio
async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
FakeTunnel.instances.clear()
- monkeypatch.setattr(rollout_gateway_env, "Tunnel", FakeTunnel)
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
dataset = Dataset.from_dict(
{
@@ -284,9 +292,9 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
gateway_port=8000,
)
- url_a = await env.get_tunnel_url(local_addr="10.20.0.58")
- url_b = await env.get_tunnel_url(local_addr="10.20.0.59")
- url_a_reuse = await env.get_tunnel_url(local_addr="10.20.0.58")
+ url_a = await env.get_gateway_tunnel_url(local_addr="10.20.0.58")
+ url_b = await env.get_gateway_tunnel_url(local_addr="10.20.0.59")
+ url_a_reuse = await env.get_gateway_tunnel_url(local_addr="10.20.0.58")
assert url_a == "https://unit-test.tunnel.prime.ai"
assert url_b == "https://unit-test.tunnel.prime.ai"
@@ -299,7 +307,7 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
with pytest.raises(
ValueError, match="local_addr is required when multiple tunnels are active"
):
- await env.get_tunnel_url()
+ await env.get_gateway_tunnel_url()
- await env.teardown_resources()
+ await env.teardown_gateway()
assert sum(t.stop_calls for t in FakeTunnel.instances) == 2
diff --git a/verifiers/__init__.py b/verifiers/__init__.py
index 7c91a5600..e7fa269ea 100644
--- a/verifiers/__init__.py
+++ b/verifiers/__init__.py
@@ -67,7 +67,7 @@
"ReasoningGymEnv",
"GymEnv",
"CliAgentEnv",
- "RolloutGatewayEnv",
+ "RolloutGatewayMixin",
"HarborEnv",
"MCPEnv",
"BrowserEnv",
@@ -122,7 +122,7 @@
"PythonEnv": "verifiers.envs.python_env:PythonEnv",
"GymEnv": "verifiers.envs.experimental.gym_env:GymEnv",
"CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv",
- "RolloutGatewayEnv": "verifiers.envs.experimental.rollout_gateway_env:RolloutGatewayEnv",
+ "RolloutGatewayMixin": "verifiers.envs.experimental.rollout_gateway_mixin:RolloutGatewayMixin",
"HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv",
"MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv",
"ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv",
@@ -162,7 +162,7 @@ def __getattr__(name: str):
from typing import Any
from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401
- from .envs.experimental.rollout_gateway_env import RolloutGatewayEnv # noqa: F401
+ from .envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin # noqa: F401
from .envs.experimental.gym_env import GymEnv # noqa: F401
from .envs.experimental.harbor_env import HarborEnv # noqa: F401
from .envs.experimental.mcp_env import MCPEnv # noqa: F401
diff --git a/verifiers/envs/experimental/__init__.py b/verifiers/envs/experimental/__init__.py
index 034cb817f..ff81549e0 100644
--- a/verifiers/envs/experimental/__init__.py
+++ b/verifiers/envs/experimental/__init__.py
@@ -1,3 +1,4 @@
+from verifiers.envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin
from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
-__all__ = ["SandboxMixin"]
+__all__ = ["RolloutGatewayMixin", "SandboxMixin"]
diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py
index 5cfb9e42e..946308f5b 100644
--- a/verifiers/envs/experimental/cli_agent_env.py
+++ b/verifiers/envs/experimental/cli_agent_env.py
@@ -84,8 +84,6 @@ def __init__(
)
self.run_command = run_command
self.poll_interval = poll_interval
- self.interception_port = interception_port
- self.interception_url = interception_url
self.timeout_seconds = timeout_seconds
self.docker_image = docker_image
self.start_command = start_command
@@ -99,7 +97,16 @@ def __init__(
self.advanced_configs = advanced_configs
self.labels = labels
- # Tunnel and interception server
+ self.init_interception(interception_port, interception_url)
+
+ def init_interception(
+ self,
+ interception_port: int = 8765,
+ interception_url: str | None = None,
+ ):
+ """Initialize interception server and tunnel resources. Call from __init__."""
+ self.interception_port = interception_port
+ self.interception_url = interception_url
self._tunnel: Tunnel | None = None
self._tunnel_lock = asyncio.Lock()
self._interception_server = InterceptionServer(port=interception_port)
@@ -411,6 +418,8 @@ async def add_model_response(
@vf.teardown
async def teardown_resources(self):
"""Stop Prime Tunnel and HTTP interception server."""
+ if not hasattr(self, "_tunnel_lock"):
+ return
async with self._tunnel_lock:
if self._tunnel is not None:
try:
@@ -437,7 +446,7 @@ async def cleanup_interception_context(self, state: State):
state.pop("background_job", None)
rollout_id = state.get("rollout_id")
- if rollout_id:
+ if rollout_id and hasattr(self, "_interception_server"):
self._interception_server.unregister_rollout(rollout_id)
@vf.stop
diff --git a/verifiers/envs/experimental/rollout_gateway_env.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
similarity index 64%
rename from verifiers/envs/experimental/rollout_gateway_env.py
rename to verifiers/envs/experimental/rollout_gateway_mixin.py
index 53a264bea..880e452b3 100644
--- a/verifiers/envs/experimental/rollout_gateway_env.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -6,18 +6,11 @@
from urllib.parse import urlparse
import httpx
-from prime_sandboxes import (
- AdvancedConfigs,
- BackgroundJob,
- BackgroundJobStatus,
- CreateSandboxRequest,
-)
+from prime_sandboxes import CreateSandboxRequest
from prime_tunnel import Tunnel
import verifiers as vf
from verifiers.clients import Client
-from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
-from verifiers.envs.multiturn_env import MultiTurnMonitorRubric
from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State
logger = logging.getLogger(__name__)
@@ -30,137 +23,64 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str:
return text[-max_chars:] if len(text) > max_chars else text
-class RolloutGatewayEnv(SandboxMixin, vf.Environment):
- """
- Environment for running full agent code inside sandboxes.
+class RolloutGatewayMixin:
+ """Opt-in mixin that replaces MultiTurnEnv's interception-based rollout
+ with a server-side gateway path. Toggle via ``use_gateway`` attribute.
+
+ When gateway is active, the agent talks directly to vLLM's rollout
+ gateway through a prime tunnel. The env only manages sandbox lifecycle.
+ When inactive, falls through to CliAgentEnv's interception path.
- The sandboxed agent talks directly to the rollout gateway running in the vLLM
- server through a prime tunnel URL. The environment only handles sandbox
- lifecycle, rollout registration, trajectory fetch, and reward computation.
+ MRO: ``MyEnv → RolloutGatewayMixin → CliAgentEnv → SandboxMixin → MultiTurnEnv → Environment``
"""
- def __init__(
+ use_gateway: bool = True
+
+ def init_gateway(
self,
- run_command: str,
gateway_port: int = 8000,
- max_turns: int = -1,
timeout_seconds: float = 3600.0,
- poll_interval: float = 2.0,
- docker_image: str = "python:3.11-slim",
- start_command: str = "tail -f /dev/null",
- cpu_cores: int = 1,
- memory_gb: int = 2,
- disk_size_gb: int = 5,
- gpu_count: int = 0,
- timeout_minutes: int = 60,
- environment_vars: dict[str, str] | None = None,
- team_id: str | None = None,
- advanced_configs: AdvancedConfigs | None = None,
- labels: list[str] | None = None,
- max_retries: int = 5,
- base_delay: float = 0.5,
- backoff_factor: float = 2.0,
- max_backoff_seconds: float = 30.0,
- jitter: float = 1e-3,
- sandbox_client_max_workers: int = 10,
- sandbox_client_max_connections: int = 100,
- sandbox_client_max_keepalive_connections: int = 50,
- sandbox_wait_for_creation_max_attempts: int = 120,
- **kwargs,
):
- super().__init__(message_type="chat", **kwargs)
- self.add_rubric(MultiTurnMonitorRubric())
-
- self.init_sandbox_client(
- max_retries=max_retries,
- base_delay=base_delay,
- backoff_factor=backoff_factor,
- max_backoff_seconds=max_backoff_seconds,
- jitter=jitter,
- sandbox_client_max_workers=sandbox_client_max_workers,
- sandbox_client_max_connections=sandbox_client_max_connections,
- sandbox_client_max_keepalive_connections=sandbox_client_max_keepalive_connections,
- sandbox_wait_for_creation_max_attempts=sandbox_wait_for_creation_max_attempts,
- )
-
- self.run_command = run_command
+ """Initialize gateway resources. Call in __init__ when use_gateway=True."""
self.gateway_port = gateway_port
- self.max_turns = max_turns
- self.poll_interval = poll_interval
- self.timeout_seconds = timeout_seconds
- self.docker_image = docker_image
- self.start_command = start_command
- self.cpu_cores = cpu_cores
- self.memory_gb = memory_gb
- self.disk_size_gb = disk_size_gb
- self.gpu_count = gpu_count
- self.timeout_minutes = timeout_minutes
- self.environment_vars = environment_vars
- self.team_id = team_id
- self.advanced_configs = advanced_configs
- self.labels = labels
-
- self._http_client = httpx.AsyncClient(
- timeout=httpx.Timeout(self.timeout_seconds)
- )
- self._tunnels: dict[str, Tunnel] = {}
- self._tunnel_lock = asyncio.Lock()
+ self._gw_timeout_seconds = timeout_seconds
+ self._gw_http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds))
+ self._gw_tunnels: dict[str, Tunnel] = {}
+ self._gw_tunnel_lock = asyncio.Lock()
- def _resolve_tunnel_local_addr(self, state: State) -> str:
- gateway_url = state["gateway_url"]
- parsed = urlparse(gateway_url)
- host = parsed.hostname
- if host is None:
- raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}")
- return host
-
- async def get_tunnel_url(self, local_addr: str | None = None) -> str:
- """Get tunnel URL, starting the tunnel if needed."""
- async with self._tunnel_lock:
- if local_addr is None:
- if len(self._tunnels) == 1:
- tunnel = next(iter(self._tunnels.values()))
- assert tunnel.url is not None, "Tunnel started but URL is None"
- return tunnel.url
- if len(self._tunnels) == 0:
- raise ValueError("local_addr is required when starting tunnel")
- raise ValueError(
- "local_addr is required when multiple tunnels are active"
- )
-
- tunnel = self._tunnels.get(local_addr)
- if tunnel is None:
- tunnel = Tunnel(
- local_port=self.gateway_port,
- local_addr=local_addr,
- log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info",
- )
- url = await tunnel.start()
- self._tunnels[local_addr] = tunnel
- logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}")
- return url
-
- assert tunnel.url is not None, "Tunnel started but URL is None"
- return tunnel.url
+ # ------------------------------------------------------------------
+ # Gateway URL resolution
+ # ------------------------------------------------------------------
def _resolve_gateway_url(self, state: State) -> str:
- # `state["client"]` may be a Verifiers wrapper with the raw client on `.client`.
client = getattr(state["client"], "client", state["client"])
gateway_url = str(client.base_url).rstrip("/")
if gateway_url.endswith("/v1"):
gateway_url = gateway_url[:-3]
return gateway_url
+ def _resolve_tunnel_local_addr(self, state: State) -> str:
+ gateway_url = state["gateway_url"]
+ parsed = urlparse(gateway_url)
+ host = parsed.hostname
+ if host is None:
+ raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}")
+ return host
+
def _rollout_endpoint(self, state: State, suffix: str) -> str:
return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}"
+ # ------------------------------------------------------------------
+ # Gateway HTTP helpers
+ # ------------------------------------------------------------------
+
async def _gateway_post(
self,
state: State,
suffix: str,
payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
- response = await self._http_client.post(
+ response = await self._gw_http_client.post(
self._rollout_endpoint(state, suffix),
json=payload,
)
@@ -170,10 +90,32 @@ async def _gateway_post(
return response.json()
async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
- response = await self._http_client.get(self._rollout_endpoint(state, suffix))
+ response = await self._gw_http_client.get(self._rollout_endpoint(state, suffix))
response.raise_for_status()
return response.json()
+ # ------------------------------------------------------------------
+ # Env var override for gateway path
+ # ------------------------------------------------------------------
+
+ async def build_env_vars(self, state: State) -> dict[str, str]:
+ """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode."""
+ if not self.use_gateway:
+ return await super().build_env_vars(state)
+ env_vars = dict(self.environment_vars) if self.environment_vars else {}
+ env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"]
+ env_vars.setdefault("OPENAI_TIMEOUT", "600")
+ env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
+ env_vars.setdefault("HTTPX_TIMEOUT", "600")
+ model = state.get("model")
+ if model:
+ env_vars["OPENAI_MODEL"] = model
+ return env_vars
+
+ # ------------------------------------------------------------------
+ # Rollout registration & trajectory
+ # ------------------------------------------------------------------
+
async def register_rollout(self, state: State) -> None:
sampling_params = state.get("sampling_args") or {}
payload = {
@@ -196,28 +138,47 @@ async def fetch_trajectory(self, state: State) -> None:
data.get("is_truncated", state.get("is_truncated", False))
)
- async def get_docker_image(self, state: State) -> str:
- """Get the Docker image for the sandbox. Override for per-task images."""
- return self.docker_image
+ # ------------------------------------------------------------------
+ # Gateway tunnel management
+ # ------------------------------------------------------------------
- async def build_env_vars(self, state: State) -> dict[str, str]:
- """Build environment variables for the sandbox. Override to add custom vars."""
- env_vars = dict(self.environment_vars) if self.environment_vars else {}
- env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"]
- env_vars.setdefault("OPENAI_TIMEOUT", "600")
- env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
- env_vars.setdefault("HTTPX_TIMEOUT", "600")
- model = state.get("model")
- if model:
- env_vars["OPENAI_MODEL"] = model
- return env_vars
+ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
+ """Get gateway tunnel URL, starting the tunnel if needed."""
+ async with self._gw_tunnel_lock:
+ if local_addr is None:
+ if len(self._gw_tunnels) == 1:
+ tunnel = next(iter(self._gw_tunnels.values()))
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+ if len(self._gw_tunnels) == 0:
+ raise ValueError("local_addr is required when starting tunnel")
+ raise ValueError(
+ "local_addr is required when multiple tunnels are active"
+ )
- async def post_sandbox_setup(self, state: State) -> None:
- """Hook for post-sandbox setup. Override to upload files, run commands, etc."""
- pass
+ tunnel = self._gw_tunnels.get(local_addr)
+ if tunnel is None:
+ tunnel = Tunnel(
+ local_port=self.gateway_port,
+ local_addr=local_addr,
+ log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info",
+ )
+ url = await tunnel.start()
+ self._gw_tunnels[local_addr] = tunnel
+ logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}")
+ return url
+
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+
+ # ------------------------------------------------------------------
+ # Agent start & completion polling
+ # ------------------------------------------------------------------
async def start_agent(self, state: State) -> None:
- """Start the agent command using background job."""
+ """Start the agent command. In gateway mode, skip background completion task."""
+ if not self.use_gateway:
+ return await super().start_agent(state)
sandbox_id = state["sandbox_id"]
background_job = await self.sandbox_client.start_background_job(
sandbox_id,
@@ -227,36 +188,17 @@ async def start_agent(self, state: State) -> None:
state["agent_start_time"] = time.time()
state["agent_completed"] = False
- async def wait_for_agent_completion(self, state: State) -> None:
- """Poll for agent completion using background job API."""
- sandbox_id = state.get("sandbox_id")
- background_job = state.get("background_job")
- if not sandbox_id or not background_job:
- state["agent_completed"] = True
- return
-
- try:
- await asyncio.wait_for(
- self.poll_job_completion(state, sandbox_id, background_job),
- timeout=self.timeout_seconds,
- )
- except asyncio.TimeoutError:
- logger.warning(
- f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} stage=wait_for_agent_completion timed out after {self.timeout_seconds:.1f}s"
- )
- state["agent_timed_out"] = True
- finally:
- state["agent_completed"] = True
-
async def poll_job_completion(
self,
state: State,
sandbox_id: str,
- background_job: BackgroundJob,
+ background_job,
) -> None:
"""Poll until background job completes, capturing output."""
+ if not self.use_gateway:
+ return await super().poll_job_completion(state, sandbox_id, background_job)
while True:
- status: BackgroundJobStatus = await self.sandbox_client.get_background_job(
+ status = await self.sandbox_client.get_background_job(
sandbox_id,
background_job,
)
@@ -266,15 +208,45 @@ async def poll_job_completion(
state["agent_stderr"] = status.stderr
if status.exit_code not in (None, 0):
logger.warning(
- f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code} stdout_tail={_tail_text(status.stdout)!r} stderr_tail={_tail_text(status.stderr)!r}"
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} "
+ f"stage=agent_completed exit_code={status.exit_code} "
+ f"stdout_tail={_tail_text(status.stdout)!r} "
+ f"stderr_tail={_tail_text(status.stderr)!r}"
)
else:
logger.debug(
- f"rollout={state.get('rollout_id')} sandbox={sandbox_id} stage=agent_completed exit_code={status.exit_code}"
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} "
+ f"stage=agent_completed exit_code={status.exit_code}"
)
return
await asyncio.sleep(self.poll_interval)
+ async def wait_for_agent_completion(self, state: State) -> None:
+ """Poll for agent completion using background job API."""
+ sandbox_id = state.get("sandbox_id")
+ background_job = state.get("background_job")
+ if not sandbox_id or not background_job:
+ state["agent_completed"] = True
+ return
+
+ try:
+ await asyncio.wait_for(
+ self.poll_job_completion(state, sandbox_id, background_job),
+ timeout=self._gw_timeout_seconds,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} "
+ f"stage=wait_for_agent_completion timed out after {self._gw_timeout_seconds:.1f}s"
+ )
+ state["agent_timed_out"] = True
+ finally:
+ state["agent_completed"] = True
+
+ # ------------------------------------------------------------------
+ # Timing
+ # ------------------------------------------------------------------
+
async def _render_timing(self, state: State) -> None:
start_time = state["timing"]["start_time"]
end_time = time.time()
@@ -282,6 +254,10 @@ async def _render_timing(self, state: State) -> None:
state["timing"]["generation_ms"] = generation_ms
state["timing"]["total_ms"] = generation_ms
+ # ------------------------------------------------------------------
+ # rollout() — gateway path with fallback
+ # ------------------------------------------------------------------
+
async def rollout(
self,
input: RolloutInput,
@@ -289,13 +265,18 @@ async def rollout(
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
+ if not self.use_gateway:
+ return await super().rollout(input, client, model, sampling_args)
+
state = await self.init_state(input, client, model, sampling_args)
state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}"
state["gateway_url"] = self._resolve_gateway_url(state)
rollout_id = state["rollout_id"]
info = state.get("info") or {}
logger.info(
- f"rollout={rollout_id} stage=start model={model} example_id={info.get('instance_id') or info.get('example_id')} repo={info.get('repo_name')}"
+ f"rollout={rollout_id} stage=start model={model} "
+ f"example_id={info.get('instance_id') or info.get('example_id')} "
+ f"repo={info.get('repo_name')}"
)
rollout_registered = False
@@ -310,7 +291,7 @@ async def rollout(
f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}"
)
- tunnel_url = await self.get_tunnel_url(local_addr=tunnel_local_addr)
+ tunnel_url = await self.get_gateway_tunnel_url(local_addr=tunnel_local_addr)
state["tunnel_url"] = tunnel_url
state["rollout_base_url"] = (
f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}"
@@ -339,25 +320,32 @@ async def rollout(
)
await self.create_sandbox(state, sandbox_request)
logger.info(
- f"rollout={rollout_id} stage=create_sandbox ok sandbox_id={state.get('sandbox_id')} docker_image={docker_image}"
+ f"rollout={rollout_id} stage=create_sandbox ok "
+ f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}"
)
await self.start_agent(state)
logger.debug(
- f"rollout={rollout_id} stage=start_agent ok sandbox_id={state.get('sandbox_id')}"
+ f"rollout={rollout_id} stage=start_agent ok "
+ f"sandbox_id={state.get('sandbox_id')}"
)
await self.wait_for_agent_completion(state)
logger.debug(
- f"rollout={rollout_id} stage=wait_for_agent_completion ok exit_code={state.get('agent_exit_code')}"
+ f"rollout={rollout_id} stage=wait_for_agent_completion ok "
+ f"exit_code={state.get('agent_exit_code')}"
)
await self.fetch_trajectory(state)
trajectory = state.get("trajectory") or []
logger.info(
- f"rollout={rollout_id} stage=fetch_trajectory ok turns={len(trajectory)} truncated={state.get('is_truncated', False)}"
+ f"rollout={rollout_id} stage=fetch_trajectory ok "
+ f"turns={len(trajectory)} truncated={state.get('is_truncated', False)}"
)
if len(trajectory) == 0:
logger.warning(
- f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory agent_exit_code={state.get('agent_exit_code')} stdout_tail={_tail_text(state.get('agent_stdout'))!r} stderr_tail={_tail_text(state.get('agent_stderr'))!r}"
+ f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory "
+ f"agent_exit_code={state.get('agent_exit_code')} "
+ f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} "
+ f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}"
)
except vf.Error as e:
state["error"] = e
@@ -403,18 +391,27 @@ async def rollout(
await self._render_timing(state)
error_name = type(state["error"]).__name__ if state.get("error") else None
logger.info(
- f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} sandbox_id={state.get('sandbox_id')} turns={len(state.get('trajectory', []))} agent_exit_code={state.get('agent_exit_code')} error={error_name}"
+ f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} "
+ f"sandbox_id={state.get('sandbox_id')} "
+ f"turns={len(state.get('trajectory', []))} "
+ f"agent_exit_code={state.get('agent_exit_code')} error={error_name}"
)
return state
+ # ------------------------------------------------------------------
+ # Teardown
+ # ------------------------------------------------------------------
+
@vf.teardown
- async def teardown_resources(self):
- """Stop Prime Tunnel and close HTTP client."""
- await self._http_client.aclose()
- async with self._tunnel_lock:
- tunnels = list(self._tunnels.items())
- self._tunnels = {}
+ async def teardown_gateway(self):
+ """Close gateway HTTP client and stop gateway tunnels."""
+ if not hasattr(self, "_gw_http_client"):
+ return
+ await self._gw_http_client.aclose()
+ async with self._gw_tunnel_lock:
+ tunnels = list(self._gw_tunnels.items())
+ self._gw_tunnels = {}
for local_addr, tunnel in tunnels:
try:
tunnel.sync_stop()
@@ -423,18 +420,3 @@ async def teardown_resources(self):
logger.warning(
f"Error stopping Prime Tunnel local_addr={local_addr}: {e}"
)
-
- async def post_rollout(self, state: State):
- """
- Override for custom post-rollout logic. For example, if sandbox state is needed for reward functions,
- run computation here and cache the result in state before sandbox is destroyed.
- """
- pass
-
- @vf.cleanup
- async def destroy_sandbox(self, state: State):
- """Cleanup sandbox after rollout."""
- await self.post_rollout(state)
- sandbox_id = state.get("sandbox_id")
- if sandbox_id:
- await self.delete_sandbox(sandbox_id)
From 7a36c41c0322fa972ebf856916d41589b9c6795d Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 02:24:00 +0530
Subject: [PATCH 16/21] cancel properly
---
verifiers/envs/experimental/rollout_gateway_mixin.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
index 880e452b3..1d5a78630 100644
--- a/verifiers/envs/experimental/rollout_gateway_mixin.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -347,6 +347,13 @@ async def rollout(
f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} "
f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}"
)
+ except asyncio.CancelledError:
+ if rollout_registered:
+ try:
+ await self._gateway_post(state, "cancel")
+ except Exception:
+ pass
+ raise
except vf.Error as e:
state["error"] = e
logger.exception(
From d664d31cebb2576306ecb648abfc11422be0ae71 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 04:08:45 +0530
Subject: [PATCH 17/21] refactor
---
tests/test_rollout_gateway_env.py | 39 +++++++++++++++-
verifiers/envs/experimental/cli_agent_env.py | 11 +++--
.../experimental/rollout_gateway_mixin.py | 44 +++++++++++--------
3 files changed, 70 insertions(+), 24 deletions(-)
diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py
index d16da5d35..87bfe7b07 100644
--- a/tests/test_rollout_gateway_env.py
+++ b/tests/test_rollout_gateway_env.py
@@ -42,8 +42,8 @@ def sync_stop(self) -> None:
class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):
def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs):
- super().__init__(**kwargs)
self.use_gateway = use_gateway
+ super().__init__(**kwargs)
if use_gateway:
self.init_gateway(
gateway_port=gateway_port, timeout_seconds=self.timeout_seconds
@@ -311,3 +311,40 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
await env.teardown_gateway()
assert sum(t.stop_calls for t in FakeTunnel.instances) == 2
+
+
+@pytest.mark.asyncio
+async def test_use_gateway_false_initializes_interception(monkeypatch):
+ """With use_gateway=False, interception server is created and gateway is not."""
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ use_gateway=False,
+ timeout_seconds=30.0,
+ )
+
+ # Interception server should be initialized
+ assert env._interception_server is not None
+ assert env._tunnel is None
+ assert env._tunnel_lock is not None
+
+ # Gateway attributes should not exist (init_gateway was never called)
+ assert not hasattr(env, "_http_client")
+ assert not hasattr(env, "_tunnels")
+
+ # Teardowns should be safe no-ops for the inactive path
+ await env.teardown_gateway() # early return via use_gateway=False
+ await env.teardown_resources() # stops interception (which was never started)
+
+ assert len(FakeTunnel.instances) == 0
diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py
index 946308f5b..f4f91093b 100644
--- a/verifiers/envs/experimental/cli_agent_env.py
+++ b/verifiers/envs/experimental/cli_agent_env.py
@@ -97,6 +97,10 @@ def __init__(
self.advanced_configs = advanced_configs
self.labels = labels
+ self._interception_server: InterceptionServer | None = None
+ self._tunnel: Tunnel | None = None
+ self._tunnel_lock = asyncio.Lock()
+
self.init_interception(interception_port, interception_url)
def init_interception(
@@ -418,8 +422,6 @@ async def add_model_response(
@vf.teardown
async def teardown_resources(self):
"""Stop Prime Tunnel and HTTP interception server."""
- if not hasattr(self, "_tunnel_lock"):
- return
async with self._tunnel_lock:
if self._tunnel is not None:
try:
@@ -429,7 +431,8 @@ async def teardown_resources(self):
logger.warning(f"Error stopping Prime Tunnel: {e}")
finally:
self._tunnel = None
- await self._interception_server.stop()
+ if self._interception_server is not None:
+ await self._interception_server.stop()
@vf.cleanup
async def cleanup_interception_context(self, state: State):
@@ -446,7 +449,7 @@ async def cleanup_interception_context(self, state: State):
state.pop("background_job", None)
rollout_id = state.get("rollout_id")
- if rollout_id and hasattr(self, "_interception_server"):
+ if rollout_id and self._interception_server is not None:
self._interception_server.unregister_rollout(rollout_id)
@vf.stop
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
index 1d5a78630..0e0ca49b1 100644
--- a/verifiers/envs/experimental/rollout_gateway_mixin.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -36,6 +36,10 @@ class RolloutGatewayMixin:
use_gateway: bool = True
+ def init_interception(self, *args, **kwargs):
+ if not self.use_gateway:
+ super().init_interception(*args, **kwargs)
+
def init_gateway(
self,
gateway_port: int = 8000,
@@ -43,10 +47,12 @@ def init_gateway(
):
"""Initialize gateway resources. Call in __init__ when use_gateway=True."""
self.gateway_port = gateway_port
- self._gw_timeout_seconds = timeout_seconds
- self._gw_http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds))
- self._gw_tunnels: dict[str, Tunnel] = {}
- self._gw_tunnel_lock = asyncio.Lock()
+ self._gateway_timeout_seconds = timeout_seconds
+ self._http_client: httpx.AsyncClient | None = httpx.AsyncClient(
+ timeout=httpx.Timeout(timeout_seconds)
+ )
+ self._tunnels: dict[str, Tunnel] = {}
+ self._tunnel_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Gateway URL resolution
@@ -80,7 +86,7 @@ async def _gateway_post(
suffix: str,
payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
- response = await self._gw_http_client.post(
+ response = await self._http_client.post(
self._rollout_endpoint(state, suffix),
json=payload,
)
@@ -90,7 +96,7 @@ async def _gateway_post(
return response.json()
async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
- response = await self._gw_http_client.get(self._rollout_endpoint(state, suffix))
+ response = await self._http_client.get(self._rollout_endpoint(state, suffix))
response.raise_for_status()
return response.json()
@@ -144,19 +150,19 @@ async def fetch_trajectory(self, state: State) -> None:
async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
"""Get gateway tunnel URL, starting the tunnel if needed."""
- async with self._gw_tunnel_lock:
+ async with self._tunnel_lock:
if local_addr is None:
- if len(self._gw_tunnels) == 1:
- tunnel = next(iter(self._gw_tunnels.values()))
+ if len(self._tunnels) == 1:
+ tunnel = next(iter(self._tunnels.values()))
assert tunnel.url is not None, "Tunnel started but URL is None"
return tunnel.url
- if len(self._gw_tunnels) == 0:
+ if len(self._tunnels) == 0:
raise ValueError("local_addr is required when starting tunnel")
raise ValueError(
"local_addr is required when multiple tunnels are active"
)
- tunnel = self._gw_tunnels.get(local_addr)
+ tunnel = self._tunnels.get(local_addr)
if tunnel is None:
tunnel = Tunnel(
local_port=self.gateway_port,
@@ -164,7 +170,7 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info",
)
url = await tunnel.start()
- self._gw_tunnels[local_addr] = tunnel
+ self._tunnels[local_addr] = tunnel
logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}")
return url
@@ -232,12 +238,12 @@ async def wait_for_agent_completion(self, state: State) -> None:
try:
await asyncio.wait_for(
self.poll_job_completion(state, sandbox_id, background_job),
- timeout=self._gw_timeout_seconds,
+ timeout=self._gateway_timeout_seconds,
)
except asyncio.TimeoutError:
logger.warning(
f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} "
- f"stage=wait_for_agent_completion timed out after {self._gw_timeout_seconds:.1f}s"
+ f"stage=wait_for_agent_completion timed out after {self._gateway_timeout_seconds:.1f}s"
)
state["agent_timed_out"] = True
finally:
@@ -413,12 +419,12 @@ async def rollout(
@vf.teardown
async def teardown_gateway(self):
"""Close gateway HTTP client and stop gateway tunnels."""
- if not hasattr(self, "_gw_http_client"):
+ if not self.use_gateway:
return
- await self._gw_http_client.aclose()
- async with self._gw_tunnel_lock:
- tunnels = list(self._gw_tunnels.items())
- self._gw_tunnels = {}
+ await self._http_client.aclose()
+ async with self._tunnel_lock:
+ tunnels = list(self._tunnels.items())
+ self._tunnels = {}
for local_addr, tunnel in tunnels:
try:
tunnel.sync_stop()
From d0e58f799528b0e5d541f6fe2162b0b10f604df1 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 04:13:40 +0530
Subject: [PATCH 18/21] clean up comments
---
.../experimental/rollout_gateway_mixin.py | 40 +------------------
1 file changed, 2 insertions(+), 38 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
index 0e0ca49b1..67f96672c 100644
--- a/verifiers/envs/experimental/rollout_gateway_mixin.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -24,10 +24,10 @@ def _tail_text(value: Any, max_chars: int = 1200) -> str:
class RolloutGatewayMixin:
- """Opt-in mixin that replaces MultiTurnEnv's interception-based rollout
+ """Opt-in mixin that replaces CliAgentEnv's interception-based rollout
with a server-side gateway path. Toggle via ``use_gateway`` attribute.
- When gateway is active, the agent talks directly to vLLM's rollout
+ When gateway is active, the agent talks directly to prime-rl's rollout
gateway through a prime tunnel. The env only manages sandbox lifecycle.
When inactive, falls through to CliAgentEnv's interception path.
@@ -54,10 +54,6 @@ def init_gateway(
self._tunnels: dict[str, Tunnel] = {}
self._tunnel_lock = asyncio.Lock()
- # ------------------------------------------------------------------
- # Gateway URL resolution
- # ------------------------------------------------------------------
-
def _resolve_gateway_url(self, state: State) -> str:
client = getattr(state["client"], "client", state["client"])
gateway_url = str(client.base_url).rstrip("/")
@@ -76,10 +72,6 @@ def _resolve_tunnel_local_addr(self, state: State) -> str:
def _rollout_endpoint(self, state: State, suffix: str) -> str:
return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}"
- # ------------------------------------------------------------------
- # Gateway HTTP helpers
- # ------------------------------------------------------------------
-
async def _gateway_post(
self,
state: State,
@@ -100,10 +92,6 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
response.raise_for_status()
return response.json()
- # ------------------------------------------------------------------
- # Env var override for gateway path
- # ------------------------------------------------------------------
-
async def build_env_vars(self, state: State) -> dict[str, str]:
"""Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode."""
if not self.use_gateway:
@@ -118,10 +106,6 @@ async def build_env_vars(self, state: State) -> dict[str, str]:
env_vars["OPENAI_MODEL"] = model
return env_vars
- # ------------------------------------------------------------------
- # Rollout registration & trajectory
- # ------------------------------------------------------------------
-
async def register_rollout(self, state: State) -> None:
sampling_params = state.get("sampling_args") or {}
payload = {
@@ -144,10 +128,6 @@ async def fetch_trajectory(self, state: State) -> None:
data.get("is_truncated", state.get("is_truncated", False))
)
- # ------------------------------------------------------------------
- # Gateway tunnel management
- # ------------------------------------------------------------------
-
async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
"""Get gateway tunnel URL, starting the tunnel if needed."""
async with self._tunnel_lock:
@@ -177,10 +157,6 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
assert tunnel.url is not None, "Tunnel started but URL is None"
return tunnel.url
- # ------------------------------------------------------------------
- # Agent start & completion polling
- # ------------------------------------------------------------------
-
async def start_agent(self, state: State) -> None:
"""Start the agent command. In gateway mode, skip background completion task."""
if not self.use_gateway:
@@ -249,10 +225,6 @@ async def wait_for_agent_completion(self, state: State) -> None:
finally:
state["agent_completed"] = True
- # ------------------------------------------------------------------
- # Timing
- # ------------------------------------------------------------------
-
async def _render_timing(self, state: State) -> None:
start_time = state["timing"]["start_time"]
end_time = time.time()
@@ -260,10 +232,6 @@ async def _render_timing(self, state: State) -> None:
state["timing"]["generation_ms"] = generation_ms
state["timing"]["total_ms"] = generation_ms
- # ------------------------------------------------------------------
- # rollout() — gateway path with fallback
- # ------------------------------------------------------------------
-
async def rollout(
self,
input: RolloutInput,
@@ -412,10 +380,6 @@ async def rollout(
return state
- # ------------------------------------------------------------------
- # Teardown
- # ------------------------------------------------------------------
-
@vf.teardown
async def teardown_gateway(self):
"""Close gateway HTTP client and stop gateway tunnels."""
From d74e447f6a96c6d357274d0d9d052052e8654743 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 04:23:09 +0530
Subject: [PATCH 19/21] fix
---
verifiers/envs/experimental/rollout_gateway_mixin.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
index 67f96672c..a30a1f6ff 100644
--- a/verifiers/envs/experimental/rollout_gateway_mixin.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -11,7 +11,7 @@
import verifiers as vf
from verifiers.clients import Client
-from verifiers.types import ClientConfig, RolloutInput, SamplingArgs, State
+from verifiers.types import RolloutInput, SamplingArgs, State
logger = logging.getLogger(__name__)
@@ -235,7 +235,7 @@ async def _render_timing(self, state: State) -> None:
async def rollout(
self,
input: RolloutInput,
- client: Client | ClientConfig,
+ client: Client,
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
From 4518df7efa52c94b20437fdcf2ecbd171e1662bd Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 04:35:47 +0530
Subject: [PATCH 20/21] fix `ty`
---
.../experimental/rollout_gateway_mixin.py | 54 +++++++++----------
1 file changed, 26 insertions(+), 28 deletions(-)
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
index a30a1f6ff..833b40b30 100644
--- a/verifiers/envs/experimental/rollout_gateway_mixin.py
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -38,7 +38,7 @@ class RolloutGatewayMixin:
def init_interception(self, *args, **kwargs):
if not self.use_gateway:
- super().init_interception(*args, **kwargs)
+ super().init_interception(*args, **kwargs) # ty: ignore[unresolved-attribute]
def init_gateway(
self,
@@ -48,9 +48,7 @@ def init_gateway(
"""Initialize gateway resources. Call in __init__ when use_gateway=True."""
self.gateway_port = gateway_port
self._gateway_timeout_seconds = timeout_seconds
- self._http_client: httpx.AsyncClient | None = httpx.AsyncClient(
- timeout=httpx.Timeout(timeout_seconds)
- )
+ self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds))
self._tunnels: dict[str, Tunnel] = {}
self._tunnel_lock = asyncio.Lock()
@@ -95,8 +93,8 @@ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
async def build_env_vars(self, state: State) -> dict[str, str]:
"""Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode."""
if not self.use_gateway:
- return await super().build_env_vars(state)
- env_vars = dict(self.environment_vars) if self.environment_vars else {}
+ return await super().build_env_vars(state) # ty: ignore[unresolved-attribute]
+ env_vars = dict(self.environment_vars) if self.environment_vars else {} # ty: ignore[unresolved-attribute]
env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"]
env_vars.setdefault("OPENAI_TIMEOUT", "600")
env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
@@ -111,8 +109,8 @@ async def register_rollout(self, state: State) -> None:
payload = {
"model": state["model"],
"sampling_params": sampling_params,
- "max_turns": self.max_turns,
- "max_seq_len": self.max_seq_len,
+ "max_turns": self.max_turns, # ty: ignore[unresolved-attribute]
+ "max_seq_len": self.max_seq_len, # ty: ignore[unresolved-attribute]
}
await self._gateway_post(state, "register", payload)
@@ -160,11 +158,11 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
async def start_agent(self, state: State) -> None:
"""Start the agent command. In gateway mode, skip background completion task."""
if not self.use_gateway:
- return await super().start_agent(state)
+ return await super().start_agent(state) # ty: ignore[unresolved-attribute]
sandbox_id = state["sandbox_id"]
- background_job = await self.sandbox_client.start_background_job(
+ background_job = await self.sandbox_client.start_background_job( # ty: ignore[unresolved-attribute]
sandbox_id,
- self.run_command,
+ self.run_command, # ty: ignore[unresolved-attribute]
)
state["background_job"] = background_job
state["agent_start_time"] = time.time()
@@ -178,9 +176,9 @@ async def poll_job_completion(
) -> None:
"""Poll until background job completes, capturing output."""
if not self.use_gateway:
- return await super().poll_job_completion(state, sandbox_id, background_job)
+ return await super().poll_job_completion(state, sandbox_id, background_job) # ty: ignore[unresolved-attribute]
while True:
- status = await self.sandbox_client.get_background_job(
+ status = await self.sandbox_client.get_background_job( # ty: ignore[unresolved-attribute]
sandbox_id,
background_job,
)
@@ -201,7 +199,7 @@ async def poll_job_completion(
f"stage=agent_completed exit_code={status.exit_code}"
)
return
- await asyncio.sleep(self.poll_interval)
+ await asyncio.sleep(self.poll_interval) # ty: ignore[unresolved-attribute]
async def wait_for_agent_completion(self, state: State) -> None:
"""Poll for agent completion using background job API."""
@@ -240,9 +238,9 @@ async def rollout(
sampling_args: SamplingArgs | None = None,
) -> State:
if not self.use_gateway:
- return await super().rollout(input, client, model, sampling_args)
+ return await super().rollout(input, client, model, sampling_args) # ty: ignore[unresolved-attribute]
- state = await self.init_state(input, client, model, sampling_args)
+ state = await self.init_state(input, client, model, sampling_args) # ty: ignore[unresolved-attribute]
state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}"
state["gateway_url"] = self._resolve_gateway_url(state)
rollout_id = state["rollout_id"]
@@ -273,26 +271,26 @@ async def rollout(
logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}")
env_vars = await self.build_env_vars(state)
- docker_image = await self.get_docker_image(state)
+ docker_image = await self.get_docker_image(state) # ty: ignore[unresolved-attribute]
sandbox_request = CreateSandboxRequest(
name=state["rollout_id"],
docker_image=docker_image,
- start_command=self.start_command,
- cpu_cores=self.cpu_cores,
- memory_gb=self.memory_gb,
- disk_size_gb=self.disk_size_gb,
- gpu_count=self.gpu_count,
- timeout_minutes=self.timeout_minutes,
+ start_command=self.start_command, # ty: ignore[unresolved-attribute]
+ cpu_cores=self.cpu_cores, # ty: ignore[unresolved-attribute]
+ memory_gb=self.memory_gb, # ty: ignore[unresolved-attribute]
+ disk_size_gb=self.disk_size_gb, # ty: ignore[unresolved-attribute]
+ gpu_count=self.gpu_count, # ty: ignore[unresolved-attribute]
+ timeout_minutes=self.timeout_minutes, # ty: ignore[unresolved-attribute]
environment_vars=env_vars,
- team_id=self.team_id,
- advanced_configs=self.advanced_configs,
- labels=self.labels or [],
+ team_id=self.team_id, # ty: ignore[unresolved-attribute]
+ advanced_configs=self.advanced_configs, # ty: ignore[unresolved-attribute]
+ labels=self.labels or [], # ty: ignore[unresolved-attribute]
)
logger.debug(
f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} "
f"docker_image={docker_image}"
)
- await self.create_sandbox(state, sandbox_request)
+ await self.create_sandbox(state, sandbox_request) # ty: ignore[unresolved-attribute]
logger.info(
f"rollout={rollout_id} stage=create_sandbox ok "
f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}"
@@ -351,7 +349,7 @@ async def rollout(
if state.get("sandbox_id"):
try:
- await self._cleanup(state)
+ await self._cleanup(state) # ty: ignore[unresolved-attribute]
except Exception as e:
logger.warning(
f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}"
From 21d45c0460214c951549bca978795bd6cd26cc61 Mon Sep 17 00:00:00 2001
From: rasdani <73563550+rasdani@users.noreply.github.com>
Date: Wed, 25 Feb 2026 04:37:36 +0530
Subject: [PATCH 21/21] docs
---
assets/lab/environments/AGENTS.md | 3 ++-
docs/environments.md | 1 +
environments/AGENTS.md | 3 ++-
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/assets/lab/environments/AGENTS.md b/assets/lab/environments/AGENTS.md
index cc7f76d02..98e63b1d7 100644
--- a/assets/lab/environments/AGENTS.md
+++ b/assets/lab/environments/AGENTS.md
@@ -802,5 +802,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
-- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`.
+- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
diff --git a/docs/environments.md b/docs/environments.md
index 84f884482..4f8e43121 100644
--- a/docs/environments.md
+++ b/docs/environments.md
@@ -796,5 +796,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
diff --git a/environments/AGENTS.md b/environments/AGENTS.md
index da5323579..632fb0ee0 100644
--- a/environments/AGENTS.md
+++ b/environments/AGENTS.md
@@ -802,5 +802,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
-- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`.
+- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls