diff --git a/assets/lab/environments/AGENTS.md b/assets/lab/environments/AGENTS.md
index cc7f76d02..98e63b1d7 100644
--- a/assets/lab/environments/AGENTS.md
+++ b/assets/lab/environments/AGENTS.md
@@ -802,5 +802,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
-- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`.
+- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
diff --git a/docs/environments.md b/docs/environments.md
index 84f884482..4f8e43121 100644
--- a/docs/environments.md
+++ b/docs/environments.md
@@ -796,5 +796,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
diff --git a/environments/AGENTS.md b/environments/AGENTS.md
index da5323579..632fb0ee0 100644
--- a/environments/AGENTS.md
+++ b/environments/AGENTS.md
@@ -802,5 +802,6 @@ Newer and more experimental environment classes include:
- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
+- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
-- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`.
+- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py
new file mode 100644
index 000000000..87bfe7b07
--- /dev/null
+++ b/tests/test_rollout_gateway_env.py
@@ -0,0 +1,350 @@
+from __future__ import annotations
+
+import json
+from types import SimpleNamespace
+from unittest.mock import AsyncMock
+
+import httpx
+import pytest
+from datasets import Dataset
+
+import verifiers as vf
+import verifiers.envs.experimental.rollout_gateway_mixin as rollout_gateway_mixin
+
+pytestmark = [pytest.mark.integration, pytest.mark.environments]
+
+
+class FakeTunnel:
+ instances: list["FakeTunnel"] = []
+
+ def __init__(
+ self,
+ local_port: int,
+ local_addr: str = "127.0.0.1",
+ log_level: str | None = None,
+ ):
+ self.local_port = local_port
+ self.local_addr = local_addr
+ self.log_level = log_level
+ self.url: str | None = None
+ self.start_calls = 0
+ self.stop_calls = 0
+ FakeTunnel.instances.append(self)
+
+ async def start(self) -> str:
+ self.start_calls += 1
+ self.url = "https://unit-test.tunnel.prime.ai"
+ return self.url
+
+ def sync_stop(self) -> None:
+ self.stop_calls += 1
+
+
+class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):
+ def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs):
+ self.use_gateway = use_gateway
+ super().__init__(**kwargs)
+ if use_gateway:
+ self.init_gateway(
+ gateway_port=gateway_port, timeout_seconds=self.timeout_seconds
+ )
+
+ async def post_rollout(self, state: vf.State):
+ state["reward"] = 1.0
+ state["test_output"] = "ok"
+
+
+def _build_gateway_transport(tracker: dict) -> httpx.MockTransport:
+ trajectory = [
+ {
+ "prompt": [{"role": "user", "content": "Hello"}],
+ "completion": [{"role": "assistant", "content": "reply-1"}],
+ "tokens": {
+ "prompt_ids": [1, 2],
+ "prompt_mask": [0, 0],
+ "completion_ids": [3],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.1],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ {
+ "prompt": [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ ],
+ "completion": [{"role": "assistant", "content": "reply-2"}],
+ "tokens": {
+ "prompt_ids": [1, 2, 3, 4],
+ "prompt_mask": [0, 0, 0, 0],
+ "completion_ids": [5],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.2],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ {
+ "prompt": [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ {"role": "assistant", "content": "reply-2"},
+ {"role": "user", "content": "Turn 3"},
+ ],
+ "completion": [{"role": "assistant", "content": "reply-3"}],
+ "tokens": {
+ "prompt_ids": [1, 2, 3, 4, 5, 6],
+ "prompt_mask": [0, 0, 0, 0, 0, 0],
+ "completion_ids": [7],
+ "completion_mask": [1],
+ "completion_logprobs": [-0.3],
+ "overlong_prompt": False,
+ "is_truncated": False,
+ },
+ "reward": None,
+ "advantage": None,
+ "is_truncated": False,
+ "trajectory_id": "traj-1",
+ "extras": {},
+ },
+ ]
+
+ def _handler(request: httpx.Request) -> httpx.Response:
+ tracker["hosts"].add(request.url.host)
+ tracker["paths"].append(request.url.path)
+ path = request.url.path
+
+ if request.method == "POST" and path.endswith("/register"):
+ payload = json.loads(request.content.decode("utf-8"))
+ tracker["register_payload"] = payload
+ tracker["rollout_id"] = path.split("/")[-2]
+ return httpx.Response(status_code=200, json={"status": "active"})
+
+ if request.method == "POST" and path.endswith("/unregister"):
+ tracker["unregister_calls"] += 1
+ return httpx.Response(status_code=200, json={"status": "active"})
+
+ if request.method == "GET" and path.endswith("/trajectory"):
+ tracker["trajectory_calls"] += 1
+ return httpx.Response(
+ status_code=200,
+ json={
+ "rollout_id": tracker["rollout_id"],
+ "status": "completed",
+ "num_turns": 3,
+ "model": "Qwen/Qwen3-0.6B",
+ "prompt": trajectory[0]["prompt"],
+ "completion": [
+ {"role": "assistant", "content": "reply-1"},
+ {"role": "user", "content": "Turn 2"},
+ {"role": "assistant", "content": "reply-2"},
+ {"role": "user", "content": "Turn 3"},
+ {"role": "assistant", "content": "reply-3"},
+ ],
+ "is_truncated": False,
+ "trajectory": trajectory,
+ },
+ )
+
+ return httpx.Response(status_code=404, json={"error": f"Unhandled path {path}"})
+
+ return httpx.MockTransport(_handler)
+
+
+@pytest.mark.asyncio
+async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
+
+ tracker = {
+ "paths": [],
+ "hosts": set(),
+ "register_payload": None,
+ "rollout_id": None,
+ "trajectory_calls": 0,
+ "unregister_calls": 0,
+ }
+ transport = _build_gateway_transport(tracker)
+ real_async_client = httpx.AsyncClient
+ client = vf.OpenAIChatCompletionsClient(
+ vf.ClientConfig(
+ api_key_var="UNIT_TEST_API_KEY",
+ api_base_url="http://gateway.internal:8000/v1/",
+ )
+ )
+
+ def _client_factory(*args, **kwargs):
+ kwargs["transport"] = transport
+ return real_async_client(*args, **kwargs)
+
+ monkeypatch.setattr(rollout_gateway_mixin.httpx, "AsyncClient", _client_factory)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ gateway_port=8000,
+ max_turns=10,
+ timeout_seconds=30.0,
+ )
+
+ env.sandbox_client.create = AsyncMock(return_value=SimpleNamespace(id="sb-123"))
+ env.sandbox_client.wait_for_creation = AsyncMock(return_value=None)
+ env.sandbox_client.start_background_job = AsyncMock(
+ return_value=SimpleNamespace(id="job-1")
+ )
+ env.sandbox_client.get_background_job = AsyncMock(
+ return_value=SimpleNamespace(
+ completed=True,
+ exit_code=0,
+ stdout="agent ok",
+ stderr="",
+ )
+ )
+ env.sandbox_client.delete = AsyncMock(return_value=None)
+
+ rollout_input = {
+ "prompt": [{"role": "user", "content": "Hello"}],
+ "answer": "",
+ "example_id": 0,
+ "task": "gateway-test",
+ }
+ state = await env.rollout(
+ input=rollout_input,
+ client=client,
+ model="Qwen/Qwen3-0.6B",
+ sampling_args={"temperature": 0.7, "max_completion_tokens": 64},
+ )
+
+ assert state.get("error") is None
+ assert state["gateway_url"] == "http://gateway.internal:8000"
+ assert state["tunnel_url"] == "https://unit-test.tunnel.prime.ai"
+ assert state["rollout_base_url"].startswith(
+ "https://unit-test.tunnel.prime.ai/v1/rollouts/"
+ )
+ assert len(state["trajectory"]) == 3
+ assert state["prompt"] == [{"role": "user", "content": "Hello"}]
+ assert state["completion"][-1]["content"] == "reply-3"
+ assert state["reward"] == 1.0
+
+ create_request = env.sandbox_client.create.await_args.args[0]
+ assert (
+ create_request.environment_vars["OPENAI_BASE_URL"]
+ == f"https://unit-test.tunnel.prime.ai/v1/rollouts/{state['rollout_id']}"
+ )
+ assert create_request.environment_vars["OPENAI_MODEL"] == "Qwen/Qwen3-0.6B"
+
+ assert tracker["register_payload"]["max_turns"] == 10
+ assert tracker["register_payload"]["sampling_params"]["temperature"] == 0.7
+ assert tracker["register_payload"]["sampling_params"]["max_completion_tokens"] == 64
+ assert tracker["trajectory_calls"] == 1
+ assert tracker["unregister_calls"] == 1
+ assert tracker["hosts"] == {"gateway.internal"}
+
+ assert len(FakeTunnel.instances) == 1
+ tunnel = FakeTunnel.instances[0]
+ assert tunnel.local_port == 8000
+ assert tunnel.local_addr == "gateway.internal"
+ assert tunnel.start_calls == 1
+ assert await env.get_gateway_tunnel_url() == "https://unit-test.tunnel.prime.ai"
+ assert tunnel.start_calls == 1
+
+ await env.teardown_gateway()
+ assert tunnel.stop_calls == 1
+
+
+@pytest.mark.asyncio
+async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ gateway_port=8000,
+ )
+
+ url_a = await env.get_gateway_tunnel_url(local_addr="10.20.0.58")
+ url_b = await env.get_gateway_tunnel_url(local_addr="10.20.0.59")
+ url_a_reuse = await env.get_gateway_tunnel_url(local_addr="10.20.0.58")
+
+ assert url_a == "https://unit-test.tunnel.prime.ai"
+ assert url_b == "https://unit-test.tunnel.prime.ai"
+ assert url_a_reuse == url_a
+
+ assert len(FakeTunnel.instances) == 2
+ assert {t.local_addr for t in FakeTunnel.instances} == {"10.20.0.58", "10.20.0.59"}
+ assert sum(t.start_calls for t in FakeTunnel.instances) == 2
+
+ with pytest.raises(
+ ValueError, match="local_addr is required when multiple tunnels are active"
+ ):
+ await env.get_gateway_tunnel_url()
+
+ await env.teardown_gateway()
+ assert sum(t.stop_calls for t in FakeTunnel.instances) == 2
+
+
+@pytest.mark.asyncio
+async def test_use_gateway_false_initializes_interception(monkeypatch):
+ """With use_gateway=False, interception server is created and gateway is not."""
+ FakeTunnel.instances.clear()
+ monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)
+
+ dataset = Dataset.from_dict(
+ {
+ "prompt": [[{"role": "user", "content": "Hello"}]],
+ "answer": [""],
+ "example_id": [0],
+ }
+ )
+ env = GatewayCliAgentEnv(
+ run_command="echo run-agent",
+ dataset=dataset,
+ rubric=vf.Rubric(),
+ use_gateway=False,
+ timeout_seconds=30.0,
+ )
+
+ # Interception server should be initialized
+ assert env._interception_server is not None
+ assert env._tunnel is None
+ assert env._tunnel_lock is not None
+
+ # Gateway attributes should not exist (init_gateway was never called)
+ assert not hasattr(env, "_http_client")
+ assert not hasattr(env, "_tunnels")
+
+ # Teardowns should be safe no-ops for the inactive path
+ await env.teardown_gateway() # early return via use_gateway=False
+ await env.teardown_resources() # stops interception (which was never started)
+
+ assert len(FakeTunnel.instances) == 0
diff --git a/verifiers/__init__.py b/verifiers/__init__.py
index 2cd3480d8..e7fa269ea 100644
--- a/verifiers/__init__.py
+++ b/verifiers/__init__.py
@@ -67,6 +67,7 @@
"ReasoningGymEnv",
"GymEnv",
"CliAgentEnv",
+ "RolloutGatewayMixin",
"HarborEnv",
"MCPEnv",
"BrowserEnv",
@@ -121,6 +122,7 @@
"PythonEnv": "verifiers.envs.python_env:PythonEnv",
"GymEnv": "verifiers.envs.experimental.gym_env:GymEnv",
"CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv",
+ "RolloutGatewayMixin": "verifiers.envs.experimental.rollout_gateway_mixin:RolloutGatewayMixin",
"HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv",
"MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv",
"ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv",
@@ -160,6 +162,7 @@ def __getattr__(name: str):
from typing import Any
from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401
+ from .envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin # noqa: F401
from .envs.experimental.gym_env import GymEnv # noqa: F401
from .envs.experimental.harbor_env import HarborEnv # noqa: F401
from .envs.experimental.mcp_env import MCPEnv # noqa: F401
diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py
index 40ef74cbe..96126b506 100644
--- a/verifiers/envs/environment.py
+++ b/verifiers/envs/environment.py
@@ -622,6 +622,7 @@ async def init_state(
state["tool_defs"] = self._normalize_tool_defs(resolved_tool_defs) or []
state["trajectory"] = []
+ state["completion"] = None
self._get_usage_tracker(state, create_if_missing=True)
state["trajectory_id"] = uuid.uuid4().hex
state["reward"] = None
diff --git a/verifiers/envs/experimental/__init__.py b/verifiers/envs/experimental/__init__.py
index 034cb817f..ff81549e0 100644
--- a/verifiers/envs/experimental/__init__.py
+++ b/verifiers/envs/experimental/__init__.py
@@ -1,3 +1,4 @@
+from verifiers.envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin
from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
-__all__ = ["SandboxMixin"]
+__all__ = ["RolloutGatewayMixin", "SandboxMixin"]
diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py
index 5cfb9e42e..f4f91093b 100644
--- a/verifiers/envs/experimental/cli_agent_env.py
+++ b/verifiers/envs/experimental/cli_agent_env.py
@@ -84,8 +84,6 @@ def __init__(
)
self.run_command = run_command
self.poll_interval = poll_interval
- self.interception_port = interception_port
- self.interception_url = interception_url
self.timeout_seconds = timeout_seconds
self.docker_image = docker_image
self.start_command = start_command
@@ -99,7 +97,20 @@ def __init__(
self.advanced_configs = advanced_configs
self.labels = labels
- # Tunnel and interception server
+ self._interception_server: InterceptionServer | None = None
+ self._tunnel: Tunnel | None = None
+ self._tunnel_lock = asyncio.Lock()
+
+ self.init_interception(interception_port, interception_url)
+
+ def init_interception(
+ self,
+ interception_port: int = 8765,
+ interception_url: str | None = None,
+ ):
+ """Initialize interception server and tunnel resources. Call from __init__."""
+ self.interception_port = interception_port
+ self.interception_url = interception_url
self._tunnel: Tunnel | None = None
self._tunnel_lock = asyncio.Lock()
self._interception_server = InterceptionServer(port=interception_port)
@@ -420,7 +431,8 @@ async def teardown_resources(self):
logger.warning(f"Error stopping Prime Tunnel: {e}")
finally:
self._tunnel = None
- await self._interception_server.stop()
+ if self._interception_server is not None:
+ await self._interception_server.stop()
@vf.cleanup
async def cleanup_interception_context(self, state: State):
@@ -437,7 +449,7 @@ async def cleanup_interception_context(self, state: State):
state.pop("background_job", None)
rollout_id = state.get("rollout_id")
- if rollout_id:
+ if rollout_id and self._interception_server is not None:
self._interception_server.unregister_rollout(rollout_id)
@vf.stop
diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py
new file mode 100644
index 000000000..833b40b30
--- /dev/null
+++ b/verifiers/envs/experimental/rollout_gateway_mixin.py
@@ -0,0 +1,397 @@
+import asyncio
+import logging
+import time
+import uuid
+from typing import Any
+from urllib.parse import urlparse
+
+import httpx
+from prime_sandboxes import CreateSandboxRequest
+from prime_tunnel import Tunnel
+
+import verifiers as vf
+from verifiers.clients import Client
+from verifiers.types import RolloutInput, SamplingArgs, State
+
+logger = logging.getLogger(__name__)
+
+
+def _tail_text(value: Any, max_chars: int = 1200) -> str:
+ if value is None:
+ return ""
+ text = str(value)
+ return text[-max_chars:] if len(text) > max_chars else text
+
+
+class RolloutGatewayMixin:
+ """Opt-in mixin that replaces CliAgentEnv's interception-based rollout
+ with a server-side gateway path. Toggle via ``use_gateway`` attribute.
+
+ When gateway is active, the agent talks directly to prime-rl's rollout
+ gateway through a prime tunnel. The env only manages sandbox lifecycle.
+ When inactive, falls through to CliAgentEnv's interception path.
+
+ MRO: ``MyEnv → RolloutGatewayMixin → CliAgentEnv → SandboxMixin → MultiTurnEnv → Environment``
+ """
+
+ use_gateway: bool = True
+
+ def init_interception(self, *args, **kwargs):
+ if not self.use_gateway:
+ super().init_interception(*args, **kwargs) # ty: ignore[unresolved-attribute]
+
+ def init_gateway(
+ self,
+ gateway_port: int = 8000,
+ timeout_seconds: float = 3600.0,
+ ):
+ """Initialize gateway resources. Call in __init__ when use_gateway=True."""
+ self.gateway_port = gateway_port
+ self._gateway_timeout_seconds = timeout_seconds
+ self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds))
+ self._tunnels: dict[str, Tunnel] = {}
+ self._tunnel_lock = asyncio.Lock()
+
+ def _resolve_gateway_url(self, state: State) -> str:
+ client = getattr(state["client"], "client", state["client"])
+ gateway_url = str(client.base_url).rstrip("/")
+ if gateway_url.endswith("/v1"):
+ gateway_url = gateway_url[:-3]
+ return gateway_url
+
+ def _resolve_tunnel_local_addr(self, state: State) -> str:
+ gateway_url = state["gateway_url"]
+ parsed = urlparse(gateway_url)
+ host = parsed.hostname
+ if host is None:
+ raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}")
+ return host
+
+ def _rollout_endpoint(self, state: State, suffix: str) -> str:
+ return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}"
+
+ async def _gateway_post(
+ self,
+ state: State,
+ suffix: str,
+ payload: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ response = await self._http_client.post(
+ self._rollout_endpoint(state, suffix),
+ json=payload,
+ )
+ response.raise_for_status()
+ if not response.content:
+ return {}
+ return response.json()
+
+ async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]:
+ response = await self._http_client.get(self._rollout_endpoint(state, suffix))
+ response.raise_for_status()
+ return response.json()
+
+ async def build_env_vars(self, state: State) -> dict[str, str]:
+ """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode."""
+ if not self.use_gateway:
+ return await super().build_env_vars(state) # ty: ignore[unresolved-attribute]
+ env_vars = dict(self.environment_vars) if self.environment_vars else {} # ty: ignore[unresolved-attribute]
+ env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"]
+ env_vars.setdefault("OPENAI_TIMEOUT", "600")
+ env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600")
+ env_vars.setdefault("HTTPX_TIMEOUT", "600")
+ model = state.get("model")
+ if model:
+ env_vars["OPENAI_MODEL"] = model
+ return env_vars
+
+ async def register_rollout(self, state: State) -> None:
+ sampling_params = state.get("sampling_args") or {}
+ payload = {
+ "model": state["model"],
+ "sampling_params": sampling_params,
+ "max_turns": self.max_turns, # ty: ignore[unresolved-attribute]
+ "max_seq_len": self.max_seq_len, # ty: ignore[unresolved-attribute]
+ }
+ await self._gateway_post(state, "register", payload)
+
+ async def unregister_rollout(self, state: State) -> None:
+ await self._gateway_post(state, "unregister")
+
+ async def fetch_trajectory(self, state: State) -> None:
+ data = await self._gateway_get(state, "trajectory")
+ state["trajectory"] = data.get("trajectory", [])
+ state["prompt"] = data.get("prompt")
+ state["completion"] = data.get("completion")
+ state["is_truncated"] = bool(
+ data.get("is_truncated", state.get("is_truncated", False))
+ )
+
+ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str:
+ """Get gateway tunnel URL, starting the tunnel if needed."""
+ async with self._tunnel_lock:
+ if local_addr is None:
+ if len(self._tunnels) == 1:
+ tunnel = next(iter(self._tunnels.values()))
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+ if len(self._tunnels) == 0:
+ raise ValueError("local_addr is required when starting tunnel")
+ raise ValueError(
+ "local_addr is required when multiple tunnels are active"
+ )
+
+ tunnel = self._tunnels.get(local_addr)
+ if tunnel is None:
+ tunnel = Tunnel(
+ local_port=self.gateway_port,
+ local_addr=local_addr,
+ log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info",
+ )
+ url = await tunnel.start()
+ self._tunnels[local_addr] = tunnel
+ logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}")
+ return url
+
+ assert tunnel.url is not None, "Tunnel started but URL is None"
+ return tunnel.url
+
+ async def start_agent(self, state: State) -> None:
+ """Start the agent command. In gateway mode, skip background completion task."""
+ if not self.use_gateway:
+ return await super().start_agent(state) # ty: ignore[unresolved-attribute]
+ sandbox_id = state["sandbox_id"]
+ background_job = await self.sandbox_client.start_background_job( # ty: ignore[unresolved-attribute]
+ sandbox_id,
+ self.run_command, # ty: ignore[unresolved-attribute]
+ )
+ state["background_job"] = background_job
+ state["agent_start_time"] = time.time()
+ state["agent_completed"] = False
+
+ async def poll_job_completion(
+ self,
+ state: State,
+ sandbox_id: str,
+ background_job,
+ ) -> None:
+ """Poll until background job completes, capturing output."""
+ if not self.use_gateway:
+ return await super().poll_job_completion(state, sandbox_id, background_job) # ty: ignore[unresolved-attribute]
+ while True:
+ status = await self.sandbox_client.get_background_job( # ty: ignore[unresolved-attribute]
+ sandbox_id,
+ background_job,
+ )
+ if status.completed:
+ state["agent_exit_code"] = status.exit_code
+ state["agent_stdout"] = status.stdout
+ state["agent_stderr"] = status.stderr
+ if status.exit_code not in (None, 0):
+ logger.warning(
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} "
+ f"stage=agent_completed exit_code={status.exit_code} "
+ f"stdout_tail={_tail_text(status.stdout)!r} "
+ f"stderr_tail={_tail_text(status.stderr)!r}"
+ )
+ else:
+ logger.debug(
+ f"rollout={state.get('rollout_id')} sandbox={sandbox_id} "
+ f"stage=agent_completed exit_code={status.exit_code}"
+ )
+ return
+ await asyncio.sleep(self.poll_interval) # ty: ignore[unresolved-attribute]
+
+ async def wait_for_agent_completion(self, state: State) -> None:
+ """Poll for agent completion using background job API."""
+ sandbox_id = state.get("sandbox_id")
+ background_job = state.get("background_job")
+ if not sandbox_id or not background_job:
+ state["agent_completed"] = True
+ return
+
+ try:
+ await asyncio.wait_for(
+ self.poll_job_completion(state, sandbox_id, background_job),
+ timeout=self._gateway_timeout_seconds,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} "
+ f"stage=wait_for_agent_completion timed out after {self._gateway_timeout_seconds:.1f}s"
+ )
+ state["agent_timed_out"] = True
+ finally:
+ state["agent_completed"] = True
+
+ async def _render_timing(self, state: State) -> None:
+ start_time = state["timing"]["start_time"]
+ end_time = time.time()
+ generation_ms = (end_time - start_time) * 1000
+ state["timing"]["generation_ms"] = generation_ms
+ state["timing"]["total_ms"] = generation_ms
+
+ async def rollout(
+ self,
+ input: RolloutInput,
+ client: Client,
+ model: str,
+ sampling_args: SamplingArgs | None = None,
+ ) -> State:
+ if not self.use_gateway:
+ return await super().rollout(input, client, model, sampling_args) # ty: ignore[unresolved-attribute]
+
+ state = await self.init_state(input, client, model, sampling_args) # ty: ignore[unresolved-attribute]
+ state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}"
+ state["gateway_url"] = self._resolve_gateway_url(state)
+ rollout_id = state["rollout_id"]
+ info = state.get("info") or {}
+ logger.info(
+ f"rollout={rollout_id} stage=start model={model} "
+ f"example_id={info.get('instance_id') or info.get('example_id')} "
+ f"repo={info.get('repo_name')}"
+ )
+
+ rollout_registered = False
+ try:
+ await self.register_rollout(state)
+ rollout_registered = True
+ logger.debug(f"rollout={rollout_id} stage=register_rollout ok")
+
+ tunnel_local_addr = self._resolve_tunnel_local_addr(state)
+ state["tunnel_local_addr"] = tunnel_local_addr
+ logger.debug(
+ f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}"
+ )
+
+ tunnel_url = await self.get_gateway_tunnel_url(local_addr=tunnel_local_addr)
+ state["tunnel_url"] = tunnel_url
+ state["rollout_base_url"] = (
+ f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}"
+ )
+ logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}")
+
+ env_vars = await self.build_env_vars(state)
+ docker_image = await self.get_docker_image(state) # ty: ignore[unresolved-attribute]
+ sandbox_request = CreateSandboxRequest(
+ name=state["rollout_id"],
+ docker_image=docker_image,
+ start_command=self.start_command, # ty: ignore[unresolved-attribute]
+ cpu_cores=self.cpu_cores, # ty: ignore[unresolved-attribute]
+ memory_gb=self.memory_gb, # ty: ignore[unresolved-attribute]
+ disk_size_gb=self.disk_size_gb, # ty: ignore[unresolved-attribute]
+ gpu_count=self.gpu_count, # ty: ignore[unresolved-attribute]
+ timeout_minutes=self.timeout_minutes, # ty: ignore[unresolved-attribute]
+ environment_vars=env_vars,
+ team_id=self.team_id, # ty: ignore[unresolved-attribute]
+ advanced_configs=self.advanced_configs, # ty: ignore[unresolved-attribute]
+ labels=self.labels or [], # ty: ignore[unresolved-attribute]
+ )
+ logger.debug(
+ f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} "
+ f"docker_image={docker_image}"
+ )
+ await self.create_sandbox(state, sandbox_request) # ty: ignore[unresolved-attribute]
+ logger.info(
+ f"rollout={rollout_id} stage=create_sandbox ok "
+ f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}"
+ )
+
+ await self.start_agent(state)
+ logger.debug(
+ f"rollout={rollout_id} stage=start_agent ok "
+ f"sandbox_id={state.get('sandbox_id')}"
+ )
+ await self.wait_for_agent_completion(state)
+ logger.debug(
+ f"rollout={rollout_id} stage=wait_for_agent_completion ok "
+ f"exit_code={state.get('agent_exit_code')}"
+ )
+ await self.fetch_trajectory(state)
+ trajectory = state.get("trajectory") or []
+ logger.info(
+ f"rollout={rollout_id} stage=fetch_trajectory ok "
+ f"turns={len(trajectory)} truncated={state.get('is_truncated', False)}"
+ )
+ if len(trajectory) == 0:
+ logger.warning(
+ f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory "
+ f"agent_exit_code={state.get('agent_exit_code')} "
+ f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} "
+ f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}"
+ )
+ except asyncio.CancelledError:
+ if rollout_registered:
+ try:
+ await self._gateway_post(state, "cancel")
+ except Exception:
+ pass
+ raise
+ except vf.Error as e:
+ state["error"] = e
+ logger.exception(
+ f"rollout={rollout_id} stage={type(e).__name__} vf_error={e}"
+ )
+ except Exception as e:
+ state["error"] = vf.InfraError(str(e))
+ logger.exception(
+ f"rollout={rollout_id} stage={type(e).__name__} unhandled_error={e}"
+ )
+ finally:
+ if rollout_registered:
+ try:
+ await self.unregister_rollout(state)
+ except Exception as e:
+ logger.warning(
+ f"Failed to unregister rollout {state['rollout_id']}: {e}"
+ )
+ if state.get("error") is None:
+ state["error"] = vf.InfraError(str(e))
+
+ if state.get("sandbox_id"):
+ try:
+ await self._cleanup(state) # ty: ignore[unresolved-attribute]
+ except Exception as e:
+ logger.warning(
+ f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}"
+ )
+ if state.get("error") is None:
+ state["error"] = vf.InfraError(str(e))
+
+ if state.get("completion") is None:
+ state["completion"] = []
+ if state.get("stop_condition") is None:
+ if state.get("error") is not None:
+ state["stop_condition"] = "has_error"
+ elif state.get("agent_timed_out", False):
+ state["stop_condition"] = "agent_timeout"
+ else:
+ state["stop_condition"] = "completed"
+ state["is_completed"] = True
+ await self._render_timing(state)
+ error_name = type(state["error"]).__name__ if state.get("error") else None
+ logger.info(
+ f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} "
+ f"sandbox_id={state.get('sandbox_id')} "
+ f"turns={len(state.get('trajectory', []))} "
+ f"agent_exit_code={state.get('agent_exit_code')} error={error_name}"
+ )
+
+ return state
+
+ @vf.teardown
+ async def teardown_gateway(self):
+ """Close gateway HTTP client and stop gateway tunnels."""
+ if not self.use_gateway:
+ return
+ await self._http_client.aclose()
+ async with self._tunnel_lock:
+ tunnels = list(self._tunnels.items())
+ self._tunnels = {}
+ for local_addr, tunnel in tunnels:
+ try:
+ tunnel.sync_stop()
+ logger.debug(f"Prime Tunnel stopped local_addr={local_addr}")
+ except Exception as e:
+ logger.warning(
+ f"Error stopping Prime Tunnel local_addr={local_addr}: {e}"
+ )