Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,27 @@ async def test_failed_abort(tmp_socket):
await client.check_health()

# Trigger an abort on the client side.
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")

# Future generation requests will now fail
# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())

# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored

await abort_task

# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class RPCAbortRequest:
request_id: str


class RPCHealthRequest:
pass


class RPCStartupRequest(Enum):
IS_SERVER_READY = 1

Expand All @@ -52,7 +56,8 @@ class RPCStartupResponse:
tracing_enabled: bool


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

Expand Down
51 changes: 29 additions & 22 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
Expand Down Expand Up @@ -94,9 +95,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")

# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
Expand All @@ -123,28 +124,34 @@ def get_data_socket(self) -> Iterator[Socket]:
finally:
socket.close(linger=0)

async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.

The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.

The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
"""

try:
while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break

if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)

# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
else:
# Heartbeat received- check the message
# Server sent a health status message unprompted.
await self._check_success(
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
error_message="Health check failed.",
socket=self.health_socket)

logger.debug("Heartbeat successful.")
logger.debug("Health probe successful.")

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
Expand Down Expand Up @@ -227,7 +234,7 @@ async def setup(self):

# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))

def close(self):
"""Destroy the ZeroMQ Context."""
Expand Down
77 changes: 17 additions & 60 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -17,10 +15,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -93,30 +91,16 @@ def __init__(self,
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

# Error state.
self._errored_with: Optional[BaseException] = None

# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
Expand Down Expand Up @@ -147,8 +131,6 @@ def start(self):
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -162,7 +144,6 @@ def start(self):
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine

Expand Down Expand Up @@ -201,11 +182,9 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")

Expand All @@ -221,6 +200,7 @@ def run_engine_loop(self):

def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""

try:
return self.engine.step()
except SystemExit:
Expand Down Expand Up @@ -249,9 +229,10 @@ def handle_new_input(self):
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
raise ValueError("Unknown RPCRequest Type: {request}")
Comment on lines +232 to +235
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing f-string prefix causes incorrect error message.

Line 235 is missing the f prefix, so {request} will be printed literally instead of the actual request value.

Proposed fix
-                    raise ValueError("Unknown RPCRequest Type: {request}")
+                    raise ValueError(f"Unknown RPCRequest Type: {request}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
raise ValueError("Unknown RPCRequest Type: {request}")
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError(f"Unknown RPCRequest Type: {request}")
🤖 Prompt for AI Agents
In `@vllm/engine/multiprocessing/engine.py` around lines 232 - 235, The error
message in the RPC request handling branch prints the literal "{request}"
because the f-string prefix is missing; update the raise in the else branch that
currently references RPCHealthRequest and _handle_health_request to use an
f-string (or otherwise format the request) so the actual request object is
interpolated into the ValueError message (e.g., change "Unknown RPCRequest Type:
{request}" to an f-string including request).


except Exception as e:
self._set_errored(e)
Expand Down Expand Up @@ -298,32 +279,13 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()

logger.debug("Exiting MQLLMEngine heartbeat thread")

def _heartbeat(self):
# Send unhealthy if engine has already errored
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)

# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))

else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
Comment on lines +282 to +288
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing early return after sending unhealthy status.

If self._errored_with is not None, the method sends an unhealthy status but then continues to call self.engine.check_health(). This could either succeed (misleading) or raise an exception that overwrites the original error context. Add a return after _send_unhealthy.

Proposed fix
     def _handle_health_request(self):
         if self._errored_with is not None:
             self._send_unhealthy(self._errored_with)
+            return
 
         # Raises error if unhealthy.
         self.engine.check_health()
         self._send_healthy()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
return
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
🤖 Prompt for AI Agents
In `@vllm/engine/multiprocessing/engine.py` around lines 282 - 288, In
_handle_health_request, when self._errored_with is not None the code calls
self._send_unhealthy(self._errored_with) but then continues to call
self.engine.check_health() which can mask or overwrite the original error;
modify _handle_health_request so that after calling
_send_unhealthy(self._errored_with) it immediately returns (i.e., add an early
return) to avoid further health checks and preserve the original error context;
reference: function _handle_health_request, attribute _errored_with, methods
_send_unhealthy, engine.check_health, and _send_healthy.


def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
Expand All @@ -333,14 +295,12 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand All @@ -353,9 +313,6 @@ def _set_errored(self, e: BaseException):
if self._errored_with is None:
self._errored_with = e

def _alive(self):
self._last_alive_time = time.time()


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
Expand Down