Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ Verifiers defines a hierarchy of error types under `vf.Error`:
- `vf.ModelError` — errors from model interactions (e.g., `vf.EmptyModelResponseError`)
- `vf.OverlongPromptError` — prompt exceeds model context length
- `vf.ToolError` — tool-related errors (`vf.ToolParseError`, `vf.ToolCallError`)
- `vf.InfraError` — infrastructure errors (e.g., `vf.SandboxError`)
- `vf.InfraError` — infrastructure errors (e.g., `vf.SandboxError`, `vf.TunnelError`)

When a `vf.Error` is raised during a rollout, it is automatically caught and stored in `state["error"]`, triggering the built-in `has_error` stop condition at the next check. This allows rollouts to terminate gracefully rather than crashing.

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ dependencies = [
"nest-asyncio>=1.6.0", # for jupyter notebooks
"openai>=1.108.1",
"openai-agents>=0.0.7",
"prime-tunnel>=0.1.1",
"prime-sandboxes>=0.2.14",
"prime-tunnel>=0.1.4",
"prime-sandboxes>=0.2.16",
"pydantic>=2.11.9",
"requests",
"rich",
Expand Down
200 changes: 200 additions & 0 deletions tests/test_rollout_gateway_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock
Expand All @@ -16,6 +17,7 @@

class FakeTunnel:
instances: list["FakeTunnel"] = []
_next_id: int = 0

def __init__(
self,
Expand All @@ -27,17 +29,35 @@ def __init__(
self.local_addr = local_addr
self.log_level = log_level
self.url: str | None = None
FakeTunnel._next_id += 1
self.tunnel_id: str = f"fake-tunnel-{FakeTunnel._next_id}"
self._is_running: bool = True
self._recent_output: list[str] = ["frpc log line 1", "frpc log line 2"]
self.start_calls = 0
self.stop_calls = 0
FakeTunnel.instances.append(self)

@property
def is_running(self) -> bool:
return self._is_running

@property
def recent_output(self) -> list[str]:
return list(self._recent_output)

async def start(self) -> str:
self.start_calls += 1
self._is_running = True
self.url = "https://unit-test.tunnel.prime.ai"
return self.url

async def stop(self) -> None:
self.stop_calls += 1
self._is_running = False

def sync_stop(self) -> None:
self.stop_calls += 1
self._is_running = False


class GatewayCliAgentEnv(vf.RolloutGatewayMixin, vf.CliAgentEnv):
Expand Down Expand Up @@ -167,6 +187,7 @@ def _handler(request: httpx.Request) -> httpx.Response:
@pytest.mark.asyncio
async def test_cli_agent_env_rollout_uses_gateway_and_tunnel(monkeypatch):
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

tracker = {
Expand Down Expand Up @@ -276,6 +297,7 @@ def _client_factory(*args, **kwargs):
@pytest.mark.asyncio
async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
Expand Down Expand Up @@ -317,6 +339,7 @@ async def test_cli_agent_env_maintains_tunnel_per_local_addr(monkeypatch):
async def test_use_gateway_false_initializes_interception(monkeypatch):
"""With use_gateway=False, interception server is created and gateway is not."""
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
Expand Down Expand Up @@ -348,3 +371,180 @@ async def test_use_gateway_false_initializes_interception(monkeypatch):
await env.teardown_resources() # stops interception (which was never started)

assert len(FakeTunnel.instances) == 0


@pytest.mark.asyncio
async def test_dead_tunnel_recreated_on_get_gateway_tunnel_url(monkeypatch):
"""Dead tunnel is stopped and replaced when get_gateway_tunnel_url is called."""
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Hello"}]],
"answer": [""],
"example_id": [0],
}
)
env = GatewayCliAgentEnv(
run_command="echo run-agent",
dataset=dataset,
rubric=vf.Rubric(),
gateway_port=8000,
)

# Start a tunnel
await env.get_gateway_tunnel_url(local_addr="10.0.0.1")
assert len(FakeTunnel.instances) == 1
original_tunnel = FakeTunnel.instances[0]
assert original_tunnel.start_calls == 1

# Kill the tunnel
original_tunnel._is_running = False

# Requesting URL again should recreate
url2 = await env.get_gateway_tunnel_url(local_addr="10.0.0.1")
assert url2 == "https://unit-test.tunnel.prime.ai"
assert len(FakeTunnel.instances) == 2
assert original_tunnel.stop_calls == 1
new_tunnel = FakeTunnel.instances[1]
assert new_tunnel.start_calls == 1
assert new_tunnel._is_running is True

await env.teardown_gateway()


@pytest.mark.asyncio
async def test_poll_job_completion_raises_tunnel_error_on_dead_tunnel(monkeypatch):
"""poll_job_completion raises TunnelError when the tunnel dies mid-rollout."""
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Hello"}]],
"answer": [""],
"example_id": [0],
}
)
env = GatewayCliAgentEnv(
run_command="echo run-agent",
dataset=dataset,
rubric=vf.Rubric(),
gateway_port=8000,
)

# Start a tunnel
await env.get_gateway_tunnel_url(local_addr="10.0.0.1")
tunnel = FakeTunnel.instances[0]

# Mock sandbox_client.get_background_job to never complete
env.sandbox_client = SimpleNamespace(
get_background_job=AsyncMock(return_value=SimpleNamespace(completed=False)),
)

state = {
"rollout_id": "rollout_test123",
"tunnel_local_addr": "10.0.0.1",
}
background_job = SimpleNamespace(id="job-1")

# Kill tunnel after a brief delay
async def kill_tunnel():
await asyncio.sleep(0.05)
tunnel._is_running = False

asyncio.create_task(kill_tunnel())

with pytest.raises(vf.TunnelError, match="Tunnel process died"):
await env.poll_job_completion(state, "sb-123", background_job)

