diff --git a/docs/environments.md b/docs/environments.md index 4f8e43121..f02987c9a 100644 --- a/docs/environments.md +++ b/docs/environments.md @@ -592,7 +592,7 @@ Verifiers defines a hierarchy of error types under `vf.Error`: - `vf.ModelError` — errors from model interactions (e.g., `vf.EmptyModelResponseError`) - `vf.OverlongPromptError` — prompt exceeds model context length - `vf.ToolError` — tool-related errors (`vf.ToolParseError`, `vf.ToolCallError`) -- `vf.InfraError` — infrastructure errors (e.g., `vf.SandboxError`) +- `vf.InfraError` — infrastructure errors (e.g., `vf.SandboxError`, `vf.TunnelError`) When a `vf.Error` is raised during a rollout, it is automatically caught and stored in `state["error"]`, triggering the built-in `has_error` stop condition at the next check. This allows rollouts to terminate gracefully rather than crashing. diff --git a/pyproject.toml b/pyproject.toml index cc23cc1d7..9982c9c70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,8 +37,8 @@ dependencies = [ "nest-asyncio>=1.6.0", # for jupyter notebooks "openai>=1.108.1", "openai-agents>=0.0.7", - "prime-tunnel>=0.1.1", - "prime-sandboxes>=0.2.14", + "prime-tunnel>=0.1.4", + "prime-sandboxes>=0.2.16", "pydantic>=2.11.9", "requests", "rich", diff --git a/tests/test_rollout_gateway_env.py b/tests/test_rollout_gateway_env.py index 87bfe7b07..97223b2b7 100644 --- a/tests/test_rollout_gateway_env.py +++ b/tests/test_rollout_gateway_env.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from types import SimpleNamespace from unittest.mock import AsyncMock @@ -16,6 +17,7 @@ class FakeTunnel: instances: list["FakeTunnel"] = [] + _next_id: int = 0 def __init__( self, @@ -27,17 +29,35 @@ def __init__( self.local_addr = local_addr self.log_level = log_level self.url: str | None = None + FakeTunnel._next_id += 1 + self.tunnel_id: str = f"fake-tunnel-{FakeTunnel._next_id}" + self._is_running: bool = True + self._recent_output: list[str] = ["frpc log line 1", "frpc log line 2"] self.start_calls = 0 self.stop_calls = 0 FakeTunnel.instances.append(self) + @property + def is_running(self) -> bool: + return self._is_running + + @property + def recent_output(self) -> list[str]: + return list(self._recent_output) + async def start(self) -> str: self.start_calls += 1 + self._is_running = True self.url = "https://unit-test.tunnel.prime.ai" return self.url + async def stop(self) -> None: + self.stop_calls += 1 + self._is_running = False + def sync_stop(self) -> None: self.stop_calls += 1 + self._is_running = False class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv): @@ -167,6 +187,7 @@ def _handler(request: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch): FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) tracker = { @@ -276,6 +297,7 @@ def _client_factory(*args, **kwargs): @pytest.mark.asyncio async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) dataset = Dataset.from_dict( @@ -317,6 +339,7 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch): async def test_use_gateway_false_initializes_interception(monkeypatch): """With use_gateway=False, interception server is created and gateway is not.""" FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) dataset = Dataset.from_dict( @@ -348,3 +371,180 @@ async def test_use_gateway_false_initializes_interception(monkeypatch): await env.teardown_resources() # stops interception (which was never started) assert len(FakeTunnel.instances) == 0 + + +@pytest.mark.asyncio +async def test_dead_tunnel_recreated_on_get_gateway_tunnel_url(monkeypatch): + """Dead tunnel is stopped and replaced when get_gateway_tunnel_url is called.""" + FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + # Start a tunnel + await env.get_gateway_tunnel_url(local_addr="10.0.0.1") + assert len(FakeTunnel.instances) == 1 + original_tunnel = FakeTunnel.instances[0] + assert original_tunnel.start_calls == 1 + + # Kill the tunnel + original_tunnel._is_running = False + + # Requesting URL again should recreate + url2 = await env.get_gateway_tunnel_url(local_addr="10.0.0.1") + assert url2 == "https://unit-test.tunnel.prime.ai" + assert len(FakeTunnel.instances) == 2 + assert original_tunnel.stop_calls == 1 + new_tunnel = FakeTunnel.instances[1] + assert new_tunnel.start_calls == 1 + assert new_tunnel._is_running is True + + await env.teardown_gateway() + + +@pytest.mark.asyncio +async def test_poll_job_completion_raises_tunnel_error_on_dead_tunnel(monkeypatch): + """poll_job_completion raises TunnelError when the tunnel dies mid-rollout.""" + FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + # Start a tunnel + await env.get_gateway_tunnel_url(local_addr="10.0.0.1") + tunnel = FakeTunnel.instances[0] + + # Mock sandbox_client.get_background_job to never complete + env.sandbox_client = SimpleNamespace( + get_background_job=AsyncMock(return_value=SimpleNamespace(completed=False)), + ) + + state = { + "rollout_id": "rollout_test123", + "tunnel_local_addr": "10.0.0.1", + } + background_job = SimpleNamespace(id="job-1") + + # Kill tunnel after a brief delay + async def kill_tunnel(): + await asyncio.sleep(0.05) + tunnel._is_running = False + + asyncio.create_task(kill_tunnel()) + + with pytest.raises(vf.TunnelError, match="Tunnel process died"): + await env.poll_job_completion(state, "sb-123", background_job) + + await env.teardown_gateway() + + +@pytest.mark.asyncio +async def test_health_monitor_restarts_dead_tunnels(monkeypatch): + """Background health monitor detects and restarts dead tunnels.""" + FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + # Start a tunnel (this also starts the health monitor) + await env.get_gateway_tunnel_url(local_addr="10.0.0.1") + assert env._tunnel_monitor_task is not None + assert not env._tunnel_monitor_task.done() + + original_tunnel = FakeTunnel.instances[0] + original_tunnel._is_running = False + + # Run the health monitor with a short interval + # Cancel the default one and start one with a short interval + env._tunnel_monitor_task.cancel() + try: + await env._tunnel_monitor_task + except asyncio.CancelledError: + pass + + env._tunnel_monitor_task = asyncio.create_task( + env._tunnel_health_monitor(interval=0.05) + ) + + # Wait for the monitor to detect and restart + await asyncio.sleep(0.2) + + assert len(FakeTunnel.instances) == 2 + assert original_tunnel.stop_calls == 1 + new_tunnel = FakeTunnel.instances[1] + assert new_tunnel.start_calls == 1 + assert new_tunnel._is_running is True + + await env.teardown_gateway() + + +@pytest.mark.asyncio +async def test_teardown_gateway_cancels_health_monitor(monkeypatch): + """teardown_gateway cancels the health monitor task.""" + FakeTunnel.instances.clear() + FakeTunnel._next_id = 0 + monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel) + + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": "Hello"}]], + "answer": [""], + "example_id": [0], + } + ) + env = GatewayCliAgentEnv( + run_command="echo run-agent", + dataset=dataset, + rubric=vf.Rubric(), + gateway_port=8000, + ) + + # Start a tunnel to create the health monitor + await env.get_gateway_tunnel_url(local_addr="10.0.0.1") + monitor_task = env._tunnel_monitor_task + assert monitor_task is not None + assert not monitor_task.done() + + await env.teardown_gateway() + + assert monitor_task.done() + assert env._tunnel_monitor_task is None diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 67c443437..1d88f304f 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -47,7 +47,7 @@ def __init__( interception_url: str | None = None, max_turns: int = -1, timeout_seconds: float = 3600.0, - poll_interval: float = 2.0, + poll_interval: float = 5.0, docker_image: str = "python:3.11-slim", start_command: str = "tail -f /dev/null", cpu_cores: int = 1, @@ -115,16 +115,19 @@ def init_interception( self._tunnel_lock = asyncio.Lock() self._interception_server = InterceptionServer(port=interception_port) - @property - def _server(self) -> InterceptionServer: - assert self._interception_server is not None - return self._interception_server - async def get_tunnel_url(self) -> str: - """Get tunnel URL, starting the tunnel if needed.""" + """Get tunnel URL, starting the tunnel if needed. Recreates dead tunnels.""" async with self._tunnel_lock: + if self._tunnel is not None and not self._tunnel.is_running: + frpc_output = "\n".join(self._tunnel.recent_output) + logger.warning( + f"Tunnel process died, recreating. frpc output:\n{frpc_output}" + ) + self._tunnel.sync_stop() + self._tunnel = None + if self._tunnel is None: - port = self._server.port + port = self._interception_server.port # ty: ignore[unresolved-attribute] if logger.isEnabledFor(logging.DEBUG): self._tunnel = Tunnel( local_port=port, @@ -146,7 +149,7 @@ async def setup_state(self, state: State) -> State: rollout_id = f"rollout_{uuid.uuid4().hex[:8]}" state["rollout_id"] = rollout_id - await self._server.start() + await self._interception_server.start() # ty: ignore[unresolved-attribute] if self.interception_url is None: tunnel_url = await self.get_tunnel_url() @@ -180,7 +183,7 @@ async def setup_state(self, state: State) -> State: await self.create_sandbox(state, sandbox_request) # Register rollout for interception - request_id_queue = self._server.register_rollout(rollout_id) + request_id_queue = self._interception_server.register_rollout(rollout_id) # ty: ignore[unresolved-attribute] state["request_id_queue"] = request_id_queue state["agent_completed"] = False @@ -281,11 +284,18 @@ async def get_prompt_messages(self, state: State) -> Messages: ) # Got a request, proceed normally state["current_request_id"] = request_id - intercept = self._server.intercepts[request_id] + intercept = self._interception_server.intercepts[request_id] # ty: ignore[unresolved-attribute] return intercept["messages"] except asyncio.TimeoutError: - # No request yet, check if agent finished or timed out + # No request yet — check tunnel liveness first + if self._tunnel is not None and not self._tunnel.is_running: + frpc_output = "\n".join(self._tunnel.recent_output) + raise vf.TunnelError( + f"Tunnel process died during rollout. " + f"frpc output:\n{frpc_output}" + ) + # Then check if agent finished or timed out if await self.check_agent_completed(state): state["agent_completed"] = True return [] @@ -367,7 +377,9 @@ async def get_model_response( ) request_id = state.get("current_request_id") - intercept = self._server.intercepts.get(request_id) if request_id else None + intercept = ( + self._interception_server.intercepts.get(request_id) if request_id else None # ty: ignore[unresolved-attribute] + ) if intercept: # Always use the configured model from state, not the intercepted model diff --git a/verifiers/envs/experimental/rollout_gateway_mixin.py b/verifiers/envs/experimental/rollout_gateway_mixin.py index 833b40b30..e0960a0bd 100644 --- a/verifiers/envs/experimental/rollout_gateway_mixin.py +++ b/verifiers/envs/experimental/rollout_gateway_mixin.py @@ -43,7 +43,7 @@ def init_interception(self, *args, **kwargs): def init_gateway( self, gateway_port: int = 8000, - timeout_seconds: float = 3600.0, + timeout_seconds: float = 21600.0, ): """Initialize gateway resources. Call in __init__ when use_gateway=True.""" self.gateway_port = gateway_port @@ -51,6 +51,7 @@ def init_gateway( self._http_client = httpx.AsyncClient(timeout=httpx.Timeout(timeout_seconds)) self._tunnels: dict[str, Tunnel] = {} self._tunnel_lock = asyncio.Lock() + self._tunnel_monitor_task: asyncio.Task | None = None def _resolve_gateway_url(self, state: State) -> str: client = getattr(state["client"], "client", state["client"]) @@ -127,7 +128,7 @@ async def fetch_trajectory(self, state: State) -> None: ) async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: - """Get gateway tunnel URL, starting the tunnel if needed.""" + """Get gateway tunnel URL, starting the tunnel if needed. Restarts dead tunnels.""" async with self._tunnel_lock: if local_addr is None: if len(self._tunnels) == 1: @@ -141,6 +142,19 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: ) tunnel = self._tunnels.get(local_addr) + + # Restart dead tunnel + if tunnel is not None and not tunnel.is_running: + frpc_output = "\n".join(tunnel.recent_output) + logger.warning( + f"Tunnel dead for local_addr={local_addr} " + f"tunnel_id={tunnel.tunnel_id}, recreating. " + f"frpc output:\n{frpc_output}" + ) + tunnel.sync_stop() + del self._tunnels[local_addr] + tunnel = None + if tunnel is None: tunnel = Tunnel( local_port=self.gateway_port, @@ -149,7 +163,20 @@ async def get_gateway_tunnel_url(self, local_addr: str | None = None) -> str: ) url = await tunnel.start() self._tunnels[local_addr] = tunnel - logger.debug(f"Prime Tunnel started local_addr={local_addr} url={url}") + logger.debug( + f"Prime Tunnel started local_addr={local_addr} " + f"tunnel_id={tunnel.tunnel_id} url={url}" + ) + + # Lazily start health monitor on first tunnel creation + if ( + self._tunnel_monitor_task is None + or self._tunnel_monitor_task.done() + ): + self._tunnel_monitor_task = asyncio.create_task( + self._tunnel_health_monitor() + ) + return url assert tunnel.url is not None, "Tunnel started but URL is None" @@ -177,7 +204,28 @@ async def poll_job_completion( """Poll until background job completes, capturing output.""" if not self.use_gateway: return await super().poll_job_completion(state, sandbox_id, background_job) # ty: ignore[unresolved-attribute] + + tunnel_local_addr = state.get("tunnel_local_addr") + while True: + # Check tunnel liveness + if tunnel_local_addr: + tunnel = self._tunnels.get(tunnel_local_addr) + if tunnel is not None and not tunnel.is_running: + frpc_output = "\n".join(tunnel.recent_output) + logger.warning( + f"rollout={state.get('rollout_id')} sandbox={sandbox_id} " + f"tunnel_id={tunnel.tunnel_id} stage=tunnel_died " + f"frpc output:\n{frpc_output}" + ) + raise vf.TunnelError( + f"Tunnel process died during rollout " + f"rollout={state.get('rollout_id')} " + f"sandbox={sandbox_id} " + f"tunnel_id={tunnel.tunnel_id}. " + f"frpc output:\n{frpc_output}" + ) + status = await self.sandbox_client.get_background_job( # ty: ignore[unresolved-attribute] sandbox_id, background_job, @@ -268,7 +316,13 @@ async def rollout( state["rollout_base_url"] = ( f"{tunnel_url.rstrip('/')}/v1/rollouts/{state['rollout_id']}" ) - logger.debug(f"rollout={rollout_id} stage=start_tunnel url={tunnel_url}") + tunnel = self._tunnels.get(tunnel_local_addr) + tunnel_id = tunnel.tunnel_id if tunnel else None + state["tunnel_id"] = tunnel_id + logger.debug( + f"rollout={rollout_id} stage=start_tunnel " + f"tunnel_id={tunnel_id} url={tunnel_url}" + ) env_vars = await self.build_env_vars(state) docker_image = await self.get_docker_image(state) # ty: ignore[unresolved-attribute] @@ -378,11 +432,62 @@ async def rollout( return state + async def _tunnel_health_monitor(self, interval: float = 30.0) -> None: + """Background task that checks tunnel liveness and restarts dead tunnels.""" + try: + while True: + await asyncio.sleep(interval) + async with self._tunnel_lock: + dead_addrs = [ + addr for addr, t in self._tunnels.items() if not t.is_running + ] + for addr in dead_addrs: + tunnel = self._tunnels[addr] + frpc_output = "\n".join(tunnel.recent_output) + logger.warning( + f"Health monitor: tunnel dead for local_addr={addr} " + f"tunnel_id={tunnel.tunnel_id}. " + f"frpc output:\n{frpc_output}" + ) + tunnel.sync_stop() + new_tunnel = Tunnel( + local_port=self.gateway_port, + local_addr=addr, + log_level="debug" + if logger.isEnabledFor(logging.DEBUG) + else "info", + ) + url = await new_tunnel.start() + self._tunnels[addr] = new_tunnel + logger.info( + f"Health monitor: restarted tunnel local_addr={addr} " + f"tunnel_id={new_tunnel.tunnel_id} url={url}" + ) + + alive = sum(1 for t in self._tunnels.values() if t.is_running) + total = len(self._tunnels) + logger.debug(f"Health monitor: {alive}/{total} tunnels alive") + except asyncio.CancelledError: + return + @vf.teardown async def teardown_gateway(self): - """Close gateway HTTP client and stop gateway tunnels.""" + """Close gateway HTTP client, cancel health monitor, and stop gateway tunnels.""" if not self.use_gateway: return + + # Cancel health monitor + if ( + self._tunnel_monitor_task is not None + and not self._tunnel_monitor_task.done() + ): + self._tunnel_monitor_task.cancel() + try: + await self._tunnel_monitor_task + except asyncio.CancelledError: + pass + self._tunnel_monitor_task = None + await self._http_client.aclose() async with self._tunnel_lock: tunnels = list(self._tunnels.items()) diff --git a/verifiers/errors.py b/verifiers/errors.py index eec0a566c..e725580e4 100644 --- a/verifiers/errors.py +++ b/verifiers/errors.py @@ -50,6 +50,12 @@ class InfraError(Error): pass +class TunnelError(InfraError): + """Raised when a tunnel process dies or becomes unreachable.""" + + pass + + class SandboxError(InfraError): """Used to catch errors while interacting with sandboxes."""