diff --git a/assets/lab/environments/AGENTS.md b/assets/lab/environments/AGENTS.md index cc7f76d02..98e63b1d7 100644 --- a/assets/lab/environments/AGENTS.md +++ b/assets/lab/environments/AGENTS.md @@ -802,5 +802,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks -- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`. +- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls diff --git a/docs/environments.md b/docs/environments.md index 84f884482..4f8e43121 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -796,5 +796,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks - **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls diff --git a/environments/AGENTS.md b/environments/AGENTS.md index da5323579..632fb0ee0 100644 --- a/environments/AGENTS.md +++ b/environments/AGENTS.md @@ -802,5 +802,6 @@ Newer and more experimental environment classes include: - **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API) - **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin` +- **`RolloutGatewayMixin`** — opt-in mixin for `CliAgentEnv` that replaces its interception-based rollout with a server-side gateway path, where the agent talks directly to the inference server's rollout gateway. Toggle between modes via the `use_gateway` attribute: when `True`, the mixin's `rollout()` fires and manages gateway registration, tunnel setup, and trajectory fetching; when `False`, falls through to `CliAgentEnv`'s interception path. Use with `class MyEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):` - **`HarborEnv`** — loads Harbor-format agent benchmark tasks -- **`RLMEnv`** — implements Recursive Language Models for unbounded context processing. Execution supports both local and sandbox backends via `execution_backend` (`"local"` default, `"sandbox"` to run the REPL inside a Prime Sandbox). Context is still filesystem-based: a provided `context_dir` is copied into the working directory, or legacy JSON-serializable `context` data is written to `context.json`/`context.txt`. The RLM scaffolding prompt (filesystem availability note, REPL workflow, tool docs) is injected into the first user message wrapped in `...`, preserving any external system prompt; the model-visible prompt is stored in `state["prompt"]`, while the original input prompt is preserved in `state["raw_prompt"]`. The REPL language is configurable via `repl_language` (default: `bash`); use `repl_language="python"` to retain the Python REPL. Bash mode uses `call_bash_repl` and behaves like a terminal; Python mode uses `call_python_repl`. Sub-LLM and root-tool interception for sandboxes is routed through a Prime Tunnel unless `interception_url` is provided. Tooling can be split via `tools` (shared), `root_tools` (REPL-only), and `sub_tools` (sub-LLM tools). Fixed root tools like `llm_batch` are always present and cannot be overridden. Tool ordering is fixed tools → shared tools → role-specific tools, with per-list deduplication by name. Root tools are callable only inside the REPL; sub-LLM tools use standard tool-calling. When using the sandbox backend, the sandbox and worker are started eagerly during `setup_state`, and package installs are skipped when the package is already importable in the image. Environments can pre-set `state["rlm_fs_root_remote"]` (and optionally `state["rlm_control_dir_remote"]`) before calling `super().setup_state` to point the worker at an existing filesystem path in the sandbox. For further customization, override `get_sandbox_request`, `on_sandbox_ready`, or `customize_worker_script` on `RLMEnv`. +- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py new file mode 100644 index 000000000..87bfe7b07 --- /dev/null +++ b/tests/test_rollout_gateway_env.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import httpx +import pytest +from datasets import Dataset + +import verifiers as vf +import verifiers.envs.experimental.rollout_gateway_mixin as rollout_gateway_mixin + +pytestmark = [pytest.mark.integration, pytest.mark.environments] + + +class FakeTunnel: + instances: list["FakeTunnel"] = [] + + def __init__( + self, + local_port: int, + local_addr: str = "127.0.0.1", + log_level: str | None = None, + ): + self.local_port = local_port + self.local_addr = local_addr + self.log_level = log_level + self.url: str | None = None + self.start_calls = 0 + self.stop_calls = 0 + FakeTunnel.instances.append(self) + + async def start(self) -> str: + self.start_calls += 1 + self.url = "https://unit-test.tunnel.prime.ai" + return self.url + + def sync_stop(self) -> None: + self.stop_calls += 1 + + +class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv): + def __init__(self, *, gateway_port=8000, use_gateway=True, **kwargs): + self.use_gateway = use_gateway + super().__init__(**kwargs) + if use_gateway: + self.init_gateway( + gateway_port=gateway_port, timeout_seconds=self.timeout_seconds + ) + + async def post_rollout(self, state: vf.State): + state["reward"] = 1.0 + state["test_output"] = "ok" + + +def _build_gateway_transport(tracker: dict) -> httpx.MockTransport: + trajectory = [ + { + "prompt": [{"role": "user", "content": "Hello"}], + "completion": [{"role": "assistant", "content": "reply-1"}], + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + { + "prompt": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + ], + "completion": [{"role": "assistant", "content": "reply-2"}], + "tokens": { + "prompt_ids": [1, 2, 3, 4], + "prompt_mask": [0, 0, 0, 0], + "completion_ids": [5], + "completion_mask": [1], + "completion_logprobs": [-0.2], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + { + "prompt": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + {"role": "assistant", "content": "reply-2"}, + {"role": "user", "content": "Turn 3"}, + ], + "completion": [{"role": "assistant", "content": "reply-3"}], + "tokens": { + "prompt_ids": [1, 2, 3, 4, 5, 6], + "prompt_mask": [0, 0, 0, 0, 0, 0], + "completion_ids": [7], + "completion_mask": [1], + "completion_logprobs": [-0.3], + "overlong_prompt": False, + "is_truncated": False, + }, + "reward": None, + "advantage": None, + "is_truncated": False, + "trajectory_id": "traj-1", + "extras": {}, + }, + ] + + def _handler(request: httpx.Request) -> httpx.Response: + tracker["hosts"].add(request.url.host) + tracker["paths"].append(request.url.path) + path = request.url.path + + if request.method == "POST" and path.endswith("/register"): + payload = json.loads(request.content.decode("utf-8")) + tracker["register_payload"] = payload + tracker["rollout_id"] = path.split("/")[-2] + return httpx.Response(status_code=200, json={"status": "active"}) + + if request.method == "POST" and path.endswith("/unregister"): + tracker["unregister_calls"] += 1 + return httpx.Response(status_code=200, json={"status": "active"}) + + if request.method == "GET" and path.endswith("/trajectory"): + tracker["trajectory_calls"] += 1 + return httpx.Response( + status_code=200, + json={ + "rollout_id": tracker["rollout_id"], + "status": "completed", + "num_turns": 3, + "model": "Qwen/Qwen3-0.6B", + "prompt": trajectory[0]["prompt"], + "completion": [ + {"role": "assistant", "content": "reply-1"}, + {"role": "user", "content": "Turn 2"}, + {"role": "assistant", "content": "reply-2"}, + {"role": "user", "content": "Turn 3"}, + {"role": "assistant", "content": "reply-3"}, + ], + "is_truncated": False, + "trajectory": trajectory, + }, + ) + + return httpx.Response(status_code=404, json={"error": f"Unhandled path {path}"}) + + return httpx.MockTransport(_handler) + + +@pytest.mark.asyncio +async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): + FakeTunnel.instances.clear() + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + tracker = { + "paths": [], + "hosts": set(), + "register_payload": None, + "rollout_id": None, + "trajectory_calls": 0, + "unregister_calls": 0, + } + transport = _build_gateway_transport(tracker) + real_async_client = httpx.AsyncClient + client = vf.OpenAIChatCompletionsClient( + vf.ClientConfig( + api_key_var="UNIT_TEST_API_KEY", + api_base_url="http://gateway.internal:8000/v1/", + ) + ) + + def _client_factory(*args, **kwargs): + kwargs["transport"] = transport + return real_async_client(*args, **kwargs) + + monkeypatch.setattr(rollout_gateway_mixin.httpx, "AsyncClient", _client_factory) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + max_turns=10, + timeout_seconds=30.0, + ) + + env.sandbox_client.create = AsyncMock(return_value=SimpleNamespace(id="sb-123")) + env.sandbox_client.wait_for_creation = AsyncMock(return_value=None) + env.sandbox_client.start_background_job = AsyncMock( + return_value=SimpleNamespace(id="job-1") + ) + env.sandbox_client.get_background_job = AsyncMock( + return_value=SimpleNamespace( + completed=True, + exit_code=0, + stdout="agent ok", + stderr="", + ) + ) + env.sandbox_client.delete = AsyncMock(return_value=None) + + rollout_input = { + "prompt": [{"role": "user", "content": "Hello"}], + "answer": "", + "example_id": 0, + "task": "gateway-test", + } + state = await env.rollout( + input=rollout_input, + client=client, + model="Qwen/Qwen3-0.6B", + sampling_args={"temperature": 0.7, "max_completion_tokens": 64}, + ) + + assert state.get("error") is None + assert state["gateway_url"] == "http://gateway.internal:8000" + assert state["tunnel_url"] == "https://unit-test.tunnel.prime.ai" + assert state["rollout_base_url"].startswith( + "https://unit-test.tunnel.prime.ai/v1/rollouts/" + ) + assert len(state["trajectory"]) == 3 + assert state["prompt"] == [{"role": "user", "content": "Hello"}] + assert state["completion"][-1]["content"] == "reply-3" + assert state["reward"] == 1.0 + + create_request = env.sandbox_client.create.await_args.args[0] + assert ( + create_request.environment_vars["OPENAI_BASE_URL"] + == f"https://unit-test.tunnel.prime.ai/v1/rollouts/{state['rollout_id']}" + ) + assert create_request.environment_vars["OPENAI_MODEL"] == "Qwen/Qwen3-0.6B" + + assert tracker["register_payload"]["max_turns"] == 10 + assert tracker["register_payload"]["sampling_params"]["temperature"] == 0.7 + assert tracker["register_payload"]["sampling_params"]["max_completion_tokens"] == 64 + assert tracker["trajectory_calls"] == 1 + assert tracker["unregister_calls"] == 1 + assert tracker["hosts"] == {"gateway.internal"} + + assert len(FakeTunnel.instances) == 1 + tunnel = FakeTunnel.instances[0] + assert tunnel.local_port == 8000 + assert tunnel.local_addr == "gateway.internal" + assert tunnel.start_calls == 1 + assert await env.get_gateway_tunnel_url() == "https://unit-test.tunnel.prime.ai" + assert tunnel.start_calls == 1 + + await env.teardown_gateway() + assert tunnel.stop_calls == 1 + + +@pytest.mark.asyncio +async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): + FakeTunnel.instances.clear() + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + url_a = await env.get_gateway_tunnel_url(local_addr="10.20.0.58") + url_b = await env.get_gateway_tunnel_url(local_addr="10.20.0.59") + url_a_reuse = await env.get_gateway_tunnel_url(local_addr="10.20.0.58") + + assert url_a == "https://unit-test.tunnel.prime.ai" + assert url_b == "https://unit-test.tunnel.prime.ai" + assert url_a_reuse == url_a + + assert len(FakeTunnel.instances) == 2 + assert {t.local_addr for t in FakeTunnel.instances} == {"10.20.0.58", "10.20.0.59"} + assert sum(t.start_calls for t in FakeTunnel.instances) == 2 + + with pytest.raises( + ValueError, match="local_addr is required when multiple tunnels are active" + ): + await env.get_gateway_tunnel_url() + + await env.teardown_gateway() + assert sum(t.stop_calls for t in FakeTunnel.instances) == 2 + + +@pytest.mark.asyncio +async def test_use_gateway_false_initializes_interception(monkeypatch): + """With use_gateway=False, interception server is created and gateway is not.""" + FakeTunnel.instances.clear() + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + use_gateway=False, + timeout_seconds=30.0, + ) + + # Interception server should be initialized + assert env._interception_server is not None + assert env._tunnel is None + assert env._tunnel_lock is not None + + # Gateway attributes should not exist (init_gateway was never called) + assert not hasattr(env, "_http_client") + assert not hasattr(env, "_tunnels") + + # Teardowns should be safe no-ops for the inactive path + await env.teardown_gateway() # early return via use_gateway=False + await env.teardown_resources() # stops interception (which was never started) + + assert len(FakeTunnel.instances) == 0 diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 2cd3480d8..e7fa269ea 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -67,6 +67,7 @@ "ReasoningGymEnv", "GymEnv", "CliAgentEnv", + "RolloutGatewayMixin", "HarborEnv", "MCPEnv", "BrowserEnv", @@ -121,6 +122,7 @@ "PythonEnv": "verifiers.envs.python_env:PythonEnv", "GymEnv": "verifiers.envs.experimental.gym_env:GymEnv", "CliAgentEnv": "verifiers.envs.experimental.cli_agent_env:CliAgentEnv", + "RolloutGatewayMixin": "verifiers.envs.experimental.rollout_gateway_mixin:RolloutGatewayMixin", "HarborEnv": "verifiers.envs.experimental.harbor_env:HarborEnv", "MCPEnv": "verifiers.envs.experimental.mcp_env:MCPEnv", "ReasoningGymEnv": "verifiers.envs.integrations.reasoninggym_env:ReasoningGymEnv", @@ -160,6 +162,7 @@ def __getattr__(name: str): from typing import Any from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401 + from .envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin # noqa: F401 from .envs.experimental.gym_env import GymEnv # noqa: F401 from .envs.experimental.harbor_env import HarborEnv # noqa: F401 from .envs.experimental.mcp_env import MCPEnv # noqa: F401 diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 40ef74cbe..96126b506 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -622,6 +622,7 @@ async def init_state( state["tool_defs"] = self._normalize_tool_defs(resolved_tool_defs) or [] state["trajectory"] = [] + state["completion"] = None self._get_usage_tracker(state, create_if_missing=True) state["trajectory_id"] = uuid.uuid4().hex state["reward"] = None diff --git a/verifiers/envs/experimental/__init__.py b/verifiers/envs/experimental/__init__.py index 034cb817f..ff81549e0 100644 --- a/verifiers/envs/experimental/__init__.py +++ b/verifiers/envs/experimental/__init__.py @@ -1,3 +1,4 @@ +from verifiers.envs.experimental.rollout_gateway_mixin import RolloutGatewayMixin from verifiers.envs.experimental.sandbox_mixin import SandboxMixin -__all__ = ["SandboxMixin"] +__all__ = ["RolloutGatewayMixin", "SandboxMixin"] diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 5cfb9e42e..f4f91093b 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -84,8 +84,6 @@ def __init__( ) self.run_command = run_command self.poll_interval = poll_interval - self.interception_port = interception_port - self.interception_url = interception_url self.timeout_seconds = timeout_seconds self.docker_image = docker_image self.start_command = start_command @@ -99,7 +97,20 @@ def __init__( self.advanced_configs = advanced_configs self.labels = labels - # Tunnel and interception server + self._interception_server: InterceptionServer | None = None + self._tunnel: Tunnel | None = None + self._tunnel_lock = asyncio.Lock() + + self.init_interception(interception_port, interception_url) + + def init_interception( + self, + interception_port: int = 8765, + interception_url: str | None = None, + ): + """Initialize interception server and tunnel resources. Call from __init__.""" + self.interception_port = interception_port + self.interception_url = interception_url self._tunnel: Tunnel | None = None self._tunnel_lock = asyncio.Lock() self._interception_server = InterceptionServer(port=interception_port) @@ -420,7 +431,8 @@ async def teardown_resources(self): logger.warning(f"Error stopping Prime Tunnel: {e}") finally: self._tunnel = None - await self._interception_server.stop() + if self._interception_server is not None: + await self._interception_server.stop() @vf.cleanup async def cleanup_interception_context(self, state: State): @@ -437,7 +449,7 @@ async def cleanup_interception_context(self, state: State): state.pop("background_job", None) rollout_id = state.get("rollout_id") - if rollout_id: + if rollout_id and self._interception_server is not None: self._interception_server.unregister_rollout(rollout_id) @vf.stop diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py new file mode 100644 index 000000000..833b40b30 --- /dev/null +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -0,0 +1,397 @@ +import asyncio +import logging +import time +import uuid +from typing import Any +from urllib.parse import urlparse + +import httpx +from prime_sandboxes import CreateSandboxRequest +from prime_tunnel import Tunnel + +import verifiers as vf +from verifiers.clients import Client +from verifiers.types import RolloutInput, SamplingArgs, State + +logger = logging.getLogger(__name__) + + +def _tail_text(value: Any, max_chars: int = 1200) -> str: + if value is None: + return "" + text = str(value) + return text[-max_chars:] if len(text) > max_chars else text + + +class RolloutGatewayMixin: + """Opt-in mixin that replaces CliAgentEnv's interception-based rollout + with a server-side gateway path. Toggle via ``use_gateway`` attribute. + + When gateway is active, the agent talks directly to prime-rl's rollout + gateway through a prime tunnel. The env only manages sandbox lifecycle. + When inactive, falls through to CliAgentEnv's interception path. + + MRO: ``MyEnv → RolloutGatewayMixin → CliAgentEnv → SandboxMixin → MultiTurnEnv → Environment`` + """ + + use_gateway: bool = True + + def init_interception(self, *args, **kwargs): + if not self.use_gateway: + super().init_interception(*args, **kwargs) # ty: ignore[unresolved-attribute] + + def init_gateway( + self, + gateway_port: int = 8000, + timeout_seconds: float = 3600.0, + ): + """Initialize gateway resources. Call in __init__ when use_gateway=True.""" + self.gateway_port = gateway_port + self._gateway_timeout_seconds = timeout_seconds + self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds)) + self._tunnels: dict[str, Tunnel] = {} + self._tunnel_lock = asyncio.Lock() + + def _resolve_gateway_url(self, state: State) -> str: + client = getattr(state["client"], "client", state["client"]) + gateway_url = str(client.base_url).rstrip("/") + if gateway_url.endswith("/v1"): + gateway_url = gateway_url[:-3] + return gateway_url + + def _resolve_tunnel_local_addr(self, state: State) -> str: + gateway_url = state["gateway_url"] + parsed = urlparse(gateway_url) + host = parsed.hostname + if host is None: + raise ValueError(f"Invalid gateway URL; missing hostname: {gateway_url}") + return host + + def _rollout_endpoint(self, state: State, suffix: str) -> str: + return f"{state['gateway_url']}/v1/rollouts/{state['rollout_id']}/{suffix.lstrip('/')}" + + async def _gateway_post( + self, + state: State, + suffix: str, + payload: dict[str, Any] | None = None, + ) -> dict[str, Any]: + response = await self._http_client.post( + self._rollout_endpoint(state, suffix), + json=payload, + ) + response.raise_for_status() + if not response.content: + return {} + return response.json() + + async def _gateway_get(self, state: State, suffix: str) -> dict[str, Any]: + response = await self._http_client.get(self._rollout_endpoint(state, suffix)) + response.raise_for_status() + return response.json() + + async def build_env_vars(self, state: State) -> dict[str, str]: + """Override to set OPENAI_BASE_URL from rollout_base_url in gateway mode.""" + if not self.use_gateway: + return await super().build_env_vars(state) # ty: ignore[unresolved-attribute] + env_vars = dict(self.environment_vars) if self.environment_vars else {} # ty: ignore[unresolved-attribute] + env_vars["OPENAI_BASE_URL"] = state["rollout_base_url"] + env_vars.setdefault("OPENAI_TIMEOUT", "600") + env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "600") + env_vars.setdefault("HTTPX_TIMEOUT", "600") + model = state.get("model") + if model: + env_vars["OPENAI_MODEL"] = model + return env_vars + + async def register_rollout(self, state: State) -> None: + sampling_params = state.get("sampling_args") or {} + payload = { + "model": state["model"], + "sampling_params": sampling_params, + "max_turns": self.max_turns, # ty: ignore[unresolved-attribute] + "max_seq_len": self.max_seq_len, # ty: ignore[unresolved-attribute] + } + await self._gateway_post(state, "register", payload) + + async def unregister_rollout(self, state: State) -> None: + await self._gateway_post(state, "unregister") + + async def fetch_trajectory(self, state: State) -> None: + data = await self._gateway_get(state, "trajectory") + state["trajectory"] = data.get("trajectory", []) + state["prompt"] = data.get("prompt") + state["completion"] = data.get("completion") + state["is_truncated"] = bool( + data.get("is_truncated", state.get("is_truncated", False)) + ) + + async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: + """Get gateway tunnel URL, starting the tunnel if needed.""" + async with self._tunnel_lock: + if local_addr is None: + if len(self._tunnels) == 1: + tunnel = next(iter(self._tunnels.values())) + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + if len(self._tunnels) == 0: + raise ValueError("local_addr is required when starting tunnel") + raise ValueError( + "local_addr is required when multiple tunnels are active" + ) + + tunnel = self._tunnels.get(local_addr) + if tunnel is None: + tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=local_addr, + log_level="debug" if logger.isEnabledFor(logging.DEBUG) else "info", + ) + url = await tunnel.start() + self._tunnels[local_addr] = tunnel + logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") + return url + + assert tunnel.url is not None, "Tunnel started but URL is None" + return tunnel.url + + async def start_agent(self, state: State) -> None: + """Start the agent command. In gateway mode, skip background completion task.""" + if not self.use_gateway: + return await super().start_agent(state) # ty: ignore[unresolved-attribute] + sandbox_id = state["sandbox_id"] + background_job = await self.sandbox_client.start_background_job( # ty: ignore[unresolved-attribute] + sandbox_id, + self.run_command, # ty: ignore[unresolved-attribute] + ) + state["background_job"] = background_job + state["agent_start_time"] = time.time() + state["agent_completed"] = False + + async def poll_job_completion( + self, + state: State, + sandbox_id: str, + background_job, + ) -> None: + """Poll until background job completes, capturing output.""" + if not self.use_gateway: + return await super().poll_job_completion(state, sandbox_id, background_job) # ty: ignore[unresolved-attribute] + while True: + status = await self.sandbox_client.get_background_job( # ty: ignore[unresolved-attribute] + sandbox_id, + background_job, + ) + if status.completed: + state["agent_exit_code"] = status.exit_code + state["agent_stdout"] = status.stdout + state["agent_stderr"] = status.stderr + if status.exit_code not in (None, 0): + logger.warning( + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} " + f"stage=agent_completed exit_code={status.exit_code} " + f"stdout_tail={_tail_text(status.stdout)!r} " + f"stderr_tail={_tail_text(status.stderr)!r}" + ) + else: + logger.debug( + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} " + f"stage=agent_completed exit_code={status.exit_code}" + ) + return + await asyncio.sleep(self.poll_interval) # ty: ignore[unresolved-attribute] + + async def wait_for_agent_completion(self, state: State) -> None: + """Poll for agent completion using background job API.""" + sandbox_id = state.get("sandbox_id") + background_job = state.get("background_job") + if not sandbox_id or not background_job: + state["agent_completed"] = True + return + + try: + await asyncio.wait_for( + self.poll_job_completion(state, sandbox_id, background_job), + timeout=self._gateway_timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + f"rollout={state.get('rollout_id')} sandbox={state.get('sandbox_id')} " + f"stage=wait_for_agent_completion timed out after {self._gateway_timeout_seconds:.1f}s" + ) + state["agent_timed_out"] = True + finally: + state["agent_completed"] = True + + async def _render_timing(self, state: State) -> None: + start_time = state["timing"]["start_time"] + end_time = time.time() + generation_ms = (end_time - start_time) * 1000 + state["timing"]["generation_ms"] = generation_ms + state["timing"]["total_ms"] = generation_ms + + async def rollout( + self, + input: RolloutInput, + client: Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + if not self.use_gateway: + return await super().rollout(input, client, model, sampling_args) # ty: ignore[unresolved-attribute] + + state = await self.init_state(input, client, model, sampling_args) # ty: ignore[unresolved-attribute] + state["rollout_id"] = f"rollout_{uuid.uuid4().hex[:8]}" + state["gateway_url"] = self._resolve_gateway_url(state) + rollout_id = state["rollout_id"] + info = state.get("info") or {} + logger.info( + f"rollout={rollout_id} stage=start model={model} " + f"example_id={info.get('instance_id') or info.get('example_id')} " + f"repo={info.get('repo_name')}" + ) + + rollout_registered = False + try: + await self.register_rollout(state) + rollout_registered = True + logger.debug(f"rollout={rollout_id} stage=register_rollout ok") + + tunnel_local_addr = self._resolve_tunnel_local_addr(state) + state["tunnel_local_addr"] = tunnel_local_addr + logger.debug( + f"rollout={rollout_id} stage=resolve_tunnel_local_addr addr={tunnel_local_addr}" + ) + + tunnel_url = await self.get_gateway_tunnel_url(local_addr=tunnel_local_addr) + state["tunnel_url"] = tunnel_url + state["rollout_base_url"] = ( + f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}" + ) + logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}") + + env_vars = await self.build_env_vars(state) + docker_image = await self.get_docker_image(state) # ty: ignore[unresolved-attribute] + sandbox_request = CreateSandboxRequest( + name=state["rollout_id"], + docker_image=docker_image, + start_command=self.start_command, # ty: ignore[unresolved-attribute] + cpu_cores=self.cpu_cores, # ty: ignore[unresolved-attribute] + memory_gb=self.memory_gb, # ty: ignore[unresolved-attribute] + disk_size_gb=self.disk_size_gb, # ty: ignore[unresolved-attribute] + gpu_count=self.gpu_count, # ty: ignore[unresolved-attribute] + timeout_minutes=self.timeout_minutes, # ty: ignore[unresolved-attribute] + environment_vars=env_vars, + team_id=self.team_id, # ty: ignore[unresolved-attribute] + advanced_configs=self.advanced_configs, # ty: ignore[unresolved-attribute] + labels=self.labels or [], # ty: ignore[unresolved-attribute] + ) + logger.debug( + f"Creating sandbox with OPENAI_BASE_URL={env_vars.get('OPENAI_BASE_URL')} " + f"docker_image={docker_image}" + ) + await self.create_sandbox(state, sandbox_request) # ty: ignore[unresolved-attribute] + logger.info( + f"rollout={rollout_id} stage=create_sandbox ok " + f"sandbox_id={state.get('sandbox_id')} docker_image={docker_image}" + ) + + await self.start_agent(state) + logger.debug( + f"rollout={rollout_id} stage=start_agent ok " + f"sandbox_id={state.get('sandbox_id')}" + ) + await self.wait_for_agent_completion(state) + logger.debug( + f"rollout={rollout_id} stage=wait_for_agent_completion ok " + f"exit_code={state.get('agent_exit_code')}" + ) + await self.fetch_trajectory(state) + trajectory = state.get("trajectory") or [] + logger.info( + f"rollout={rollout_id} stage=fetch_trajectory ok " + f"turns={len(trajectory)} truncated={state.get('is_truncated', False)}" + ) + if len(trajectory) == 0: + logger.warning( + f"rollout={rollout_id} stage=fetch_trajectory empty_trajectory " + f"agent_exit_code={state.get('agent_exit_code')} " + f"stdout_tail={_tail_text(state.get('agent_stdout'))!r} " + f"stderr_tail={_tail_text(state.get('agent_stderr'))!r}" + ) + except asyncio.CancelledError: + if rollout_registered: + try: + await self._gateway_post(state, "cancel") + except Exception: + pass + raise + except vf.Error as e: + state["error"] = e + logger.exception( + f"rollout={rollout_id} stage={type(e).__name__} vf_error={e}" + ) + except Exception as e: + state["error"] = vf.InfraError(str(e)) + logger.exception( + f"rollout={rollout_id} stage={type(e).__name__} unhandled_error={e}" + ) + finally: + if rollout_registered: + try: + await self.unregister_rollout(state) + except Exception as e: + logger.warning( + f"Failed to unregister rollout {state['rollout_id']}: {e}" + ) + if state.get("error") is None: + state["error"] = vf.InfraError(str(e)) + + if state.get("sandbox_id"): + try: + await self._cleanup(state) # ty: ignore[unresolved-attribute] + except Exception as e: + logger.warning( + f"Failed to destroy sandbox {state.get('sandbox_id')}: {e}" + ) + if state.get("error") is None: + state["error"] = vf.InfraError(str(e)) + + if state.get("completion") is None: + state["completion"] = [] + if state.get("stop_condition") is None: + if state.get("error") is not None: + state["stop_condition"] = "has_error" + elif state.get("agent_timed_out", False): + state["stop_condition"] = "agent_timeout" + else: + state["stop_condition"] = "completed" + state["is_completed"] = True + await self._render_timing(state) + error_name = type(state["error"]).__name__ if state.get("error") else None + logger.info( + f"rollout={rollout_id} stage=finish stop={state.get('stop_condition')} " + f"sandbox_id={state.get('sandbox_id')} " + f"turns={len(state.get('trajectory', []))} " + f"agent_exit_code={state.get('agent_exit_code')} error={error_name}" + ) + + return state + + @vf.teardown + async def teardown_gateway(self): + """Close gateway HTTP client and stop gateway tunnels.""" + if not self.use_gateway: + return + await self._http_client.aclose() + async with self._tunnel_lock: + tunnels = list(self._tunnels.items()) + self._tunnels = {} + for local_addr, tunnel in tunnels: + try: + tunnel.sync_stop() + logger.debug(f"Prime Tunnel stopped local_addr={local_addr}") + except Exception as e: + logger.warning( + f"Error stopping Prime Tunnel local_addr={local_addr}: {e}" + )