diff --git a/tests/test_cancel_protocol.py b/tests/test_cancel_protocol.py new file mode 100644 index 000000000..454d5ca91 --- /dev/null +++ b/tests/test_cancel_protocol.py @@ -0,0 +1,266 @@ +"""Tests for the cancel request protocol between ZMQ client and server.""" + +import asyncio +import time +from unittest.mock import AsyncMock, patch + +import msgpack +import pytest + +from verifiers.workers.client.zmq_env_client import ZMQEnvClient +from verifiers.workers.server.zmq_env_server import ZMQEnvServer +from verifiers.workers.types import ( + CancelRequest, + HealthRequest, + HealthResponse, + PendingRequest, +) + + +class TestCancelRequestType: + """Tests for CancelRequest serialization and validation.""" + + def test_cancel_request_fields(self): + req = CancelRequest(cancel_request_ids=["abc123", "def456"]) + assert req.request_type == "cancel" + assert req.cancel_request_ids == ["abc123", "def456"] + + def test_cancel_request_roundtrip(self): + req = CancelRequest(cancel_request_ids=["abc123"]) + dumped = req.model_dump(mode="python") + restored = CancelRequest.model_validate(dumped) + assert restored.cancel_request_ids == ["abc123"] + assert restored.request_type == "cancel" + + def test_cancel_request_msgpack_roundtrip(self): + req = CancelRequest(cancel_request_ids=["a", "b", "c"]) + packed = msgpack.packb(req.model_dump(mode="python"), use_bin_type=True) + unpacked = msgpack.unpackb(packed, raw=False) + assert unpacked["request_type"] == "cancel" + assert unpacked["cancel_request_ids"] == ["a", "b", "c"] + + +class TestClientSendCancel: + """Tests for client-side cancel message sending.""" + + @pytest.mark.asyncio + async def test_send_cancel_sends_message(self): + """send_cancel() sends a properly formatted cancel message.""" + client = ZMQEnvClient( + address="tcp://127.0.0.1:5555", + health_check_interval=0, + ) + sent_frames = [] + + async def capture_send(frames): + sent_frames.append(frames) + + with patch.object(client.socket, "send_multipart", new=capture_send): + await client.send_cancel(["req1", "req2"]) + + assert len(sent_frames) == 1 + frames = sent_frames[0] + assert len(frames) == 2 # [cancel_id, payload] + + payload = msgpack.unpackb(frames[1], raw=False) + assert payload["request_type"] == "cancel" + assert payload["cancel_request_ids"] == ["req1", "req2"] + + await client.close() + + @pytest.mark.asyncio + async def test_send_cancel_empty_list_is_noop(self): + """send_cancel() with empty list does nothing.""" + client = ZMQEnvClient( + address="tcp://127.0.0.1:5555", + health_check_interval=0, + ) + + send_mock = AsyncMock() + with patch.object(client.socket, "send_multipart", new=send_mock): + await client.send_cancel([]) + + send_mock.assert_not_called() + await client.close() + + @pytest.mark.asyncio + async def test_send_cancel_swallows_errors(self): + """send_cancel() does not raise on send failure.""" + client = ZMQEnvClient( + address="tcp://127.0.0.1:5555", + health_check_interval=0, + ) + + async def fail_send(frames): + raise RuntimeError("socket closed") + + with patch.object(client.socket, "send_multipart", new=fail_send): + # Should not raise + await client.send_cancel(["req1"]) + + await client.close() + + +class TestClientCancelledError: + """Tests for CancelledError handling in send_request.""" + + @pytest.mark.asyncio + async def test_cancelled_error_cleans_up_and_sends_cancel(self): + """CancelledError during send_request cleans up pending entry and sends cancel.""" + client = ZMQEnvClient( + address="tcp://127.0.0.1:5555", + health_check_interval=0, + ) + + cancel_ids_sent = [] + + async def mock_send_multipart(frames): + # After send, schedule cancellation of the task + async def cancel_after(): + await asyncio.sleep(0.05) + # Cancel the pending future + async with client.pending_lock: + for pending in client.pending_requests.values(): + if not pending.future.done(): + pending.future.cancel() + + asyncio.create_task(cancel_after()) + + async def capture_send_cancel(request_ids): + cancel_ids_sent.extend(request_ids) + + with ( + patch.object(client.socket, "connect"), + patch.object(client.socket, "send_multipart", new=mock_send_multipart), + patch.object(client, "send_cancel", new=capture_send_cancel), + ): + await client.ensure_started() + + with pytest.raises(asyncio.CancelledError): + await client.send_request(HealthRequest(), HealthResponse, timeout=5.0) + + # Pending request should have been cleaned up + assert len(client.pending_requests) == 0 + # Cancel should have been sent to server + assert len(cancel_ids_sent) == 1 + + await client.close() + + +class TestCancelAllPendingSendsCancel: + """Tests for cancel_all_pending sending cancel messages to the server.""" + + @pytest.mark.asyncio + async def test_cancel_all_pending_sends_cancel_to_server(self): + """cancel_all_pending() sends a cancel message for all pending request IDs.""" + client = ZMQEnvClient( + address="tcp://127.0.0.1:5555", + health_check_interval=0, + ) + + # Add pending requests + future1 = asyncio.Future() + future2 = asyncio.Future() + async with client.pending_lock: + client.pending_requests["req_aaa"] = PendingRequest( + request_id="req_aaa", + request=HealthRequest(), + submitted_at=time.time(), + timeout=10.0, + future=future1, + ) + client.pending_requests["req_bbb"] = PendingRequest( + request_id="req_bbb", + request=HealthRequest(), + submitted_at=time.time(), + timeout=10.0, + future=future2, + ) + + cancel_ids_sent = [] + + async def capture_send_cancel(request_ids): + cancel_ids_sent.extend(request_ids) + + with patch.object(client, "send_cancel", new=capture_send_cancel): + cancelled = await client.cancel_all_pending( + "test cancel", use_cancelled=True + ) + + assert len(cancelled) == 2 + assert set(cancel_ids_sent) == {"req_aaa", "req_bbb"} + + await client.close() + + +class TestServerHandleCancel: + """Tests for server-side cancel handling.""" + + @pytest.mark.asyncio + async def test_handle_cancel_cancels_tracked_task(self): + """_handle_cancel() cancels tasks tracked in request_tasks.""" + task = asyncio.create_task(asyncio.sleep(100)) + + server = ZMQEnvServer.__new__(ZMQEnvServer) + server.request_tasks = {"req123": task} + + import logging + + server.logger = logging.getLogger("test") + + raw = {"request_type": "cancel", "cancel_request_ids": ["req123"]} + server._handle_cancel(raw) + + # Task should have cancellation requested + assert task.cancelling() + assert "req123" not in server.request_tasks + + # Let the event loop process the cancellation + with pytest.raises(asyncio.CancelledError): + await task + + def test_handle_cancel_ignores_unknown_ids(self): + """_handle_cancel() silently ignores request IDs not in request_tasks.""" + server = ZMQEnvServer.__new__(ZMQEnvServer) + server.request_tasks = {} + + import logging + + server.logger = logging.getLogger("test") + + raw = {"request_type": "cancel", "cancel_request_ids": ["nonexistent"]} + # Should not raise + server._handle_cancel(raw) + + @pytest.mark.asyncio + async def test_handle_cancel_ignores_already_done_tasks(self): + """_handle_cancel() does not error on already-completed tasks.""" + future = asyncio.get_running_loop().create_future() + future.set_result(None) + + server = ZMQEnvServer.__new__(ZMQEnvServer) + server.request_tasks = {"req_done": future} + + import logging + + server.logger = logging.getLogger("test") + + raw = {"request_type": "cancel", "cancel_request_ids": ["req_done"]} + server._handle_cancel(raw) + + # Should have been popped from the dict + assert "req_done" not in server.request_tasks + + def test_handle_cancel_invalid_request(self): + """_handle_cancel() logs warning on invalid cancel request.""" + server = ZMQEnvServer.__new__(ZMQEnvServer) + server.request_tasks = {} + + import logging + + server.logger = logging.getLogger("test") + + # Missing required field + raw = {"request_type": "cancel"} + # Should not raise + server._handle_cancel(raw) diff --git a/tests/test_env_crash_recovery.py b/tests/test_env_crash_recovery.py index 5dc58982c..7dc01185f 100644 --- a/tests/test_env_crash_recovery.py +++ b/tests/test_env_crash_recovery.py @@ -4,6 +4,7 @@ import time from unittest.mock import patch +import msgpack import pytest from verifiers.workers.client.zmq_env_client import ZMQEnvClient @@ -92,7 +93,16 @@ async def test_retry_after_recovery(self): attempt_count = 0 - async def mock_send(*args, **kwargs): + async def mock_send(frames, *args, **kwargs): + # Ignore cancel messages sent by send_cancel() + if len(frames) == 2: + try: + payload = msgpack.unpackb(frames[1], raw=False) + if payload.get("request_type") == "cancel": + return + except Exception: + pass + nonlocal attempt_count attempt_count += 1 diff --git a/verifiers/workers/__init__.py b/verifiers/workers/__init__.py index e7def4a49..c1f5fe8cc 100644 --- a/verifiers/workers/__init__.py +++ b/verifiers/workers/__init__.py @@ -3,6 +3,7 @@ from verifiers.workers.types import ( BaseRequest, BaseResponse, + CancelRequest, HealthRequest, HealthResponse, RunGroupRequest, @@ -15,6 +16,7 @@ # types "BaseRequest", "BaseResponse", + "CancelRequest", "HealthRequest", "HealthResponse", "RunRolloutRequest", diff --git a/verifiers/workers/client/zmq_env_client.py b/verifiers/workers/client/zmq_env_client.py index e839c6726..ba594d33c 100644 --- a/verifiers/workers/client/zmq_env_client.py +++ b/verifiers/workers/client/zmq_env_client.py @@ -16,6 +16,7 @@ from verifiers.workers.types import ( BaseRequest, BaseResponseT, + CancelRequest, HealthRequest, HealthResponse, PendingRequest, @@ -167,6 +168,7 @@ async def cancel_all_pending( # Collect metadata before clearing cancelled_requests = list(self.pending_requests.values()) + request_ids = [r.request_id for r in cancelled_requests] for pending_req in cancelled_requests: if not pending_req.future.done(): @@ -178,8 +180,36 @@ async def cancel_all_pending( # Clear tracking dict self.pending_requests.clear() + # Notify server to stop work on these requests (best-effort) + await self.send_cancel(request_ids) + return cancelled_requests + async def send_cancel(self, request_ids: list[str]) -> None: + """Send a cancel message to the server for the given request IDs. + + Fire-and-forget: failures are logged but do not raise. + """ + if not request_ids: + return + try: + cancel_req = CancelRequest(cancel_request_ids=request_ids) + payload = cast( + bytes, + msgpack.packb( + cancel_req.model_dump(mode="python", warnings=False), + default=msgpack_encoder, + use_bin_type=True, + ), + ) + cancel_id = uuid.uuid4().hex + await self.socket.send_multipart([cancel_id.encode(), payload]) + self.logger.debug( + f"Sent cancel for {len(request_ids)} request(s) to env server {self.name}" + ) + except Exception as e: + self.logger.debug(f"Failed to send cancel to env server {self.name}: {e}") + async def receive_loop(self): """Continuously receive responses from environment servers.""" while True: @@ -290,6 +320,13 @@ async def send_request( try: raw_response = await asyncio.wait_for(future, timeout=effective_timeout) + except asyncio.CancelledError: + # Task was cancelled externally (e.g. scheduler timeout). + # Clean up our pending entry and tell the server to stop work. + async with self.pending_lock: + self.pending_requests.pop(request_id, None) + await self.send_cancel([request_id]) + raise except asyncio.TimeoutError: # Clean up on timeout async with self.pending_lock: diff --git a/verifiers/workers/server/zmq_env_server.py b/verifiers/workers/server/zmq_env_server.py index 8b9f8d256..66b69a4f3 100644 --- a/verifiers/workers/server/zmq_env_server.py +++ b/verifiers/workers/server/zmq_env_server.py @@ -11,6 +11,7 @@ from verifiers.workers.server.env_server import EnvServer from verifiers.workers.types import ( BaseResponse, + CancelRequest, RunGroupRequest, RunRolloutRequest, ) @@ -70,6 +71,9 @@ def __init__(self, *args, address: str = "tcp://127.0.0.1:5000", **kwargs): self.socket.setsockopt(zmq.LINGER, 0) # discard msgs on socket close self.socket.bind(self.address) + # Map request_id → asyncio.Task for cancel support + self.request_tasks: dict[str, asyncio.Task] = {} + # Health check runs in a separate process (immune to env workload) self.stop_health = mp.Event() self.health_process: mp.Process | None = None @@ -121,15 +125,35 @@ async def serve(self, stop_event: asyncio.Event | None = None) -> None: ) continue - client_id, request_id, payload_bytes = frames + client_id, request_id_bytes, payload_bytes = frames + request_id = request_id_bytes.decode() + + # Peek at request type to handle cancels inline + try: + raw = msgpack.unpackb(payload_bytes, raw=False) + except Exception: + self.logger.warning( + f"Failed to deserialize message {request_id[:7]}" + ) + continue + + if raw.get("request_type") == "cancel": + self._handle_cancel(raw) + continue # Process in background, tracking the task for cleanup task = asyncio.create_task( - self.process_request(client_id, request_id, payload_bytes) + self.process_request(client_id, request_id, raw) ) self.pending_tasks.add(task) task.add_done_callback(self.pending_tasks.discard) + # Track request_id → task for cancel support + self.request_tasks[request_id] = task + task.add_done_callback( + lambda _t, _rid=request_id: self.request_tasks.pop(_rid, None) + ) + except asyncio.CancelledError: break except Exception as e: @@ -159,6 +183,7 @@ async def close(self): task.cancel() await asyncio.gather(*self.pending_tasks, return_exceptions=True) self.pending_tasks.clear() + self.request_tasks.clear() await self.close_cached_clients() @@ -183,18 +208,29 @@ async def log_stats_loop(self, interval: float = 10.0): self.logger.info(message) + def _handle_cancel(self, raw: dict) -> None: + """Cancel server-side tasks for the given request IDs.""" + try: + cancel_req = CancelRequest.model_validate(raw) + except Exception as e: + self.logger.warning(f"Invalid cancel request: {e}") + return + + for rid in cancel_req.cancel_request_ids: + task = self.request_tasks.pop(rid, None) + if task is not None and not task.done(): + task.cancel() + self.logger.debug(f"Cancelled server task for request {rid[:7]}") + async def process_request( self, client_id: bytes, - request_id_bytes: bytes, - payload_bytes: bytes, + request_id: str, + raw: dict, ): - request_id = request_id_bytes.decode() response: BaseResponse try: - # deserialize request - raw = msgpack.unpackb(payload_bytes, raw=False) request_type = raw.get("request_type") request_id = raw.get("request_id", request_id) diff --git a/verifiers/workers/types.py b/verifiers/workers/types.py index 25ddbbc94..9227b48f7 100644 --- a/verifiers/workers/types.py +++ b/verifiers/workers/types.py @@ -68,6 +68,11 @@ class RunGroupRequest(BaseRequest): state_columns: list[str] | None +class CancelRequest(BaseRequest): + request_type: Literal["cancel"] = "cancel" + cancel_request_ids: list[str] + + class RunGroupResponse(BaseResponse): outputs: list[CoercedRolloutOutput] | None = None