Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,5 @@ dev = [
"pytest-cov>=5.0",
"pyright>=1.1",
"ruff>=0.8",
"openai>=2.30.0",
]
2 changes: 1 addition & 1 deletion src/mlx_stack/core/onboarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def generate_config(
if port == litellm_port:
port += 1

# TODO(#17): re-enable continuous_batching (waybarrios/vllm-mlx#211).
vllm_flags: dict[str, Any] = {
"continuous_batching": True,
"use_paged_cache": True,
}
if mapping.model.tool_calling:
Expand Down
3 changes: 2 additions & 1 deletion src/mlx_stack/core/stack_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ def build_vllm_flags(entry: CatalogEntry) -> dict[str, Any]:
Returns:
A dict of vllm flags.
"""
# TODO(#17): re-enable continuous_batching once vllm-mlx ships a fix for
# the missing return in load_model_with_fallback (waybarrios/vllm-mlx#211).
flags: dict[str, Any] = {
"continuous_batching": True,
"use_paged_cache": True,
}

Expand Down
32 changes: 29 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pathlib import Path
from typing import Any

import httpx
import psutil
import pytest
import yaml
Expand Down Expand Up @@ -115,7 +116,12 @@ def wait_for_port_free(port: int, timeout: float = 10.0) -> bool:


def kill_processes_on_port(port: int) -> None:
"""Best-effort kill any process bound to the given port via lsof."""
"""Best-effort kill any process bound to the given port via lsof.

Excludes the current process to avoid killing pytest itself when it
has client connections to the port being cleaned up.
"""
my_pid = os.getpid()
try:
result = subprocess.run(
["lsof", "-ti", f":{port}"],
Expand All @@ -126,7 +132,7 @@ def kill_processes_on_port(port: int) -> None:
if result.returncode == 0 and result.stdout.strip():
for pid_str in result.stdout.strip().split("\n"):
pid_str = pid_str.strip()
if pid_str.isdigit():
if pid_str.isdigit() and int(pid_str) != my_pid:
with contextlib.suppress(OSError):
os.kill(int(pid_str), signal.SIGKILL)
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
Expand Down Expand Up @@ -330,6 +336,21 @@ def start_vllm(
total_timeout=timeout,
)

# Warmup: the health check only proves the HTTP server is up.
# The first inference triggers MLX weight loading and JIT compilation,
# which can take significantly longer. Send a throwaway request so
# that callers' inference timeouts measure generation, not cold start.
with contextlib.suppress(httpx.TimeoutException, httpx.HTTPError):
httpx.post(
f"http://127.0.0.1:{port}/v1/chat/completions",
json={
"model": model_source,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 1,
},
timeout=timeout,
)

return managed

def start_litellm(
Expand Down Expand Up @@ -404,11 +425,16 @@ def stop_all(self) -> None:
if svc.pid is not None:
with contextlib.suppress(OSError):
os.kill(svc.pid, signal.SIGKILL)

time.sleep(1)

# Kill any orphaned child processes still holding the port
for svc in self._services:
kill_processes_on_port(svc.port)

# Wait for ports to be freed
for svc in self._services:
wait_for_port_free(svc.port, timeout=10.0)
wait_for_port_free(svc.port, timeout=15.0)

# Clean up PID files
pids_dir = self._mlx_home / "pids"
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_model_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _min_memory_for_entry(entry: CatalogEntry) -> float:

def _build_vllm_flags(entry: CatalogEntry) -> dict:
"""Build vllm flags based on model capabilities."""
# TODO(#17): re-enable continuous_batching (waybarrios/vllm-mlx#211)
flags: dict = {
"continuous_batching": True,
"use_paged_cache": True,
}
if entry.capabilities.tool_calling:
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_tool_calling(
{"role": "user", "content": "What is the weather in San Francisco?"},
],
"tools": [tool_definition],
"max_tokens": 100,
"max_tokens": 500,
},
timeout=120.0,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_stack_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def test_full_lifecycle(

# After context manager exits, services are stopped
# Verify ports are freed
assert wait_for_port_free(vllm_port, timeout=10.0), (
assert wait_for_port_free(vllm_port, timeout=15.0), (
f"Port {vllm_port} still bound after shutdown"
)
assert wait_for_port_free(litellm_port, timeout=10.0), (
assert wait_for_port_free(litellm_port, timeout=15.0), (
f"Port {litellm_port} still bound after shutdown"
)

Expand Down
5 changes: 1 addition & 4 deletions tests/unit/test_cli_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ class TestVLLMFlags:
"""Tests for vllm_flags generation."""

def test_base_flags_always_present(self) -> None:
"""continuous_batching and use_paged_cache always present."""
"""use_paged_cache always present."""
entry = _make_entry(tool_calling=False, thinking=False)
flags = build_vllm_flags(entry)
assert flags["continuous_batching"] is True
assert flags["use_paged_cache"] is True

def test_tool_calling_flags(self) -> None:
Expand Down Expand Up @@ -303,7 +302,6 @@ def test_combined_tool_and_thinking_flags(self) -> None:
reasoning_parser="nemotron",
)
flags = build_vllm_flags(entry)
assert flags["continuous_batching"] is True
assert flags["use_paged_cache"] is True
assert flags["enable_auto_tool_choice"] is True
assert flags["tool_call_parser"] == "hermes"
Expand Down Expand Up @@ -1132,7 +1130,6 @@ def test_vllm_flags_in_generated_stack(self, mlx_stack_home: Path) -> None:

for tier in result["stack"]["tiers"]:
flags = tier["vllm_flags"]
assert flags["continuous_batching"] is True
assert flags["use_paged_cache"] is True


Expand Down
Loading
Loading