await env.teardown_gateway()


@pytest.mark.asyncio
async def test_health_monitor_restarts_dead_tunnels(monkeypatch):
"""Background health monitor detects and restarts dead tunnels."""
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Hello"}]],
"answer": [""],
"example_id": [0],
}
)
env = GatewayCliAgentEnv(
run_command="echo run-agent",
dataset=dataset,
rubric=vf.Rubric(),
gateway_port=8000,
)

# Start a tunnel (this also starts the health monitor)
await env.get_gateway_tunnel_url(local_addr="10.0.0.1")
assert env._tunnel_monitor_task is not None
assert not env._tunnel_monitor_task.done()

original_tunnel = FakeTunnel.instances[0]
original_tunnel._is_running = False

# Run the health monitor with a short interval
# Cancel the default one and start one with a short interval
env._tunnel_monitor_task.cancel()
try:
await env._tunnel_monitor_task
except asyncio.CancelledError:
pass

env._tunnel_monitor_task = asyncio.create_task(
env._tunnel_health_monitor(interval=0.05)
)

# Wait for the monitor to detect and restart
await asyncio.sleep(0.2)

assert len(FakeTunnel.instances) == 2
assert original_tunnel.stop_calls == 1
new_tunnel = FakeTunnel.instances[1]
assert new_tunnel.start_calls == 1
assert new_tunnel._is_running is True

await env.teardown_gateway()


@pytest.mark.asyncio
async def test_teardown_gateway_cancels_health_monitor(monkeypatch):
"""teardown_gateway cancels the health monitor task."""
FakeTunnel.instances.clear()
FakeTunnel._next_id = 0
monkeypatch.setattr(rollout_gateway_mixin, "Tunnel", FakeTunnel)

dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Hello"}]],
"answer": [""],
"example_id": [0],
}
)
env = GatewayCliAgentEnv(
run_command="echo run-agent",
dataset=dataset,
rubric=vf.Rubric(),
gateway_port=8000,
)

# Start a tunnel to create the health monitor
await env.get_gateway_tunnel_url(local_addr="10.0.0.1")
monitor_task = env._tunnel_monitor_task
assert monitor_task is not None
assert not monitor_task.done()

await env.teardown_gateway()

assert monitor_task.done()
assert env._tunnel_monitor_task is None
38 changes: 25 additions & 13 deletions verifiers/envs/experimental/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
interception_url: str | None = None,
max_turns: int = -1,
timeout_seconds: float = 3600.0,
poll_interval: float = 2.0,
poll_interval: float = 5.0,
docker_image: str = "python:3.11-slim",
start_command: str = "tail -f /dev/null",
cpu_cores: int = 1,
Expand Down Expand Up @@ -115,16 +115,19 @@ def init_interception(
self._tunnel_lock = asyncio.Lock()
self._interception_server = InterceptionServer(port=interception_port)

@property
def _server(self) -> InterceptionServer:
assert self._interception_server is not None
return self._interception_server

async def get_tunnel_url(self) -> str:
"""Get tunnel URL, starting the tunnel if needed."""
"""Get tunnel URL, starting the tunnel if needed. Recreates dead tunnels."""
async with self._tunnel_lock:
if self._tunnel is not None and not self._tunnel.is_running:
frpc_output = "\n".join(self._tunnel.recent_output)
logger.warning(
f"Tunnel process died, recreating. frpc output:\n{frpc_output}"
)
self._tunnel.sync_stop()
self._tunnel = None

if self._tunnel is None:
port = self._server.port
port = self._interception_server.port # ty: ignore[unresolved-attribute]
if logger.isEnabledFor(logging.DEBUG):
self._tunnel = Tunnel(
local_port=port,
Expand All @@ -146,7 +149,7 @@ async def setup_state(self, state: State) -> State:
rollout_id = f"rollout_{uuid.uuid4().hex[:8]}"
state["rollout_id"] = rollout_id

await self._server.start()
await self._interception_server.start() # ty: ignore[unresolved-attribute]

if self.interception_url is None:
tunnel_url = await self.get_tunnel_url()
Expand Down Expand Up @@ -180,7 +183,7 @@ async def setup_state(self, state: State) -> State:
await self.create_sandbox(state, sandbox_request)

# Register rollout for interception
request_id_queue = self._server.register_rollout(rollout_id)
request_id_queue = self._interception_server.register_rollout(rollout_id) # ty: ignore[unresolved-attribute]
state["request_id_queue"] = request_id_queue
state["agent_completed"] = False

Expand Down Expand Up @@ -281,11 +284,18 @@ async def get_prompt_messages(self, state: State) -> Messages:
)
# Got a request, proceed normally
state["current_request_id"] = request_id
intercept = self._server.intercepts[request_id]
intercept = self._interception_server.intercepts[request_id] # ty: ignore[unresolved-attribute]
return intercept["messages"]

except asyncio.TimeoutError:
# No request yet, check if agent finished or timed out
# No request yet — check tunnel liveness first
if self._tunnel is not None and not self._tunnel.is_running:
frpc_output = "\n".join(self._tunnel.recent_output)
raise vf.TunnelError(
f"Tunnel process died during rollout. "
f"frpc output:\n{frpc_output}"
)
# Then check if agent finished or timed out
if await self.check_agent_completed(state):
state["agent_completed"] = True
return []
Expand Down Expand Up @@ -367,7 +377,9 @@ async def get_model_response(
)

request_id = state.get("current_request_id")
intercept = self._server.intercepts.get(request_id) if request_id else None
intercept = (
self._interception_server.intercepts.get(request_id) if request_id else None # ty: ignore[unresolved-attribute]
)

if intercept:
# Always use the configured model from state, not the intercepted model
Expand Down
Loading
Loading