From 9ef98a5a797966fba06701b706bc0dbac0997e44 Mon Sep 17 00:00:00 2001 From: Wes Eklund Date: Sat, 4 Apr 2026 10:07:36 -0400 Subject: [PATCH 1/3] refactor(tests): replace brittle mock-heavy tests with behavioral tests and shared factories Replace 10-deep @patch stacks in TestRunUp with a FakeServiceLayer test double that mocks only at the OS boundary (subprocess, HTTP, signals). Consolidate ~50 duplicate helper functions (_make_entry x9, _make_stack_yaml x5, etc.) into tests/factories.py. Structure all tests with AAA comments. Net result: -577 lines, same 1481 tests passing, tests now assert behavior not mock wiring. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/factories.py | 298 ++++++ tests/fakes.py | 190 ++++ tests/unit/conftest.py | 106 ++ tests/unit/test_cli_down.py | 246 ++--- tests/unit/test_cli_init.py | 288 +++--- tests/unit/test_cli_logs.py | 73 +- tests/unit/test_cli_models.py | 172 ++- tests/unit/test_cli_pull.py | 393 +++++-- tests/unit/test_cli_recommend.py | 377 ++++--- tests/unit/test_cli_status.py | 306 +++--- tests/unit/test_cli_up.py | 1492 +++++++++------------------ tests/unit/test_cross_area.py | 162 +-- tests/unit/test_lifecycle_fixes.py | 224 ++-- tests/unit/test_log_rotation.py | 52 +- tests/unit/test_log_viewer.py | 79 +- tests/unit/test_models.py | 66 +- tests/unit/test_ops_cross_area.py | 47 +- tests/unit/test_robustness_fixes.py | 136 +-- tests/unit/test_scoring.py | 186 ++-- tests/unit/test_watchdog.py | 180 ++-- 20 files changed, 2545 insertions(+), 2528 deletions(-) create mode 100644 tests/factories.py create mode 100644 tests/fakes.py create mode 100644 tests/unit/conftest.py diff --git a/tests/factories.py b/tests/factories.py new file mode 100644 index 0000000..afa8e9d --- /dev/null +++ b/tests/factories.py @@ -0,0 +1,298 @@ +"""Shared test data factories for mlx-stack tests. + +Consolidates all duplicate helper functions (_make_entry, _make_stack_yaml, etc.) +into a single module. Every test file imports from here instead of defining +its own copy. + +Factory functions are plain functions (not pytest fixtures) so they can be +called with different arguments in the same test. Filesystem helpers accept +a ``Path`` parameter for the target directory. +""" + +from __future__ import annotations + +import gzip +from pathlib import Path +from typing import Any + +import yaml + +from mlx_stack.core.catalog import ( + BenchmarkResult, + Capabilities, + CatalogEntry, + QualityScores, + QuantSource, +) +from mlx_stack.core.hardware import HardwareProfile + +# --------------------------------------------------------------------------- # +# Catalog data factories +# --------------------------------------------------------------------------- # + + +def make_entry( + model_id: str = "test-model", + name: str | None = None, + family: str = "Test", + params_b: float = 8.0, + architecture: str = "transformer", + quality_overall: int = 70, + quality_coding: int = 65, + quality_reasoning: int = 60, + quality_instruction: int = 72, + tool_calling: bool = True, + tool_call_parser: str | None = "hermes", + thinking: bool = False, + reasoning_parser: str | None = None, + vision: bool = False, + benchmarks: dict[str, BenchmarkResult] | None = None, + memory_gb: float = 5.5, + tags: list[str] | None = None, + gated: bool = False, + disk_size_gb: float = 4.5, + sources: dict[str, QuantSource] | None = None, +) -> CatalogEntry: + """Create a ``CatalogEntry`` for testing. + + Unified superset of all former ``_make_entry`` helpers. Every parameter + has a sensible default so callers only override what they care about. + """ + if name is None: + name = f"Test {model_id}" + + if sources is None: + sources = { + "int4": QuantSource( + hf_repo=f"mlx-community/{model_id}-4bit", + disk_size_gb=disk_size_gb, + ), + } + + if benchmarks is None: + benchmarks = { + "m4-max-128": BenchmarkResult( + prompt_tps=100.0, gen_tps=50.0, memory_gb=memory_gb, + ), + } + + return CatalogEntry( + id=model_id, + name=name, + family=family, + params_b=params_b, + architecture=architecture, + min_mlx_lm_version="0.22.0", + sources=sources, + capabilities=Capabilities( + tool_calling=tool_calling, + tool_call_parser=tool_call_parser if tool_calling else None, + thinking=thinking, + reasoning_parser=reasoning_parser if thinking else None, + vision=vision, + ), + quality=QualityScores( + overall=quality_overall, + coding=quality_coding, + reasoning=quality_reasoning, + instruction_following=quality_instruction, + ), + benchmarks=benchmarks, + tags=tags if tags is not None else [], + gated=gated, + ) + + +def make_test_catalog() -> list[CatalogEntry]: + """Standard two-model catalog: big-model (49B) + fast-model (3B). + + Matches the pattern used in ``test_cli_up``, ``test_lifecycle_fixes``, + and ``test_cross_area``. + """ + return [ + make_entry("big-model", params_b=49.0, memory_gb=30.0), + make_entry("fast-model", params_b=3.0, memory_gb=2.0), + ] + + +# --------------------------------------------------------------------------- # +# Hardware profile factory +# --------------------------------------------------------------------------- # + + +def make_profile( + chip: str = "Apple M4 Max", + gpu_cores: int = 40, + memory_gb: int = 128, + bandwidth_gbps: float = 546.0, + is_estimate: bool = False, +) -> HardwareProfile: + """Create a ``HardwareProfile`` for testing.""" + return HardwareProfile( + chip=chip, + gpu_cores=gpu_cores, + memory_gb=memory_gb, + bandwidth_gbps=bandwidth_gbps, + is_estimate=is_estimate, + ) + + +def make_small_profile() -> HardwareProfile: + """M4 Pro 24 GB — memory-constrained profile for budget tests.""" + return make_profile( + chip="Apple M4 Pro", + gpu_cores=20, + memory_gb=24, + bandwidth_gbps=273.0, + ) + + +# --------------------------------------------------------------------------- # +# Stack YAML factories +# --------------------------------------------------------------------------- # + + +def make_stack_yaml( + tiers: list[dict[str, Any]] | None = None, + schema_version: int = 1, + litellm_port: int = 4000, +) -> dict[str, Any]: + """Create a stack definition dict for testing. + + The ``litellm_port`` parameter is stored as top-level metadata (matches + the ``test_cli_up`` variant). The default two-tier stack uses + ``standard`` (port 8000) and ``fast`` (port 8001). + """ + if tiers is None: + tiers = [ + { + "name": "standard", + "model": "big-model", + "quant": "int4", + "source": "mlx-community/big-model-4bit", + "port": 8000, + "vllm_flags": { + "continuous_batching": True, + "use_paged_cache": True, + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + }, + }, + { + "name": "fast", + "model": "fast-model", + "quant": "int4", + "source": "mlx-community/fast-model-4bit", + "port": 8001, + "vllm_flags": { + "continuous_batching": True, + "use_paged_cache": True, + }, + }, + ] + return { + "schema_version": schema_version, + "name": "default", + "hardware_profile": "m4-max-128", + "intent": "balanced", + "created": "2026-03-24T00:00:00+00:00", + "tiers": tiers, + } + + +def write_stack_yaml( + mlx_stack_home: Path, + stack: dict[str, Any] | None = None, +) -> Path: + """Write a stack YAML file to ``mlx_stack_home`` and return its path.""" + if stack is None: + stack = make_stack_yaml() + stacks_dir = mlx_stack_home / "stacks" + stacks_dir.mkdir(parents=True, exist_ok=True) + stack_path = stacks_dir / "default.yaml" + stack_path.write_text(yaml.dump(stack, default_flow_style=False)) + return stack_path + + +def write_litellm_yaml(mlx_stack_home: Path) -> Path: + """Write a minimal ``litellm.yaml`` config and return its path.""" + litellm_config = { + "model_list": [ + { + "model_name": "standard", + "litellm_params": { + "model": "openai/big-model", + "api_base": "http://localhost:8000/v1", + "api_key": "dummy", + }, + }, + ], + } + litellm_path = mlx_stack_home / "litellm.yaml" + litellm_path.write_text(yaml.dump(litellm_config, default_flow_style=False)) + return litellm_path + + +# --------------------------------------------------------------------------- # +# PID file helpers +# --------------------------------------------------------------------------- # + + +def create_pid_file( + mlx_stack_home: Path, + service_name: str, + pid: int | str = 12345, +) -> Path: + """Create a PID file in ``mlx_stack_home/pids/`` and return its path.""" + pids_dir = mlx_stack_home / "pids" + pids_dir.mkdir(parents=True, exist_ok=True) + pid_path = pids_dir / f"{service_name}.pid" + pid_path.write_text(str(pid)) + return pid_path + + +# --------------------------------------------------------------------------- # +# Log file helpers +# --------------------------------------------------------------------------- # + + +def create_log_file( + logs_dir: Path, + service: str, + content: str = "", + size_mb: float = 0, +) -> Path: + """Create a log file. + + If *content* is provided it is written as-is. Otherwise *size_mb* bytes + of filler are written. Returns the log file path. + """ + logs_dir.mkdir(parents=True, exist_ok=True) + log_path = logs_dir / f"{service}.log" + if content: + log_path.write_text(content) + elif size_mb > 0: + log_path.write_bytes(b"x" * int(size_mb * 1024 * 1024)) + else: + log_path.touch() + return log_path + + +def create_archive( + logs_dir: Path, + service: str, + number: int, + content: str, +) -> Path: + """Create a gzip log archive and return its path.""" + logs_dir.mkdir(parents=True, exist_ok=True) + archive_path = logs_dir / f"{service}.log.{number}.gz" + with gzip.open(str(archive_path), "wb") as f: + f.write(content.encode("utf-8")) + return archive_path + + +def read_gz(path: Path) -> bytes: + """Read and decompress a gzip file.""" + with gzip.open(str(path), "rb") as f: + return f.read() diff --git a/tests/fakes.py b/tests/fakes.py new file mode 100644 index 0000000..7315ef0 --- /dev/null +++ b/tests/fakes.py @@ -0,0 +1,190 @@ +"""Fake service layer for behavioral testing of stack orchestration. + +Replaces 10-deep ``@patch`` stacks with a single, configurable test double. +The default behaviour is *happy path* — everything succeeds. Tests configure +only the failures they care about via explicit methods like ``fail_port()``, +``fail_health()``, or ``hold_lock()``. + +Usage in tests:: + + def test_startup(stack_on_disk, fake_services): + # Arrange — default: all services start successfully + # Act + result = run_up() + # Assert + assert all(t.status == "healthy" for t in result.tiers) +""" + +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from typing import Any + +from mlx_stack.core.process import ( + HealthCheckError, + HealthCheckResult, + LockError, + ServiceInfo, +) +from tests.factories import make_test_catalog + + +class FakeServiceLayer: + """Test double for OS-boundary functions used by ``run_up`` / watchdog. + + Defaults produce a successful startup. Call configuration methods to + inject specific failure scenarios before calling the code under test. + """ + + def __init__(self) -> None: + # --- Configuration (set before act) --- + self._port_conflicts: dict[int, tuple[int, str]] = {} + self._health_failures: set[int] = set() + self._start_failures: set[str] = set() + self._alive_pids: dict[int, bool] = {} + self._lock_held: bool = False + self._dependency_failures: set[str] = set() + self._model_check_failures: dict[str, str] = {} + self._catalog = make_test_catalog() + self._config_overrides: dict[str, Any] = {} + + # --- Recording (assert after act) --- + self.started: list[str] = [] + self.health_checked: list[int] = [] + self.dependencies_checked: list[str] = [] + + # Internal counter for deterministic PIDs + self._next_pid = 10000 + + # ------------------------------------------------------------------ # + # Configuration methods — call these in the Arrange phase + # ------------------------------------------------------------------ # + + def fail_port(self, port: int, pid: int = 99999, name: str = "other") -> None: + """Make ``check_port_conflict`` report *port* as occupied.""" + self._port_conflicts[port] = (pid, name) + + def fail_health(self, port: int) -> None: + """Make ``wait_for_healthy`` raise ``HealthCheckError`` for *port*.""" + self._health_failures.add(port) + + def fail_start(self, service_name: str) -> None: + """Make ``start_service`` raise for *service_name*.""" + self._start_failures.add(service_name) + + def set_alive(self, pid: int, alive: bool = True) -> None: + """Control what ``is_process_alive`` returns for a specific PID.""" + self._alive_pids[pid] = alive + + def hold_lock(self) -> None: + """Make ``acquire_lock`` raise ``LockError``.""" + self._lock_held = True + + def fail_dependency(self, name: str) -> None: + """Make ``ensure_dependency`` raise for *name*.""" + self._dependency_failures.add(name) + + def fail_model_check(self, tier_name: str, message: str) -> None: + """Make ``check_local_model_exists`` return an error for *tier_name*.""" + self._model_check_failures[tier_name] = message + + def set_catalog(self, catalog: list[Any]) -> None: + """Override the catalog returned by ``load_catalog``.""" + self._catalog = catalog + + def set_config(self, key: str, value: Any) -> None: + """Override a config value returned by ``get_value``.""" + self._config_overrides[key] = value + + # ------------------------------------------------------------------ # + # Fake implementations — patched into the module under test + # ------------------------------------------------------------------ # + + def start_service( + self, + service_name: str, + cmd: list[str], + port: int, + env: dict[str, str] | None = None, + log_dir: Path | None = None, + ) -> ServiceInfo: + """Fake ``process.start_service``.""" + self.started.append(service_name) + if service_name in self._start_failures: + msg = f"Fake start failure for {service_name}" + raise RuntimeError(msg) + pid = self._next_pid + port + return ServiceInfo( + name=service_name, + pid=pid, + port=port, + log_path=Path(f"/tmp/{service_name}.log"), + pid_path=Path(f"/tmp/{service_name}.pid"), + ) + + def wait_for_healthy(self, **kwargs: Any) -> HealthCheckResult: + """Fake ``process.wait_for_healthy``.""" + port = kwargs.get("port", 0) + self.health_checked.append(port) + if port in self._health_failures: + msg = f"Timeout after 120s waiting for port {port}" + raise HealthCheckError(msg) + return HealthCheckResult(healthy=True, response_time=0.1, status_code=200) + + def check_port_conflict(self, port: int) -> tuple[int, str] | None: + """Fake ``process.check_port_conflict``.""" + return self._port_conflicts.get(port) + + def is_process_alive(self, pid: int) -> bool: + """Fake ``process.is_process_alive``.""" + return self._alive_pids.get(pid, False) + + def cleanup_stale_pid(self, service_name: str) -> bool: + """Fake ``process.cleanup_stale_pid``.""" + return True + + @contextmanager + def acquire_lock(self): + """Fake ``process.acquire_lock``.""" + if self._lock_held: + raise LockError("Lock held by another process (fake)") + yield + + def ensure_dependency(self, name: str) -> None: + """Fake ``deps.ensure_dependency``.""" + self.dependencies_checked.append(name) + if name in self._dependency_failures: + from mlx_stack.core.deps import DependencyInstallError + + msg = f"Failed to install {name} (fake)" + raise DependencyInstallError(msg) + + def which(self, name: str) -> str: + """Fake ``shutil.which``.""" + return f"/usr/local/bin/{name}" + + def check_local_model_exists(self, tier: dict[str, Any]) -> str | None: + """Fake ``stack_up.check_local_model_exists``.""" + tier_name = tier.get("name", "") + return self._model_check_failures.get(tier_name) + + def load_catalog(self) -> list[Any]: + """Fake ``catalog.load_catalog``.""" + return self._catalog + + def get_value(self, key: str) -> Any: + """Fake ``config.get_value``.""" + if key in self._config_overrides: + return self._config_overrides[key] + defaults: dict[str, Any] = { + "litellm-port": 4000, + "openrouter-key": "", + "default-quant": "int4", + "memory-budget-pct": 40, + "model-dir": "~/.mlx-stack/models", + "auto-health-check": True, + "log-max-size-mb": 50, + "log-max-files": 5, + } + return defaults.get(key, "") diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..fb397a3 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,106 @@ +"""Unit-test-specific fixtures. + +Provides ``fake_services``, ``stack_on_disk``, and directory helpers that +build on the shared ``mlx_stack_home`` fixture from ``tests/conftest.py``. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from tests.factories import write_litellm_yaml, write_stack_yaml +from tests.fakes import FakeServiceLayer + +# --------------------------------------------------------------------------- # +# Directory helpers +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def pids_dir(mlx_stack_home: Path) -> Path: + """Create and return the ``pids/`` subdirectory.""" + d = mlx_stack_home / "pids" + d.mkdir(parents=True, exist_ok=True) + return d + + +@pytest.fixture +def logs_dir(mlx_stack_home: Path) -> Path: + """Create and return the ``logs/`` subdirectory.""" + d = mlx_stack_home / "logs" + d.mkdir(parents=True, exist_ok=True) + return d + + +# --------------------------------------------------------------------------- # +# Stack-on-disk fixture +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def stack_on_disk(mlx_stack_home: Path) -> Path: + """Write the default stack + litellm YAML and return the home path.""" + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + return mlx_stack_home + + +# --------------------------------------------------------------------------- # +# FakeServiceLayer fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def fake_services(monkeypatch: pytest.MonkeyPatch) -> FakeServiceLayer: + """Provide a ``FakeServiceLayer`` patched into ``mlx_stack.core.stack_up``. + + The default configuration produces a fully successful startup. Tests + call configuration methods (e.g. ``fake_services.fail_port(8000)``) + in the *Arrange* phase to inject specific failure scenarios. + """ + layer = FakeServiceLayer() + _patch_stack_up(monkeypatch, layer) + return layer + + +@pytest.fixture +def fake_watchdog_services(monkeypatch: pytest.MonkeyPatch) -> FakeServiceLayer: + """Provide a ``FakeServiceLayer`` patched into ``mlx_stack.core.watchdog``. + + Same as ``fake_services`` but targets the watchdog module's imports. + """ + layer = FakeServiceLayer() + _patch_watchdog(monkeypatch, layer) + return layer + + +# --------------------------------------------------------------------------- # +# Internal patch helpers +# --------------------------------------------------------------------------- # + + +def _patch_stack_up(mp: pytest.MonkeyPatch, layer: FakeServiceLayer) -> None: + """Apply monkeypatches for ``mlx_stack.core.stack_up``.""" + prefix = "mlx_stack.core.stack_up" + mp.setattr(f"{prefix}.start_service", layer.start_service) + mp.setattr(f"{prefix}.wait_for_healthy", layer.wait_for_healthy) + mp.setattr(f"{prefix}.check_port_conflict", layer.check_port_conflict) + mp.setattr(f"{prefix}.is_process_alive", layer.is_process_alive) + mp.setattr(f"{prefix}.cleanup_stale_pid", layer.cleanup_stale_pid) + mp.setattr(f"{prefix}.acquire_lock", layer.acquire_lock) + mp.setattr(f"{prefix}.ensure_dependency", layer.ensure_dependency) + mp.setattr(f"{prefix}.shutil.which", layer.which) + mp.setattr(f"{prefix}.check_local_model_exists", layer.check_local_model_exists) + mp.setattr(f"{prefix}.load_catalog", layer.load_catalog) + mp.setattr(f"{prefix}.get_value", layer.get_value) + + +def _patch_watchdog(mp: pytest.MonkeyPatch, layer: FakeServiceLayer) -> None: + """Apply monkeypatches for ``mlx_stack.core.watchdog``.""" + prefix = "mlx_stack.core.watchdog" + mp.setattr(f"{prefix}.start_service", layer.start_service) + mp.setattr(f"{prefix}.acquire_lock", layer.acquire_lock) + mp.setattr(f"{prefix}.is_process_alive", layer.is_process_alive) + mp.setattr(f"{prefix}.get_value", layer.get_value) diff --git a/tests/unit/test_cli_down.py b/tests/unit/test_cli_down.py index e98cc02..611ff7a 100644 --- a/tests/unit/test_cli_down.py +++ b/tests/unit/test_cli_down.py @@ -14,11 +14,9 @@ from __future__ import annotations from pathlib import Path -from typing import Any from unittest.mock import MagicMock, patch import pytest -import yaml from click.testing import CliRunner from mlx_stack.cli.main import cli @@ -36,78 +34,7 @@ _stop_single_service, run_down, ) - -# --------------------------------------------------------------------------- # -# Fixtures — reusable test data -# --------------------------------------------------------------------------- # - - -def _make_stack_yaml( - tiers: list[dict[str, Any]] | None = None, - schema_version: int = 1, -) -> dict[str, Any]: - """Create a stack definition dict for testing.""" - if tiers is None: - tiers = [ - { - "name": "standard", - "model": "big-model", - "quant": "int4", - "source": "mlx-community/big-model-4bit", - "port": 8000, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - { - "name": "fast", - "model": "fast-model", - "quant": "int4", - "source": "mlx-community/fast-model-4bit", - "port": 8001, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - ] - return { - "schema_version": schema_version, - "name": "default", - "hardware_profile": "m4-max-128", - "intent": "balanced", - "created": "2026-03-24T00:00:00+00:00", - "tiers": tiers, - } - - -def _write_stack_yaml( - mlx_stack_home: Path, - stack: dict[str, Any] | None = None, -) -> Path: - """Write a stack YAML file and return its path.""" - if stack is None: - stack = _make_stack_yaml() - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack, default_flow_style=False)) - return stack_path - - -def _create_pid_file( - mlx_stack_home: Path, - service_name: str, - pid: int | str = 12345, -) -> Path: - """Create a PID file in the pids directory.""" - pids_dir = mlx_stack_home / "pids" - pids_dir.mkdir(parents=True, exist_ok=True) - pid_path = pids_dir / f"{service_name}.pid" - pid_path.write_text(str(pid)) - return pid_path - +from tests.factories import create_pid_file, make_stack_yaml, write_stack_yaml # --------------------------------------------------------------------------- # # Tests — _get_tier_names_from_stack @@ -119,12 +46,15 @@ class TestGetTierNamesFromStack: def test_returns_tier_names_from_stack(self, mlx_stack_home: Path) -> None: """Returns tier names from a valid stack definition.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) + # Act with patch("mlx_stack.core.stack_down.load_catalog") as mock_catalog: mock_catalog.side_effect = Exception("no catalog") names = _get_tier_names_from_stack() + # Assert assert set(names) == {"standard", "fast"} def test_returns_empty_on_missing_stack(self, mlx_stack_home: Path) -> None: @@ -138,7 +68,7 @@ class TestGetValidTierNames: def test_returns_valid_tier_names(self, mlx_stack_home: Path) -> None: """Returns tier names from the stack definition.""" - _write_stack_yaml(mlx_stack_home) + write_stack_yaml(mlx_stack_home) names = _get_valid_tier_names() assert set(names) == {"standard", "fast"} @@ -158,8 +88,10 @@ class TestStopSingleService: def test_stops_running_process_gracefully(self, mlx_stack_home: Path) -> None: """VAL-DOWN-002: Stops a running process gracefully.""" - _create_pid_file(mlx_stack_home, "fast", 12345) + # Arrange + create_pid_file(mlx_stack_home, "fast", 12345) + # Act with ( patch("mlx_stack.core.stack_down.read_pid_file", return_value=12345), patch("mlx_stack.core.stack_down.is_process_alive", return_value=True), @@ -170,6 +102,7 @@ def test_stops_running_process_gracefully(self, mlx_stack_home: Path) -> None: ): result = _stop_single_service("fast") + # Assert assert result.name == "fast" assert result.status == "stopped" assert result.graceful is True @@ -177,8 +110,10 @@ def test_stops_running_process_gracefully(self, mlx_stack_home: Path) -> None: def test_stops_running_process_forced(self, mlx_stack_home: Path) -> None: """VAL-DOWN-002: Reports forced shutdown when SIGKILL needed.""" - _create_pid_file(mlx_stack_home, "fast", 99999) + # Arrange + create_pid_file(mlx_stack_home, "fast", 99999) + # Act with ( patch("mlx_stack.core.stack_down.read_pid_file", return_value=99999), patch("mlx_stack.core.stack_down.is_process_alive", return_value=True), @@ -189,13 +124,16 @@ def test_stops_running_process_forced(self, mlx_stack_home: Path) -> None: ): result = _stop_single_service("fast") + # Assert assert result.status == "stopped" assert result.graceful is False def test_handles_stale_pid(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Detects stale PID and cleans up.""" - _create_pid_file(mlx_stack_home, "fast", 12345) + # Arrange + create_pid_file(mlx_stack_home, "fast", 12345) + # Act with ( patch("mlx_stack.core.stack_down.read_pid_file", return_value=12345), patch("mlx_stack.core.stack_down.is_process_alive", return_value=False), @@ -203,6 +141,7 @@ def test_handles_stale_pid(self, mlx_stack_home: Path) -> None: ): result = _stop_single_service("fast") + # Assert assert result.status == "stale" assert result.pid == 12345 assert "already dead" in (result.error or "") @@ -210,8 +149,10 @@ def test_handles_stale_pid(self, mlx_stack_home: Path) -> None: def test_handles_corrupt_pid(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Handles corrupt PID file gracefully.""" - _create_pid_file(mlx_stack_home, "fast", "not-a-number") + # Arrange + create_pid_file(mlx_stack_home, "fast", "not-a-number") + # Act with ( patch( "mlx_stack.core.stack_down.read_pid_file", @@ -221,6 +162,7 @@ def test_handles_corrupt_pid(self, mlx_stack_home: Path) -> None: ): result = _stop_single_service("fast") + # Assert assert result.status == "corrupt" assert result.pid is None assert "non-numeric" in (result.error or "") @@ -252,10 +194,11 @@ def test_nothing_to_stop_no_pid_files(self, mlx_stack_home: Path) -> None: def test_full_shutdown_correct_order(self, mlx_stack_home: Path) -> None: """VAL-DOWN-001: LiteLLM stopped first, then servers in reverse order.""" - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "standard", 1001) - _create_pid_file(mlx_stack_home, "fast", 1002) - _create_pid_file(mlx_stack_home, "litellm", 1003) + # Arrange + write_stack_yaml(mlx_stack_home) + create_pid_file(mlx_stack_home, "standard", 1001) + create_pid_file(mlx_stack_home, "fast", 1002) + create_pid_file(mlx_stack_home, "litellm", 1003) shutdown_order: list[str] = [] @@ -268,6 +211,7 @@ def mock_stop(service_name: str) -> ServiceStopResult: graceful=True, ) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_down._stop_single_service", side_effect=mock_stop), @@ -278,6 +222,7 @@ def mock_stop(service_name: str) -> ServiceStopResult: result = run_down() + # Assert assert result.nothing_to_stop is False assert len(result.services) == 3 @@ -288,8 +233,10 @@ def mock_stop(service_name: str) -> ServiceStopResult: def test_lockfile_acquired_during_shutdown(self, mlx_stack_home: Path) -> None: """VAL-DOWN-003: Lockfile acquired during down.""" - _create_pid_file(mlx_stack_home, "litellm", 1001) + # Arrange + create_pid_file(mlx_stack_home, "litellm", 1001) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch( @@ -307,12 +254,15 @@ def test_lockfile_acquired_during_shutdown(self, mlx_stack_home: Path) -> None: run_down() + # Assert mock_lock.assert_called_once() def test_lockfile_error_propagated(self, mlx_stack_home: Path) -> None: """VAL-DOWN-003: Lock conflict raises LockError.""" - _create_pid_file(mlx_stack_home, "fast", 1001) + # Arrange + create_pid_file(mlx_stack_home, "fast", 1001) + # Act / Assert with ( patch( "mlx_stack.core.stack_down.acquire_lock", @@ -324,10 +274,11 @@ def test_lockfile_error_propagated(self, mlx_stack_home: Path) -> None: def test_tier_filter_stops_only_specified_tier(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: --tier stops only the specified tier.""" - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "standard", 1001) - _create_pid_file(mlx_stack_home, "fast", 1002) - _create_pid_file(mlx_stack_home, "litellm", 1003) + # Arrange + write_stack_yaml(mlx_stack_home) + create_pid_file(mlx_stack_home, "standard", 1001) + create_pid_file(mlx_stack_home, "fast", 1002) + create_pid_file(mlx_stack_home, "litellm", 1003) stopped_services: list[str] = [] @@ -340,6 +291,7 @@ def mock_stop(service_name: str) -> ServiceStopResult: graceful=True, ) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_down._stop_single_service", side_effect=mock_stop), @@ -349,35 +301,40 @@ def mock_stop(service_name: str) -> ServiceStopResult: result = run_down(tier_filter="fast") - # Only 'fast' should be stopped + # Assert — only 'fast' should be stopped assert len(result.services) == 1 assert result.services[0].name == "fast" assert stopped_services == ["fast"] def test_tier_filter_invalid_tier_raises_error(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: Invalid tier name errors with valid tier list.""" - _write_stack_yaml(mlx_stack_home) + write_stack_yaml(mlx_stack_home) with pytest.raises(DownError, match="Unknown tier 'nonexistent'"): run_down(tier_filter="nonexistent") def test_tier_filter_invalid_tier_shows_valid_list(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: Error message includes valid tier names.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) + # Act with pytest.raises(DownError) as exc_info: run_down(tier_filter="nonexistent") + # Assert error_msg = str(exc_info.value) assert "fast" in error_msg assert "standard" in error_msg def test_tier_filter_not_running(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: --tier for a valid but not-running tier.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) # Create a PID file for a different tier to avoid "nothing to stop" - _create_pid_file(mlx_stack_home, "standard", 1001) + create_pid_file(mlx_stack_home, "standard", 1001) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, ): @@ -386,14 +343,17 @@ def test_tier_filter_not_running(self, mlx_stack_home: Path) -> None: result = run_down(tier_filter="fast") + # Assert assert len(result.services) == 1 assert result.services[0].name == "fast" assert result.services[0].status == "not-running" def test_stale_pid_detected_and_cleaned(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Stale PIDs detected, reported, and cleaned up.""" - _create_pid_file(mlx_stack_home, "fast", 12345) + # Arrange + create_pid_file(mlx_stack_home, "fast", 12345) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch( @@ -411,14 +371,17 @@ def test_stale_pid_detected_and_cleaned(self, mlx_stack_home: Path) -> None: result = run_down() + # Assert assert len(result.services) == 1 assert result.services[0].status == "stale" assert "already dead" in (result.services[0].error or "") def test_corrupt_pid_detected_and_cleaned(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Corrupt PID files reported and removed.""" - _create_pid_file(mlx_stack_home, "fast", "garbage") + # Arrange + create_pid_file(mlx_stack_home, "fast", "garbage") + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch( @@ -436,15 +399,17 @@ def test_corrupt_pid_detected_and_cleaned(self, mlx_stack_home: Path) -> None: result = run_down() + # Assert assert len(result.services) == 1 assert result.services[0].status == "corrupt" def test_mixed_stale_and_running_services(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Mix of stale, corrupt, and running services processed.""" - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "standard", 1001) - _create_pid_file(mlx_stack_home, "fast", "bad") - _create_pid_file(mlx_stack_home, "litellm", 1003) + # Arrange + write_stack_yaml(mlx_stack_home) + create_pid_file(mlx_stack_home, "standard", 1001) + create_pid_file(mlx_stack_home, "fast", "bad") + create_pid_file(mlx_stack_home, "litellm", 1003) results_map = { "litellm": ServiceStopResult( @@ -470,6 +435,7 @@ def test_mixed_stale_and_running_services(self, mlx_stack_home: Path) -> None: def mock_stop(service_name: str) -> ServiceStopResult: return results_map[service_name] + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_down._stop_single_service", side_effect=mock_stop), @@ -480,6 +446,7 @@ def mock_stop(service_name: str) -> ServiceStopResult: result = run_down() + # Assert assert len(result.services) == 3 statuses = {s.name: s.status for s in result.services} assert statuses["litellm"] == "stopped" @@ -488,9 +455,11 @@ def mock_stop(service_name: str) -> ServiceStopResult: def test_orphaned_services_cleaned_up(self, mlx_stack_home: Path) -> None: """Services not in stack definition are still cleaned up.""" - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "orphaned-service", 9999) + # Arrange + write_stack_yaml(mlx_stack_home) + create_pid_file(mlx_stack_home, "orphaned-service", 9999) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch( @@ -509,16 +478,19 @@ def test_orphaned_services_cleaned_up(self, mlx_stack_home: Path) -> None: result = run_down() + # Assert assert len(result.services) == 1 assert result.services[0].name == "orphaned-service" assert result.services[0].status == "stopped" def test_pid_files_cleaned_after_shutdown(self, mlx_stack_home: Path) -> None: """VAL-DOWN-003 / VAL-CROSS-001: PID files deleted after termination.""" + # Arrange pids_dir = mlx_stack_home / "pids" pids_dir.mkdir(parents=True, exist_ok=True) (pids_dir / "fast.pid").write_text("12345") + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.process.psutil") as mock_psutil, @@ -536,13 +508,15 @@ def test_pid_files_cleaned_after_shutdown(self, mlx_stack_home: Path) -> None: run_down() - # PID file should be removed + # Assert assert not (pids_dir / "fast.pid").exists() def test_no_stack_definition_falls_back(self, mlx_stack_home: Path) -> None: """When no stack exists, PID files are still processed.""" - _create_pid_file(mlx_stack_home, "fast", 12345) + # Arrange + create_pid_file(mlx_stack_home, "fast", 12345) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch( @@ -560,6 +534,7 @@ def test_no_stack_definition_falls_back(self, mlx_stack_home: Path) -> None: result = run_down() + # Assert assert len(result.services) == 1 assert result.services[0].status == "stopped" @@ -591,6 +566,7 @@ def test_nothing_to_stop_message(self, mlx_stack_home: Path) -> None: def test_shutdown_displays_summary_table(self, mlx_stack_home: Path) -> None: """VAL-DOWN-001: Displays shutdown summary.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -614,10 +590,12 @@ def test_shutdown_displays_summary_table(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 assert "Shutdown Summary" in result.output assert "litellm" in result.output @@ -626,6 +604,7 @@ def test_shutdown_displays_summary_table(self, mlx_stack_home: Path) -> None: def test_shutdown_shows_graceful_method(self, mlx_stack_home: Path) -> None: """VAL-DOWN-002: Reports graceful vs forced per service.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -643,10 +622,12 @@ def test_shutdown_shows_graceful_method(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 assert "graceful" in result.output assert "SIGTERM" in result.output @@ -654,6 +635,7 @@ def test_shutdown_shows_graceful_method(self, mlx_stack_home: Path) -> None: def test_forced_sigkill_explicit_in_output(self, mlx_stack_home: Path) -> None: """VAL-DOWN-002: SIGKILL escalation explicitly reported per service.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -665,10 +647,12 @@ def test_forced_sigkill_explicit_in_output(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 # Must explicitly say "forced (SIGKILL)" for force-killed service assert "forced (SIGKILL)" in result.output @@ -677,6 +661,7 @@ def test_forced_sigkill_explicit_in_output(self, mlx_stack_home: Path) -> None: def test_graceful_sigterm_explicit_in_output(self, mlx_stack_home: Path) -> None: """VAL-DOWN-002: Graceful shutdown explicitly reports SIGTERM.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -688,15 +673,18 @@ def test_graceful_sigterm_explicit_in_output(self, mlx_stack_home: Path) -> None ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 assert "graceful (SIGTERM)" in result.output def test_tier_filter_option(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: --tier option passed to run_down.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -708,10 +696,12 @@ def test_tier_filter_option(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result) as mock_run: runner = CliRunner() result = runner.invoke(cli, ["down", "--tier", "fast"]) + # Assert assert result.exit_code == 0 mock_run.assert_called_once_with(tier_filter="fast") @@ -741,6 +731,7 @@ def test_lock_error_shows_message(self, mlx_stack_home: Path) -> None: def test_stale_pid_shown_in_output(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Stale PID reported in output.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -752,15 +743,18 @@ def test_stale_pid_shown_in_output(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 assert "stale" in result.output.lower() def test_corrupt_pid_shown_in_output(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Corrupt PID reported in output.""" + # Arrange mock_result = DownResult( services=[ ServiceStopResult( @@ -772,10 +766,12 @@ def test_corrupt_pid_shown_in_output(self, mlx_stack_home: Path) -> None: ], ) + # Act with patch("mlx_stack.cli.down.run_down", return_value=mock_result): runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 0 assert "corrupt" in result.output.lower() @@ -798,6 +794,7 @@ def test_down_appears_in_main_help(self, mlx_stack_home: Path) -> None: def test_no_traceback_on_error(self, mlx_stack_home: Path) -> None: """Errors never show Python tracebacks.""" + # Act with patch( "mlx_stack.cli.down.run_down", side_effect=DownError("Something went wrong"), @@ -805,6 +802,7 @@ def test_no_traceback_on_error(self, mlx_stack_home: Path) -> None: runner = CliRunner() result = runner.invoke(cli, ["down"]) + # Assert assert result.exit_code == 1 assert "Traceback" not in result.output assert "Error:" in result.output @@ -820,13 +818,15 @@ class TestDownEndToEnd: def test_full_lifecycle_cleanup(self, mlx_stack_home: Path) -> None: """VAL-CROSS-001: After down, zero PID files remain.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) pids_dir = mlx_stack_home / "pids" pids_dir.mkdir(parents=True, exist_ok=True) (pids_dir / "standard.pid").write_text("12345") (pids_dir / "fast.pid").write_text("12346") (pids_dir / "litellm.pid").write_text("12347") + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.process.psutil") as mock_psutil, @@ -855,19 +855,21 @@ def test_full_lifecycle_cleanup(self, mlx_stack_home: Path) -> None: ): run_down() - # All PID files should be cleaned up + # Assert remaining_pids = list(pids_dir.glob("*.pid")) assert remaining_pids == [], f"PID files remain: {remaining_pids}" def test_selective_tier_leaves_others(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: --tier leaves other services' PID files intact.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) pids_dir = mlx_stack_home / "pids" pids_dir.mkdir(parents=True, exist_ok=True) (pids_dir / "standard.pid").write_text("1001") (pids_dir / "fast.pid").write_text("1002") (pids_dir / "litellm.pid").write_text("1003") + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.process.psutil") as mock_psutil, @@ -884,7 +886,7 @@ def test_selective_tier_leaves_others(self, mlx_stack_home: Path) -> None: run_down(tier_filter="fast") - # Only fast PID should be removed + # Assert — only fast PID should be removed assert not (pids_dir / "fast.pid").exists() # Others should remain assert (pids_dir / "standard.pid").exists() @@ -892,11 +894,13 @@ def test_selective_tier_leaves_others(self, mlx_stack_home: Path) -> None: def test_corrupt_and_stale_mixed(self, mlx_stack_home: Path) -> None: """VAL-DOWN-005: Mixed corrupt + stale PID files all cleaned.""" + # Arrange pids_dir = mlx_stack_home / "pids" pids_dir.mkdir(parents=True, exist_ok=True) (pids_dir / "fast.pid").write_text("not_a_pid") (pids_dir / "standard.pid").write_text("99999") + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.process.psutil") as mock_psutil, @@ -911,7 +915,7 @@ def test_corrupt_and_stale_mixed(self, mlx_stack_home: Path) -> None: result = run_down() - # Both PID files should be cleaned + # Assert remaining = list(pids_dir.glob("*.pid")) assert remaining == [], f"PID files remain: {remaining}" @@ -930,10 +934,11 @@ class TestShutdownOrder: def test_litellm_stopped_before_model_servers(self, mlx_stack_home: Path) -> None: """VAL-DOWN-001: LiteLLM stopped first.""" - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "standard", 1001) - _create_pid_file(mlx_stack_home, "fast", 1002) - _create_pid_file(mlx_stack_home, "litellm", 1003) + # Arrange + write_stack_yaml(mlx_stack_home) + create_pid_file(mlx_stack_home, "standard", 1001) + create_pid_file(mlx_stack_home, "fast", 1002) + create_pid_file(mlx_stack_home, "litellm", 1003) order: list[str] = [] @@ -946,6 +951,7 @@ def mock_stop(name: str) -> ServiceStopResult: graceful=True, ) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_down._stop_single_service", side_effect=mock_stop), @@ -956,11 +962,12 @@ def mock_stop(name: str) -> ServiceStopResult: run_down() + # Assert assert order[0] == "litellm" def test_model_servers_reversed(self, mlx_stack_home: Path) -> None: """VAL-DOWN-001: Model servers stopped in reverse startup order.""" - # Stack with 3 tiers of different sizes + # Arrange — stack with 3 tiers of different sizes tiers = [ { "name": "standard", @@ -987,11 +994,11 @@ def test_model_servers_reversed(self, mlx_stack_home: Path) -> None: "vllm_flags": {}, }, ] - stack = _make_stack_yaml(tiers=tiers) - _write_stack_yaml(mlx_stack_home, stack) + stack = make_stack_yaml(tiers=tiers) + write_stack_yaml(mlx_stack_home, stack) for tier in tiers: - _create_pid_file(mlx_stack_home, tier["name"], 1000) + create_pid_file(mlx_stack_home, tier["name"], 1000) order: list[str] = [] @@ -1004,6 +1011,7 @@ def mock_stop(name: str) -> ServiceStopResult: graceful=True, ) + # Act with ( patch("mlx_stack.core.stack_down.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_down._stop_single_service", side_effect=mock_stop), @@ -1014,7 +1022,7 @@ def mock_stop(name: str) -> ServiceStopResult: run_down() - # With no catalog, tiers are in stack definition order + # Assert — with no catalog, tiers are in stack definition order # Reversed means last-defined tier stops first (among model servers) model_servers = [s for s in order if s != "litellm"] # Reverse of the startup order; startup order = stack definition order diff --git a/tests/unit/test_cli_init.py b/tests/unit/test_cli_init.py index ce9b69f..9c1b566 100644 --- a/tests/unit/test_cli_init.py +++ b/tests/unit/test_cli_init.py @@ -32,13 +32,7 @@ from click.testing import CliRunner from mlx_stack.cli.main import cli -from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, - QualityScores, - QuantSource, -) +from mlx_stack.core.catalog import BenchmarkResult, CatalogEntry from mlx_stack.core.hardware import HardwareProfile from mlx_stack.core.stack_init import ( InitError, @@ -47,99 +41,23 @@ detect_missing_models, run_init, ) +from tests.factories import make_entry, make_profile # --------------------------------------------------------------------------- # -# Fixtures — reusable test data +# Helpers — test-specific data builders (shared factories in tests.factories) # --------------------------------------------------------------------------- # -def _make_profile( - chip: str = "Apple M4 Max", - gpu_cores: int = 40, - memory_gb: int = 128, - bandwidth_gbps: float = 546.0, - is_estimate: bool = False, -) -> HardwareProfile: - """Create a HardwareProfile for testing.""" - return HardwareProfile( - chip=chip, - gpu_cores=gpu_cores, - memory_gb=memory_gb, - bandwidth_gbps=bandwidth_gbps, - is_estimate=is_estimate, - ) - - -def _make_small_profile() -> HardwareProfile: - """Create a small-memory profile (32 GB).""" - return HardwareProfile( - chip="Apple M4 Pro", - gpu_cores=20, - memory_gb=32, - bandwidth_gbps=273.0, - is_estimate=False, - ) - - -def _make_entry( - model_id: str = "test-model", - name: str = "Test Model", - family: str = "Test", - params_b: float = 8.0, - architecture: str = "transformer", - quality_overall: int = 70, - quality_coding: int = 65, - quality_reasoning: int = 60, - quality_instruction: int = 72, - tool_calling: bool = True, - tool_call_parser: str | None = "hermes", - thinking: bool = False, - reasoning_parser: str | None = None, - benchmarks: dict[str, BenchmarkResult] | None = None, - tags: list[str] | None = None, - memory_gb: float = 5.5, - gated: bool = False, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - if benchmarks is None: - benchmarks = { - "m4-pro-32": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=memory_gb), - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=memory_gb), - } - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource(hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=4.5), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser=tool_call_parser if tool_calling else None, - thinking=thinking, - reasoning_parser=reasoning_parser if thinking else None, - vision=False, - ), - quality=QualityScores( - overall=quality_overall, - coding=quality_coding, - reasoning=quality_reasoning, - instruction_following=quality_instruction, - ), - benchmarks=benchmarks, - tags=tags or [], - gated=gated, - ) - - def _make_test_catalog() -> list[CatalogEntry]: - """Create a test catalog with several models.""" + """Create a four-model test catalog using the shared ``make_entry`` factory. + + The catalog composition matters for test correctness: big-model and + fast-model are standard/fast tier candidates, longctx-model exercises + architecture variety, and medium-model is used by --add/--remove tests. + """ return [ # High quality, slow — standard tier candidate - _make_entry( + make_entry( model_id="big-model", name="Big Model 49B", params_b=49.0, @@ -157,7 +75,7 @@ def _make_test_catalog() -> list[CatalogEntry]: memory_gb=30.0, ), # Fast, small — fast tier candidate - _make_entry( + make_entry( model_id="fast-model", name="Fast Model 3B", params_b=3.0, @@ -173,7 +91,7 @@ def _make_test_catalog() -> list[CatalogEntry]: memory_gb=2.0, ), # Long context architecture — longctx tier candidate - _make_entry( + make_entry( model_id="longctx-model", name="LongCtx Model 32B", params_b=32.0, @@ -191,7 +109,7 @@ def _make_test_catalog() -> list[CatalogEntry]: memory_gb=20.0, ), # Medium model for add/remove tests - _make_entry( + make_entry( model_id="medium-model", name="Medium Model 8B", params_b=8.0, @@ -259,50 +177,80 @@ class TestVLLMFlags: def test_base_flags_always_present(self) -> None: """continuous_batching and use_paged_cache always present.""" - entry = _make_entry(tool_calling=False, thinking=False) + # Arrange + entry = make_entry(tool_calling=False, thinking=False) + + # Act flags = build_vllm_flags(entry) + + # Assert assert flags["continuous_batching"] is True assert flags["use_paged_cache"] is True def test_tool_calling_flags(self) -> None: """Tool-calling models get enable_auto_tool_choice and tool_call_parser.""" - entry = _make_entry(tool_calling=True, tool_call_parser="hermes") + # Arrange + entry = make_entry(tool_calling=True, tool_call_parser="hermes") + + # Act flags = build_vllm_flags(entry) + + # Assert assert flags["enable_auto_tool_choice"] is True assert flags["tool_call_parser"] == "hermes" def test_no_tool_calling_flags_without_capability(self) -> None: """Non-tool-calling models don't get tool-calling flags.""" - entry = _make_entry(tool_calling=False) + # Arrange + entry = make_entry(tool_calling=False) + + # Act flags = build_vllm_flags(entry) + + # Assert assert "enable_auto_tool_choice" not in flags assert "tool_call_parser" not in flags def test_thinking_model_gets_reasoning_parser(self) -> None: """Thinking-capable models get reasoning_parser.""" - entry = _make_entry( + # Arrange + entry = make_entry( tool_calling=False, thinking=True, reasoning_parser="deepseek_r1", ) + + # Act flags = build_vllm_flags(entry) + + # Assert assert flags["reasoning_parser"] == "deepseek_r1" def test_no_reasoning_parser_without_thinking(self) -> None: """Non-thinking models don't get reasoning_parser.""" - entry = _make_entry(thinking=False) + # Arrange + entry = make_entry(thinking=False) + + # Act flags = build_vllm_flags(entry) + + # Assert assert "reasoning_parser" not in flags def test_combined_tool_and_thinking_flags(self) -> None: """Model with both tool-calling and thinking gets all flags.""" - entry = _make_entry( + # Arrange + entry = make_entry( tool_calling=True, tool_call_parser="hermes", thinking=True, reasoning_parser="nemotron", ) + + # Act flags = build_vllm_flags(entry) + + # Assert assert flags["continuous_batching"] is True assert flags["use_paged_cache"] is True assert flags["enable_auto_tool_choice"] is True @@ -320,89 +268,106 @@ class TestStackDefinitionGeneration: def test_schema_version_is_1(self, mlx_stack_home: Path) -> None: """Stack definition has schema_version: 1.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert assert result["stack"]["schema_version"] == 1 def test_hardware_profile_matches(self, mlx_stack_home: Path) -> None: """hardware_profile matches the detected profile ID.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert assert result["stack"]["hardware_profile"] == profile.profile_id def test_intent_matches(self, mlx_stack_home: Path) -> None: """intent field matches the selected intent.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="agent-fleet", force=True) + # Assert assert result["stack"]["intent"] == "agent-fleet" def test_name_is_default(self, mlx_stack_home: Path) -> None: """Stack name is 'default'.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert assert result["stack"]["name"] == "default" def test_created_timestamp_is_iso8601(self, mlx_stack_home: Path) -> None: """created field is a valid ISO 8601 timestamp.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert created = result["stack"]["created"] - # Should parse as ISO 8601 dt = datetime.fromisoformat(created) assert dt is not None def test_tiers_have_required_fields(self, mlx_stack_home: Path) -> None: """Each tier has name, model, quant, source, port, vllm_flags.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert for tier in result["stack"]["tiers"]: assert "name" in tier assert "model" in tier @@ -413,7 +378,7 @@ def test_tiers_have_required_fields(self, mlx_stack_home: Path) -> None: def test_tier_ports_are_unique(self, mlx_stack_home: Path) -> None: """All tier ports are unique.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -428,7 +393,7 @@ def test_tier_ports_are_unique(self, mlx_stack_home: Path) -> None: def test_tier_ports_dont_conflict_with_litellm(self, mlx_stack_home: Path) -> None: """No tier port equals the LiteLLM port (default 4000).""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -452,7 +417,7 @@ class TestFileGeneration: def test_stack_yaml_written(self, mlx_stack_home: Path) -> None: """Stack YAML file is written to the correct path.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -471,7 +436,7 @@ def test_stack_yaml_written(self, mlx_stack_home: Path) -> None: def test_litellm_yaml_written(self, mlx_stack_home: Path) -> None: """LiteLLM YAML file is written to the correct path.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -489,7 +454,7 @@ def test_litellm_yaml_written(self, mlx_stack_home: Path) -> None: def test_directory_auto_created(self, clean_mlx_stack_home: Path) -> None: """VAL-INIT-011: Directories are auto-created if missing.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() with ( @@ -512,7 +477,7 @@ class TestLiteLLMConfigContent: def test_model_list_has_correct_count(self, mlx_stack_home: Path) -> None: """model_list has one entry per local tier.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -528,7 +493,7 @@ def test_model_list_has_correct_count(self, mlx_stack_home: Path) -> None: def test_api_base_matches_tier_port(self, mlx_stack_home: Path) -> None: """api_base URLs match tier ports.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -550,7 +515,7 @@ def test_api_base_matches_tier_port(self, mlx_stack_home: Path) -> None: def test_model_uses_openai_prefix(self, mlx_stack_home: Path) -> None: """Model identifiers use openai/ prefix.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -565,7 +530,7 @@ def test_model_uses_openai_prefix(self, mlx_stack_home: Path) -> None: def test_api_key_is_dummy(self, mlx_stack_home: Path) -> None: """api_key is 'dummy' for local models.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -580,7 +545,7 @@ def test_api_key_is_dummy(self, mlx_stack_home: Path) -> None: def test_router_settings_present(self, mlx_stack_home: Path) -> None: """VAL-INIT-013: router_settings present with correct values.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -597,7 +562,7 @@ def test_router_settings_present(self, mlx_stack_home: Path) -> None: def test_fallback_chain_present(self, mlx_stack_home: Path) -> None: """VAL-INIT-006: Fallback chain references valid tier names.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -627,7 +592,7 @@ class TestCloudFallback: def test_cloud_fallback_with_key(self, mlx_stack_home: Path) -> None: """VAL-INIT-007: Cloud fallback added with OpenRouter key.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -663,7 +628,7 @@ def config_side_effect(key: str): def test_no_cloud_without_key(self, mlx_stack_home: Path) -> None: """No cloud fallback when OpenRouter key is empty.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -691,18 +656,17 @@ class TestOverwriteProtection: def test_overwrite_blocked_without_force(self, mlx_stack_home: Path) -> None: """VAL-INIT-009: Existing stack requires --force to overwrite.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - - # Create initial stack with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): run_init(intent="balanced", force=True) - # Try again without force + # Act / Assert with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), @@ -712,23 +676,24 @@ def test_overwrite_blocked_without_force(self, mlx_stack_home: Path) -> None: def test_overwrite_allowed_with_force(self, mlx_stack_home: Path) -> None: """--force allows overwriting existing stack.""" - profile = _make_profile() + # Arrange + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): run_init(intent="balanced", force=True) - # Overwrite with force + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert assert result["stack"]["schema_version"] == 1 @@ -742,7 +707,7 @@ class TestAddRemove: def test_remove_tier(self, mlx_stack_home: Path) -> None: """VAL-INIT-008: --remove excludes a tier.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -761,7 +726,7 @@ def test_remove_tier(self, mlx_stack_home: Path) -> None: def test_remove_invalid_tier_errors(self, mlx_stack_home: Path) -> None: """Removing a non-existent tier raises an error.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -778,7 +743,7 @@ def test_remove_invalid_tier_errors(self, mlx_stack_home: Path) -> None: def test_add_model(self, mlx_stack_home: Path) -> None: """VAL-INIT-008: --add includes an additional model.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -797,7 +762,7 @@ def test_add_model(self, mlx_stack_home: Path) -> None: def test_add_unknown_model_errors(self, mlx_stack_home: Path) -> None: """Adding an unknown model raises an error.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -814,7 +779,7 @@ def test_add_unknown_model_errors(self, mlx_stack_home: Path) -> None: def test_invalid_intent_errors(self, mlx_stack_home: Path) -> None: """Invalid intent raises an error.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -896,7 +861,7 @@ class TestCLIInit: def test_accept_defaults_completes(self, mlx_stack_home: Path) -> None: """VAL-INIT-001: --accept-defaults completes without prompts.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -911,7 +876,7 @@ def test_accept_defaults_completes(self, mlx_stack_home: Path) -> None: def test_accept_defaults_with_intent(self, mlx_stack_home: Path) -> None: """--accept-defaults combined with --intent works.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -926,7 +891,7 @@ def test_accept_defaults_with_intent(self, mlx_stack_home: Path) -> None: def test_overwrite_without_force_exits_error(self, mlx_stack_home: Path) -> None: """VAL-INIT-009: Without --force, existing stack causes error exit.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -946,7 +911,7 @@ def test_overwrite_without_force_exits_error(self, mlx_stack_home: Path) -> None def test_force_allows_overwrite(self, mlx_stack_home: Path) -> None: """--force allows overwriting existing stack.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -963,7 +928,7 @@ def test_force_allows_overwrite(self, mlx_stack_home: Path) -> None: def test_output_shows_file_paths(self, mlx_stack_home: Path) -> None: """VAL-INIT-012: Output shows file paths.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -979,7 +944,7 @@ def test_output_shows_file_paths(self, mlx_stack_home: Path) -> None: def test_output_shows_tier_assignments(self, mlx_stack_home: Path) -> None: """VAL-INIT-012: Output shows tier assignments.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -994,7 +959,7 @@ def test_output_shows_tier_assignments(self, mlx_stack_home: Path) -> None: def test_output_shows_next_steps(self, mlx_stack_home: Path) -> None: """VAL-INIT-012: Output shows next-step instructions.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1010,7 +975,7 @@ def test_output_shows_next_steps(self, mlx_stack_home: Path) -> None: def test_missing_models_shows_pull_suggestion(self, mlx_stack_home: Path) -> None: """VAL-INIT-010: Missing models show pull suggestion.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1026,7 +991,7 @@ def test_missing_models_shows_pull_suggestion(self, mlx_stack_home: Path) -> Non def test_generated_stack_yaml_is_valid(self, mlx_stack_home: Path) -> None: """VAL-INIT-002: Stack YAML is valid and parseable.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1050,7 +1015,7 @@ def test_generated_stack_yaml_is_valid(self, mlx_stack_home: Path) -> None: def test_generated_litellm_yaml_is_valid(self, mlx_stack_home: Path) -> None: """VAL-INIT-005: LiteLLM YAML is valid and parseable.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1071,7 +1036,7 @@ def test_generated_litellm_yaml_is_valid(self, mlx_stack_home: Path) -> None: def test_add_option_works(self, mlx_stack_home: Path) -> None: """--add works via CLI.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1086,7 +1051,7 @@ def test_add_option_works(self, mlx_stack_home: Path) -> None: def test_remove_option_works(self, mlx_stack_home: Path) -> None: """--remove works via CLI.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1101,7 +1066,7 @@ def test_remove_option_works(self, mlx_stack_home: Path) -> None: def test_different_intents_produce_different_stacks(self, mlx_stack_home: Path) -> None: """VAL-CROSS-005: Different intents produce different selections.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1120,7 +1085,7 @@ def test_different_intents_produce_different_stacks(self, mlx_stack_home: Path) def test_vllm_flags_in_generated_stack(self, mlx_stack_home: Path) -> None: """VAL-INIT-004: vllm_flags have correct feature flags.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1200,7 +1165,7 @@ def test_raises_when_no_ports_available(self) -> None: def test_port_detection_in_full_init(self, mlx_stack_home: Path) -> None: """Port-in-use detection is exercised during full init flow.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1231,7 +1196,7 @@ class TestTotalEstimatedMemory: def test_total_memory_in_result(self, mlx_stack_home: Path) -> None: """run_init returns total_memory_gb summing all tier memory.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1248,7 +1213,7 @@ def test_total_memory_in_result(self, mlx_stack_home: Path) -> None: def test_total_memory_displayed_in_summary(self, mlx_stack_home: Path) -> None: """VAL-INIT-012: Terminal summary shows total estimated memory.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1266,7 +1231,7 @@ def test_total_memory_displayed_in_summary(self, mlx_stack_home: Path) -> None: def test_total_memory_sum_is_correct(self, mlx_stack_home: Path) -> None: """Total memory is the sum of individual tier memory_gb values.""" - profile = _make_profile() + profile = make_profile() catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) @@ -1294,18 +1259,17 @@ class TestGatedModelExclusion: def test_init_excludes_gated_models(self, mlx_stack_home: Path) -> None: """Default init excludes gated models from tier assignments.""" - profile = _make_profile() + # Arrange + profile = make_profile() _write_profile(mlx_stack_home, profile) - - # Create catalog where the best model is gated catalog = [ - _make_entry( + make_entry( model_id="gated-best", name="Gated Best", quality_overall=99, gated=True, ), - _make_entry( + make_entry( model_id="open-good", name="Open Good", quality_overall=70, @@ -1313,26 +1277,29 @@ def test_init_excludes_gated_models(self, mlx_stack_home: Path) -> None: ), ] + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), ): result = run_init(intent="balanced", force=True) + # Assert tier_model_ids = {t["model"] for t in result["stack"]["tiers"]} assert "gated-best" not in tier_model_ids assert "open-good" in tier_model_ids def test_add_gated_model_warns(self, mlx_stack_home: Path) -> None: """Adding a gated model via --add produces a warning.""" - profile = _make_profile() + # Arrange + profile = make_profile() _write_profile(mlx_stack_home, profile) - catalog = [ - _make_entry(model_id="open-model", name="Open Model"), - _make_entry(model_id="gated-model", name="Gated Model", gated=True), + make_entry(model_id="open-model", name="Open Model"), + make_entry(model_id="gated-model", name="Gated Model", gated=True), ] + # Act with ( patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), patch("mlx_stack.core.stack_init.load_profile", return_value=profile), @@ -1343,6 +1310,7 @@ def test_add_gated_model_warns(self, mlx_stack_home: Path) -> None: force=True, ) + # Assert warnings = result["warnings"] gated_warnings = [w for w in warnings if "gated" in w.lower()] assert len(gated_warnings) >= 1 diff --git a/tests/unit/test_cli_logs.py b/tests/unit/test_cli_logs.py index c24bcb0..fc81d9f 100644 --- a/tests/unit/test_cli_logs.py +++ b/tests/unit/test_cli_logs.py @@ -6,35 +6,13 @@ from __future__ import annotations -import gzip from pathlib import Path from unittest.mock import patch from click.testing import CliRunner from mlx_stack.cli.logs import logs - -# --------------------------------------------------------------------------- # -# Helpers -# --------------------------------------------------------------------------- # - - -def _create_log(logs_dir: Path, service: str, content: str = "") -> Path: - """Create a log file for a service.""" - logs_dir.mkdir(parents=True, exist_ok=True) - log_path = logs_dir / f"{service}.log" - log_path.write_text(content) - return log_path - - -def _create_archive(logs_dir: Path, service: str, number: int, content: str) -> Path: - """Create a gzip archive for a service.""" - logs_dir.mkdir(parents=True, exist_ok=True) - archive_path = logs_dir / f"{service}.log.{number}.gz" - with gzip.open(str(archive_path), "wb") as f: - f.write(content.encode("utf-8")) - return archive_path - +from tests.factories import create_archive, create_log_file # --------------------------------------------------------------------------- # # Listing (no args) @@ -46,13 +24,16 @@ class TestListLogs: def test_lists_log_files(self, mlx_stack_home: Path) -> None: """Lists available log files with sizes and times.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "some log content here\n") - _create_log(logs_dir, "litellm", "litellm output\n") + create_log_file(logs_dir, "fast", "some log content here\n") + create_log_file(logs_dir, "litellm", "litellm output\n") + # Act runner = CliRunner() result = runner.invoke(logs, []) + # Assert assert result.exit_code == 0 assert "fast.log" in result.output assert "litellm.log" in result.output @@ -87,13 +68,16 @@ class TestViewLogs: def test_shows_tail_default(self, mlx_stack_home: Path) -> None: """Shows last 50 lines by default.""" + # Arrange logs_dir = mlx_stack_home / "logs" lines = [f"line {i}" for i in range(100)] - _create_log(logs_dir, "fast", "\n".join(lines)) + create_log_file(logs_dir, "fast", "\n".join(lines)) + # Act runner = CliRunner() result = runner.invoke(logs, ["fast"]) + # Assert assert result.exit_code == 0 output_lines = result.output.strip().splitlines() assert len(output_lines) == 50 @@ -104,7 +88,7 @@ def test_tail_n_lines(self, mlx_stack_home: Path) -> None: """--tail N shows exactly N lines.""" logs_dir = mlx_stack_home / "logs" lines = [f"line {i}" for i in range(50)] - _create_log(logs_dir, "fast", "\n".join(lines)) + create_log_file(logs_dir, "fast", "\n".join(lines)) runner = CliRunner() result = runner.invoke(logs, ["fast", "--tail", "10"]) @@ -116,7 +100,7 @@ def test_tail_n_lines(self, mlx_stack_home: Path) -> None: def test_empty_log(self, mlx_stack_home: Path) -> None: """Empty log shows informational message.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "") + create_log_file(logs_dir, "fast", "") runner = CliRunner() result = runner.invoke(logs, ["fast"]) @@ -136,8 +120,8 @@ class TestServiceFiltering: def test_service_argument(self, mlx_stack_home: Path) -> None: """Positional service argument shows correct log.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "fast content\n") - _create_log(logs_dir, "standard", "standard content\n") + create_log_file(logs_dir, "fast", "fast content\n") + create_log_file(logs_dir, "standard", "standard content\n") runner = CliRunner() result = runner.invoke(logs, ["fast"]) @@ -148,7 +132,7 @@ def test_service_argument(self, mlx_stack_home: Path) -> None: def test_service_flag(self, mlx_stack_home: Path) -> None: """--service flag shows correct log.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "fast content\n") + create_log_file(logs_dir, "fast", "fast content\n") runner = CliRunner() result = runner.invoke(logs, ["--service", "fast"]) @@ -159,7 +143,7 @@ def test_service_flag(self, mlx_stack_home: Path) -> None: def test_invalid_service_error(self, mlx_stack_home: Path) -> None: """Invalid service name produces clear error.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "content") + create_log_file(logs_dir, "fast", "content") runner = CliRunner() result = runner.invoke(logs, ["nonexistent"]) @@ -228,7 +212,7 @@ def test_rotate_specific_service(self, mlx_stack_home: Path) -> None: def test_rotate_no_rotation_needed(self, mlx_stack_home: Path) -> None: """Reports 'no rotation needed' when files are small.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "small content") + create_log_file(logs_dir, "fast", "small content") with patch( "mlx_stack.core.log_viewer.get_value", @@ -251,7 +235,7 @@ def test_rotate_no_logs(self, mlx_stack_home: Path) -> None: def test_rotate_invalid_service(self, mlx_stack_home: Path) -> None: """--rotate with invalid --service shows error.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "content") + create_log_file(logs_dir, "fast", "content") runner = CliRunner() result = runner.invoke(logs, ["--rotate", "--service", "invalid"]) @@ -271,19 +255,21 @@ class TestAllLogs: def test_shows_archives_and_current(self, mlx_stack_home: Path) -> None: """--all shows archived and current logs chronologically.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_archive(logs_dir, "fast", 2, "oldest archived") - _create_archive(logs_dir, "fast", 1, "newest archived") - _create_log(logs_dir, "fast", "current content") + create_archive(logs_dir, "fast", 2, "oldest archived") + create_archive(logs_dir, "fast", 1, "newest archived") + create_log_file(logs_dir, "fast", "current content") + # Act runner = CliRunner() result = runner.invoke(logs, ["fast", "--all"]) + # Assert assert result.exit_code == 0 assert "oldest archived" in result.output assert "newest archived" in result.output assert "current content" in result.output - # Check chronological order oldest_pos = result.output.index("oldest archived") newest_pos = result.output.index("newest archived") current_pos = result.output.index("current content") @@ -304,8 +290,8 @@ def test_all_no_logs(self, mlx_stack_home: Path) -> None: def test_all_archives_only_no_current(self, mlx_stack_home: Path) -> None: """--all shows archives even when current log file is missing.""" logs_dir = mlx_stack_home / "logs" - _create_archive(logs_dir, "fast", 2, "old archived content") - _create_archive(logs_dir, "fast", 1, "recent archived content") + create_archive(logs_dir, "fast", 2, "old archived content") + create_archive(logs_dir, "fast", 1, "recent archived content") # No fast.log current file runner = CliRunner() @@ -377,13 +363,16 @@ class TestEdgeCases: def test_service_argument_takes_precedence_over_flag(self, mlx_stack_home: Path) -> None: """Positional argument takes precedence over --service flag.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "fast content\n") - _create_log(logs_dir, "standard", "standard content\n") + create_log_file(logs_dir, "fast", "fast content\n") + create_log_file(logs_dir, "standard", "standard content\n") + # Act runner = CliRunner() result = runner.invoke(logs, ["fast", "--service", "standard"]) + # Assert assert result.exit_code == 0 assert "fast content" in result.output diff --git a/tests/unit/test_cli_models.py b/tests/unit/test_cli_models.py index 42894c8..1da1c50 100644 --- a/tests/unit/test_cli_models.py +++ b/tests/unit/test_cli_models.py @@ -21,12 +21,9 @@ from mlx_stack.cli.main import cli from mlx_stack.core.catalog import ( BenchmarkResult, - Capabilities, CatalogEntry, - QualityScores, QuantSource, ) -from mlx_stack.core.hardware import HardwareProfile from mlx_stack.core.models import ( format_size, get_models_directory, @@ -34,111 +31,66 @@ list_catalog_models, scan_local_models, ) +from tests.factories import make_entry, make_profile # --------------------------------------------------------------------------- # # Fixtures — reusable test data # --------------------------------------------------------------------------- # -def _make_profile( - chip: str = "Apple M4 Max", - gpu_cores: int = 40, - memory_gb: int = 128, - bandwidth_gbps: float = 546.0, - is_estimate: bool = False, -) -> HardwareProfile: - """Create a HardwareProfile for testing.""" - return HardwareProfile( - chip=chip, - gpu_cores=gpu_cores, - memory_gb=memory_gb, - bandwidth_gbps=bandwidth_gbps, - is_estimate=is_estimate, - ) - - -def _make_entry( - model_id: str = "test-model", - name: str = "Test Model", - family: str = "Test", - params_b: float = 8.0, - architecture: str = "transformer", - quality_overall: int = 70, - tool_calling: bool = True, - benchmarks: dict[str, BenchmarkResult] | None = None, - tags: list[str] | None = None, - memory_gb: float = 5.5, - disk_size_gb: float = 4.5, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - if benchmarks is None: - benchmarks = { - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=memory_gb), - "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=memory_gb), - } - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", - disk_size_gb=disk_size_gb, - ), - "int8": QuantSource( - hf_repo=f"mlx-community/{model_id}-8bit", - disk_size_gb=disk_size_gb * 2, - ), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser="hermes" if tool_calling else None, - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores( - overall=quality_overall, - coding=65, - reasoning=60, - instruction_following=72, - ), - benchmarks=benchmarks, - tags=tags or [], - ) - - def _make_test_catalog() -> list[CatalogEntry]: - """Create a test catalog with several models.""" + """Create a test catalog with several models. + + Uses multi-benchmark / multi-source entries needed by the models tests + (m4-max-128 + m4-pro-48 benchmarks, int4 + int8 sources). + """ return [ - _make_entry( + make_entry( model_id="qwen3.5-8b", name="Qwen 3.5 8B", family="Qwen 3.5", params_b=8.0, - memory_gb=5.5, disk_size_gb=4.5, tags=["balanced", "agent-ready"], + benchmarks={ + "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), + "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), + }, + sources={ + "int4": QuantSource(hf_repo="mlx-community/qwen3.5-8b-4bit", disk_size_gb=4.5), + "int8": QuantSource(hf_repo="mlx-community/qwen3.5-8b-8bit", disk_size_gb=9.0), + }, ), - _make_entry( + make_entry( model_id="nemotron-8b", name="Nemotron 8B", family="Nemotron", params_b=8.0, - memory_gb=5.0, disk_size_gb=4.2, tags=["agent-ready"], + benchmarks={ + "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.0), + "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.0), + }, + sources={ + "int4": QuantSource(hf_repo="mlx-community/nemotron-8b-4bit", disk_size_gb=4.2), + "int8": QuantSource(hf_repo="mlx-community/nemotron-8b-8bit", disk_size_gb=8.4), + }, ), - _make_entry( + make_entry( model_id="gemma3-12b", name="Gemma 3 12B", family="Gemma 3", params_b=12.0, - memory_gb=7.5, disk_size_gb=6.5, + benchmarks={ + "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=7.5), + "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=7.5), + }, + sources={ + "int4": QuantSource(hf_repo="mlx-community/gemma3-12b-4bit", disk_size_gb=6.5), + "int8": QuantSource(hf_repo="mlx-community/gemma3-12b-8bit", disk_size_gb=13.0), + }, ), ] @@ -278,15 +230,12 @@ def test_matches_catalog_names(self, mlx_stack_home: Path) -> None: def test_active_stack_indicator(self, mlx_stack_home: Path) -> None: """Marks models as active when referenced by the active stack.""" + # Arrange models_dir = mlx_stack_home / "models" stacks_dir = mlx_stack_home / "stacks" catalog = _make_test_catalog() - - # Create model dirs _create_model_dir(models_dir, "qwen3.5-8b-4bit", size_bytes=1000) _create_model_dir(models_dir, "random-model", size_bytes=1000) - - # Create stack referencing qwen3.5-8b stack_tiers = [ { "name": "fast", @@ -298,9 +247,11 @@ def test_active_stack_indicator(self, mlx_stack_home: Path) -> None: ] _create_stack_yaml(stacks_dir, stack_tiers) + # Act stack = yaml.safe_load((stacks_dir / "default.yaml").read_text()) result = scan_local_models(models_dir=models_dir, catalog=catalog, stack=stack) + # Assert active = {m.name: m.is_active for m in result} assert active["qwen3.5-8b-4bit"] is True assert active["random-model"] is False @@ -404,12 +355,14 @@ def test_lists_all_catalog_entries(self) -> None: def test_includes_hardware_benchmarks(self) -> None: """Includes gen_tps and memory_gb when profile matches catalog benchmarks.""" + # Arrange catalog = _make_test_catalog() - profile = _make_profile() + profile = make_profile() + # Act result = list_catalog_models(catalog=catalog, profile=profile, local_models=[]) - # All our test entries have m4-max-128 benchmarks + # Assert — all test entries have m4-max-128 benchmarks for cm in result: assert cm.gen_tps is not None assert cm.memory_gb is not None @@ -419,7 +372,7 @@ def test_estimated_for_unknown_hardware(self) -> None: """Shows estimated values when hardware doesn't match catalog benchmarks.""" catalog = _make_test_catalog() # Use a profile that doesn't match any catalog benchmark key - profile = _make_profile( + profile = make_profile( chip="Apple M3", memory_gb=24, bandwidth_gbps=100.0, @@ -443,13 +396,16 @@ def test_no_profile_no_benchmarks(self, mlx_stack_home: Path) -> None: def test_marks_local_models(self, mlx_stack_home: Path) -> None: """Indicates which catalog models are available locally.""" + # Arrange catalog = _make_test_catalog() models_dir = mlx_stack_home / "models" _create_model_dir(models_dir, "qwen3.5-8b-4bit", size_bytes=1000) - local_models = scan_local_models(models_dir=models_dir, catalog=catalog) + + # Act result = list_catalog_models(catalog=catalog, profile=None, local_models=local_models) + # Assert local_map = {cm.id: cm.is_local for cm in result} assert local_map["qwen3.5-8b"] is True assert local_map["nemotron-8b"] is False @@ -546,21 +502,25 @@ def test_lists_local_models_with_size_and_quant(self, mlx_stack_home: Path) -> N Verifies through the core scan function that models are discovered with correct size and quant, then checks CLI renders them. """ + # Arrange models_dir = mlx_stack_home / "models" _create_model_dir(models_dir, "Qwen3.5-8B-4bit", size_bytes=4_500_000_000) - # Verify core scanning works correctly + # Act — core scan local_models = scan_local_models(models_dir=models_dir, catalog=[]) + + # Assert — core data assert len(local_models) == 1 assert local_models[0].name == "Qwen3.5-8B-4bit" assert local_models[0].disk_size_bytes == 4_500_000_000 assert local_models[0].quant == "int4" - # Verify CLI renders something (table with Local Models title) + # Act — CLI rendering runner = CliRunner() with patch("mlx_stack.cli.models.load_catalog", return_value=[]): result = runner.invoke(cli, ["models"]) + # Assert — CLI output assert result.exit_code == 0 assert "Local Models" in result.output assert "Model" in result.output # Table header @@ -735,7 +695,7 @@ def test_shows_all_catalog_entries(self, mlx_stack_home: Path) -> None: def test_shows_hardware_specific_benchmarks(self, mlx_stack_home: Path) -> None: """VAL-MODELS-004: Shows benchmark data for current hardware.""" catalog = _make_test_catalog() - profile = _make_profile() + profile = make_profile() runner = CliRunner() with ( @@ -757,7 +717,7 @@ def test_estimated_values_with_tilde(self, mlx_stack_home: Path) -> None: """VAL-MODELS-004: Unknown hardware shows estimated values labeled as estimates.""" catalog = _make_test_catalog() # Profile that doesn't match any benchmark key - profile = _make_profile( + profile = make_profile( chip="Apple M3", memory_gb=24, bandwidth_gbps=100.0, @@ -893,19 +853,19 @@ class TestCatalogFilters: def test_filter_by_family(self, mlx_stack_home: Path) -> None: """--family filters catalog to matching family only.""" catalog = [ - _make_entry( + make_entry( model_id="qwen-a", name="Qwen A", family="Qwen 3.5", tags=["balanced"], ), - _make_entry( + make_entry( model_id="qwen-b", name="Qwen B", family="Qwen 3.5", tags=["balanced"], ), - _make_entry( + make_entry( model_id="gemma-a", name="Gemma A", family="Gemma 3", @@ -928,7 +888,7 @@ def test_filter_by_family(self, mlx_stack_home: Path) -> None: def test_filter_by_family_case_insensitive(self, mlx_stack_home: Path) -> None: """--family is case-insensitive.""" catalog = [ - _make_entry(model_id="q1", name="Qwen Model", family="Qwen 3.5"), + make_entry(model_id="q1", name="Qwen Model", family="Qwen 3.5"), ] runner = CliRunner() @@ -944,12 +904,12 @@ def test_filter_by_family_case_insensitive(self, mlx_stack_home: Path) -> None: def test_filter_by_tag(self, mlx_stack_home: Path) -> None: """--tag filters catalog to models with the specified tag.""" catalog = [ - _make_entry( + make_entry( model_id="agent-model", name="Agent Model", tags=["agent-ready", "balanced"], ), - _make_entry( + make_entry( model_id="basic-model", name="Basic Model", tags=["balanced"], @@ -970,12 +930,12 @@ def test_filter_by_tag(self, mlx_stack_home: Path) -> None: def test_filter_by_tool_calling(self, mlx_stack_home: Path) -> None: """--tool-calling filters to tool-calling-capable models only.""" catalog = [ - _make_entry( + make_entry( model_id="with-tools", name="With Tools", tool_calling=True, ), - _make_entry( + make_entry( model_id="no-tools", name="No Tools", tool_calling=False, @@ -996,21 +956,21 @@ def test_filter_by_tool_calling(self, mlx_stack_home: Path) -> None: def test_combined_filters(self, mlx_stack_home: Path) -> None: """Multiple filters are applied together (AND logic).""" catalog = [ - _make_entry( + make_entry( model_id="match", name="Match Both", family="Qwen 3.5", tool_calling=True, tags=["agent-ready"], ), - _make_entry( + make_entry( model_id="family-only", name="Family Only", family="Qwen 3.5", tool_calling=False, tags=[], ), - _make_entry( + make_entry( model_id="tools-only", name="Tools Only", family="Gemma 3", @@ -1037,7 +997,7 @@ def test_combined_filters(self, mlx_stack_home: Path) -> None: def test_no_matches_message(self, mlx_stack_home: Path) -> None: """Shows informative message when no models match filters.""" catalog = [ - _make_entry(model_id="x", name="Some Model", family="Gemma 3"), + make_entry(model_id="x", name="Some Model", family="Gemma 3"), ] runner = CliRunner() @@ -1053,7 +1013,7 @@ def test_no_matches_message(self, mlx_stack_home: Path) -> None: def test_filter_flags_imply_catalog(self, mlx_stack_home: Path) -> None: """Using --family without --catalog still shows catalog (auto-enables).""" catalog = [ - _make_entry(model_id="q1", name="Qwen Model", family="Qwen 3.5"), + make_entry(model_id="q1", name="Qwen Model", family="Qwen 3.5"), ] runner = CliRunner() diff --git a/tests/unit/test_cli_pull.py b/tests/unit/test_cli_pull.py index cb2c1cf..81d2581 100644 --- a/tests/unit/test_cli_pull.py +++ b/tests/unit/test_cli_pull.py @@ -25,11 +25,7 @@ from mlx_stack.cli.main import cli from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, CatalogError, - QualityScores, QuantSource, ) from mlx_stack.core.pull import ( @@ -51,61 +47,29 @@ save_inventory, validate_quant, ) +from tests.factories import make_entry # --------------------------------------------------------------------------- # # Fixtures — reusable test data # --------------------------------------------------------------------------- # -def _make_entry( - model_id: str = "qwen3.5-8b", - name: str = "Qwen 3.5 8B", - family: str = "Qwen 3.5", - params_b: float = 8.0, - architecture: str = "transformer", - tool_calling: bool = True, - disk_size_gb: float = 4.5, - disk_size_gb_int8: float = 8.5, - disk_size_gb_bf16: float = 16.0, - gated: bool = False, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", - disk_size_gb=disk_size_gb, - ), - "int8": QuantSource( - hf_repo=f"mlx-community/{model_id}-8bit", - disk_size_gb=disk_size_gb_int8, - ), - "bf16": QuantSource( - hf_repo=f"Qwen/{name.replace(' ', '-')}", - disk_size_gb=disk_size_gb_bf16, - convert_from=True, - ), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser="hermes" if tool_calling else None, - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores(overall=68, coding=65, reasoning=62, instruction_following=72), - benchmarks={ - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), - }, - tags=["balanced", "agent-ready"], - gated=gated, - ) +# Pull-specific sources include int4, int8, and a bf16 conversion source. +_PULL_SOURCES = { + "int4": QuantSource( + hf_repo="mlx-community/qwen3.5-8b-4bit", + disk_size_gb=4.5, + ), + "int8": QuantSource( + hf_repo="mlx-community/qwen3.5-8b-8bit", + disk_size_gb=8.5, + ), + "bf16": QuantSource( + hf_repo="Qwen/Qwen-3.5-8B", + disk_size_gb=16.0, + convert_from=True, + ), +} # =========================================================================== # @@ -143,23 +107,56 @@ class TestResolveSource: def test_mlx_community_preferred_int4(self) -> None: """mlx-community source is used for int4 (convert_from=False).""" - entry = _make_entry() + # Arrange + entry = make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + ) + + # Act source, source_type = resolve_source(entry, "int4") + + # Assert assert source_type == "mlx-community" assert "mlx-community" in source.hf_repo assert source.convert_from is False def test_mlx_community_preferred_int8(self) -> None: """mlx-community source is used for int8.""" - entry = _make_entry() + # Arrange + entry = make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + ) + + # Act source, source_type = resolve_source(entry, "int8") + + # Assert assert source_type == "mlx-community" assert "mlx-community" in source.hf_repo def test_conversion_fallback_bf16(self) -> None: """bf16 with convert_from=True uses conversion source type.""" - entry = _make_entry() + # Arrange + entry = make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + ) + + # Act source, source_type = resolve_source(entry, "bf16") + + # Assert assert source_type == "converted" assert source.convert_from is True assert "Qwen/" in source.hf_repo @@ -168,28 +165,16 @@ def test_unavailable_quant_raises(self) -> None: """Requesting unavailable quant raises PullError.""" import pytest - # Create entry with only int4 source - entry = CatalogEntry( - id="test-model", - name="Test Model", - family="Test", - params_b=1.0, - architecture="transformer", - min_mlx_lm_version="0.22.0", + # Arrange -- entry with only int4 source (shared factory default) + entry = make_entry( + model_id="test-model", + tool_calling=False, sources={ "int4": QuantSource(hf_repo="mlx-community/test-4bit", disk_size_gb=1.0), }, - capabilities=Capabilities( - tool_calling=False, - tool_call_parser=None, - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores(overall=50, coding=50, reasoning=50, instruction_following=50), - benchmarks={}, - tags=[], ) + + # Act / Assert with pytest.raises(PullError, match="Quantization 'int8' is not available"): resolve_source(entry, "int8") @@ -410,7 +395,13 @@ def test_successful_download( mlx_stack_home: Path, ) -> None: """VAL-PULL-001: Successful pull downloads to correct location.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] result = pull_model("qwen3.5-8b", quant="int4", catalog=catalog) assert result.model_id == "qwen3.5-8b" @@ -428,7 +419,13 @@ def test_updates_inventory_after_download( mlx_stack_home: Path, ) -> None: """VAL-PULL-007: Model appears in inventory after pull.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] pull_model("qwen3.5-8b", quant="int4", catalog=catalog) inv = load_inventory() @@ -442,7 +439,13 @@ def test_invalid_model_id_raises(self, mlx_stack_home: Path) -> None: """VAL-PULL-009: Invalid model ID raises InvalidModelError.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(InvalidModelError, match="not found in catalog"): pull_model("nonexistent-model", catalog=catalog) @@ -450,7 +453,13 @@ def test_invalid_model_suggests_catalog(self, mlx_stack_home: Path) -> None: """VAL-PULL-009: Error message suggests models --catalog.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(InvalidModelError, match="models --catalog"): pull_model("bad-model", catalog=catalog) @@ -458,7 +467,13 @@ def test_invalid_quant_rejected(self, mlx_stack_home: Path) -> None: """VAL-PULL-003: Invalid quant is rejected.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(PullError, match="Invalid quantization"): pull_model("qwen3.5-8b", quant="int6", catalog=catalog) @@ -471,7 +486,13 @@ def test_insufficient_disk_space( """VAL-PULL-004: Insufficient space raises DiskSpaceError.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(DiskSpaceError, match="Insufficient disk space"): pull_model("qwen3.5-8b", quant="int4", catalog=catalog) @@ -484,7 +505,13 @@ def test_disk_space_error_shows_requirements( """VAL-PULL-004: Disk space error shows required vs available.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(DiskSpaceError, match=r"(?s)Required: 4\.5 GB.*Available: 2\.0 GB"): pull_model("qwen3.5-8b", quant="int4", catalog=catalog) @@ -497,7 +524,13 @@ def test_already_downloaded_detected( mlx_stack_home: Path, ) -> None: """VAL-PULL-006: Already-downloaded model skips re-download.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] # Create the model directory with a file to simulate existing download models_dir = mlx_stack_home / "models" @@ -518,7 +551,13 @@ def test_force_redownloads( mlx_stack_home: Path, ) -> None: """VAL-PULL-006: --force re-downloads existing model.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] # Create the model directory with a file models_dir = mlx_stack_home / "models" @@ -539,7 +578,13 @@ def test_uses_config_default_quant( mlx_stack_home: Path, ) -> None: """VAL-PULL-003: Without --quant, uses config default (int4).""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] result = pull_model("qwen3.5-8b", quant=None, catalog=catalog) assert result.quant == "int4" @@ -554,7 +599,13 @@ def test_uses_config_custom_quant( mlx_stack_home: Path, ) -> None: """VAL-PULL-003: Uses config default-quant when set to int8.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] result = pull_model("qwen3.5-8b", quant=None, catalog=catalog) assert result.quant == "int8" @@ -567,7 +618,13 @@ def test_bf16_uses_conversion( mlx_stack_home: Path, ) -> None: """VAL-PULL-002: bf16 with convert_from uses mlx_lm conversion.""" - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] result = pull_model("qwen3.5-8b", quant="bf16", catalog=catalog) assert result.source_type == "converted" @@ -583,7 +640,13 @@ def test_download_error_propagated( """VAL-PULL-008: Download error is propagated.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(DownloadError, match="Network error"): pull_model("qwen3.5-8b", quant="int4", catalog=catalog) @@ -598,7 +661,13 @@ def test_conversion_error_propagated( """VAL-PULL-010: Conversion error is propagated.""" import pytest - catalog = [_make_entry()] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with pytest.raises(ConversionError, match="Convert failed"): pull_model("qwen3.5-8b", quant="bf16", catalog=catalog) @@ -941,7 +1010,13 @@ def test_pull_successful( mlx_stack_home: Path, ) -> None: """VAL-PULL-001: Successful pull via CLI.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b", "--quant", "int4"]) @@ -956,7 +1031,13 @@ def test_pull_invalid_model( mlx_stack_home: Path, ) -> None: """VAL-PULL-009: Invalid model exits with error.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "nonexistent-model"]) @@ -971,7 +1052,13 @@ def test_pull_invalid_quant( mlx_stack_home: Path, ) -> None: """VAL-PULL-003: Invalid quant exits with error.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b", "--quant", "int6"]) @@ -987,7 +1074,13 @@ def test_pull_insufficient_space( mlx_stack_home: Path, ) -> None: """VAL-PULL-004: Insufficient space exits with clear error.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b"]) @@ -1005,7 +1098,13 @@ def test_pull_already_exists( mlx_stack_home: Path, ) -> None: """VAL-PULL-006: Already-downloaded model reports exists.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] # Create the model directory models_dir = mlx_stack_home / "models" @@ -1030,7 +1129,13 @@ def test_pull_force_redownloads( mlx_stack_home: Path, ) -> None: """VAL-PULL-006: --force re-downloads even when model exists.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] # Create existing model models_dir = mlx_stack_home / "models" @@ -1059,7 +1164,13 @@ def test_pull_network_error( mlx_stack_home: Path, ) -> None: """VAL-PULL-008: Network error shows clear message.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b"]) @@ -1080,7 +1191,13 @@ def test_pull_conversion_error( mlx_stack_home: Path, ) -> None: """VAL-PULL-010: Conversion error shows clear message.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b", "--quant", "bf16"]) @@ -1098,7 +1215,13 @@ def test_pull_updates_inventory( mlx_stack_home: Path, ) -> None: """VAL-PULL-007: After pull, model is in inventory.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b"]) @@ -1120,7 +1243,13 @@ def test_pull_shows_progress_info( mlx_stack_home: Path, ) -> None: """VAL-PULL-005: Progress-related info shown during download.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b"]) @@ -1145,7 +1274,13 @@ def test_pull_with_bench_flag( mlx_stack_home: Path, ) -> None: """VAL-CROSS-014: --bench flag runs benchmark and shows output.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] with patch("mlx_stack.core.benchmark.run_benchmark") as mock_bench: mock_bench.return_value = MagicMock( @@ -1175,7 +1310,13 @@ def test_pull_bench_calls_run_benchmark_with_save( """pull --bench calls run_benchmark(save=True) to persist results.""" from mlx_stack.core.benchmark import BenchmarkResult_ - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] mock_result = BenchmarkResult_( model_id="qwen3.5-8b", quant="int4", @@ -1203,7 +1344,13 @@ def test_pull_no_traceback_on_error( mlx_stack_home: Path, ) -> None: """No Python traceback shown for any error scenario.""" - mock_catalog.return_value = [_make_entry()] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + )] mock_download.side_effect = DownloadError("Test error") runner = CliRunner() @@ -1244,7 +1391,13 @@ def test_pulled_model_appears_in_models( mlx_stack_home: Path, ) -> None: """VAL-PULL-007: After pull, model appears in mlx-stack models output.""" - entry = _make_entry() + entry = make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + ) mock_pull_catalog.return_value = [entry] mock_models_catalog.return_value = [entry] @@ -1284,7 +1437,14 @@ def test_gated_model_without_token_raises( """Gated model without HF token raises GatedModelError.""" import pytest - catalog = [_make_entry(gated=True)] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + gated=True, + )] with pytest.raises(GatedModelError, match="requires HuggingFace authentication"): pull_model("qwen3.5-8b", quant="int4", catalog=catalog) @@ -1299,7 +1459,14 @@ def test_gated_model_with_token_proceeds( mlx_stack_home: Path, ) -> None: """Gated model with valid token proceeds to download.""" - catalog = [_make_entry(gated=True)] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + gated=True, + )] result = pull_model("qwen3.5-8b", quant="int4", catalog=catalog) assert result.already_existed is False mock_download.assert_called_once() @@ -1311,7 +1478,14 @@ def test_non_gated_model_skips_token_check( mlx_stack_home: Path, ) -> None: """Non-gated model does not check for token.""" - catalog = [_make_entry(gated=False)] + catalog = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + gated=False, + )] with patch("mlx_stack.core.pull.download_model"): with patch("mlx_stack.core.pull.check_disk_space", return_value=(True, 100.0)): pull_model("qwen3.5-8b", quant="int4", catalog=catalog) @@ -1357,7 +1531,14 @@ def test_cli_gated_error_shows_auth_required( mlx_stack_home: Path, ) -> None: """CLI shows 'Authentication required' for gated model errors.""" - mock_catalog.return_value = [_make_entry(gated=True)] + mock_catalog.return_value = [make_entry( + model_id="qwen3.5-8b", + name="Qwen 3.5 8B", + family="Qwen 3.5", + sources=_PULL_SOURCES, + tags=["balanced", "agent-ready"], + gated=True, + )] runner = CliRunner() result = runner.invoke(cli, ["pull", "qwen3.5-8b"]) diff --git a/tests/unit/test_cli_recommend.py b/tests/unit/test_cli_recommend.py index 2f4687a..1b8cb13 100644 --- a/tests/unit/test_cli_recommend.py +++ b/tests/unit/test_cli_recommend.py @@ -28,103 +28,24 @@ from mlx_stack.cli.main import cli from mlx_stack.cli.recommend import parse_budget -from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, - QualityScores, - QuantSource, -) -from mlx_stack.core.hardware import HardwareProfile +from mlx_stack.core.catalog import BenchmarkResult, CatalogEntry +from tests.factories import make_entry, make_profile # --------------------------------------------------------------------------- # -# Fixtures — reusable test data +# Recommend-specific catalog — 5 diverse models needed by most tests here # --------------------------------------------------------------------------- # -def _make_profile( - chip: str = "Apple M4 Max", - gpu_cores: int = 40, - memory_gb: int = 128, - bandwidth_gbps: float = 546.0, - is_estimate: bool = False, -) -> HardwareProfile: - """Create a HardwareProfile for testing.""" - return HardwareProfile( - chip=chip, - gpu_cores=gpu_cores, - memory_gb=memory_gb, - bandwidth_gbps=bandwidth_gbps, - is_estimate=is_estimate, - ) - - -def _make_small_profile() -> HardwareProfile: - """Create a small-memory profile (32 GB).""" - return HardwareProfile( - chip="Apple M4 Pro", - gpu_cores=20, - memory_gb=32, - bandwidth_gbps=273.0, - is_estimate=False, - ) - - -def _make_entry( - model_id: str = "test-model", - name: str = "Test Model", - family: str = "Test", - params_b: float = 8.0, - architecture: str = "transformer", - quality_overall: int = 70, - quality_coding: int = 65, - quality_reasoning: int = 60, - quality_instruction: int = 72, - tool_calling: bool = True, - tool_call_parser: str | None = "hermes", - thinking: bool = False, - benchmarks: dict[str, BenchmarkResult] | None = None, - tags: list[str] | None = None, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - if benchmarks is None: - benchmarks = { - "m4-pro-32": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), - } - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource(hf_repo=f"test/{model_id}-4bit", disk_size_gb=4.5), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser=tool_call_parser if tool_calling else None, - thinking=thinking, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores( - overall=quality_overall, - coding=quality_coding, - reasoning=quality_reasoning, - instruction_following=quality_instruction, - ), - benchmarks=benchmarks, - tags=tags or [], - ) +def _make_recommend_catalog() -> list[CatalogEntry]: + """Build a diverse test catalog for recommendation tests. - -def _make_test_catalog() -> list[CatalogEntry]: - """Build a diverse test catalog for recommendation tests.""" + This catalog has specific models with known quality/speed/architecture + characteristics that the tier-assignment tests depend on. It differs from + the shared ``make_test_catalog`` (2 generic models) so it is kept local. + """ return [ # High quality model (standard tier candidate) - _make_entry( + make_entry( model_id="high-quality-32b", name="High Quality 32B", family="Quality", @@ -141,7 +62,7 @@ def _make_test_catalog() -> list[CatalogEntry]: tags=["quality"], ), # Fast small model (fast tier candidate) - _make_entry( + make_entry( model_id="fast-0.8b", name="Fast 0.8B", family="Fast", @@ -158,7 +79,7 @@ def _make_test_catalog() -> list[CatalogEntry]: tags=["fast"], ), # Medium model - _make_entry( + make_entry( model_id="medium-8b", name="Medium 8B", family="Medium", @@ -175,7 +96,7 @@ def _make_test_catalog() -> list[CatalogEntry]: tags=["balanced"], ), # Longctx model (mamba2-hybrid architecture) - _make_entry( + make_entry( model_id="longctx-32b", name="LongCtx 32B", family="LongCtx", @@ -193,7 +114,7 @@ def _make_test_catalog() -> list[CatalogEntry]: tags=["long-context"], ), # Large model that only fits on big systems - _make_entry( + make_entry( model_id="huge-72b", name="Huge 72B", family="Huge", @@ -276,11 +197,15 @@ def test_models_within_budget( mlx_stack_home: Path, ) -> None: """All recommended models have memory ≤ computed budget.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 # Huge 72B requires 42 GB, budget is 128*0.4=51.2 GB, so it should appear # All smaller models should also appear @@ -295,11 +220,15 @@ def test_small_budget_excludes_large_models( mlx_stack_home: Path, ) -> None: """Models exceeding explicit budget are excluded.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--budget", "10gb"]) + + # Assert assert result.exit_code == 0 # Only models with memory ≤ 10 GB should appear assert "Fast 0.8B" in result.output @@ -326,11 +255,15 @@ def test_default_budget_64gb( mlx_stack_home: Path, ) -> None: """On 64 GB system, budget = 25.6 GB. 32B models (20 GB) fit.""" - mock_load_profile.return_value = _make_profile(memory_gb=64) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=64) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 # 25.6 GB budget: 20 GB models fit, 42 GB model doesn't assert "25.6 GB" in result.output @@ -345,11 +278,15 @@ def test_default_budget_128gb( mlx_stack_home: Path, ) -> None: """On 128 GB system, budget = 51.2 GB. All models fit.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) + + # Assert assert result.exit_code == 0 # 51.2 GB budget: all models fit (42 GB largest) assert "51.2 GB" in result.output @@ -373,12 +310,16 @@ def test_budget_override( mlx_stack_home: Path, ) -> None: """--budget 30gb overrides default on 64 GB machine.""" - mock_load_profile.return_value = _make_profile(memory_gb=64) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=64) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() # Default budget would be 25.6 GB; override to 30 GB result = runner.invoke(cli, ["recommend", "--budget", "30gb", "--show-all"]) + + # Assert assert result.exit_code == 0 assert "30.0 GB" in result.output # 20 GB models fit (they didn't need override, but 30 GB includes them) @@ -393,11 +334,15 @@ def test_budget_override_excludes_when_tight( mlx_stack_home: Path, ) -> None: """--budget 4gb excludes most models.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--budget", "4gb", "--show-all"]) + + # Assert assert result.exit_code == 0 # Only 0.6 GB model fits assert "Fast 0.8B" in result.output @@ -426,12 +371,16 @@ def test_balanced_vs_agent_fleet_different_tiers( instead of the intent-weighted composite score, so both intents produced identical tier assignments. """ - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result_balanced = runner.invoke(cli, ["recommend", "--intent", "balanced"]) result_agent = runner.invoke(cli, ["recommend", "--intent", "agent-fleet"]) + + # Assert assert result_balanced.exit_code == 0 assert result_agent.exit_code == 0 @@ -472,11 +421,15 @@ def test_standard_is_highest_composite_score( mlx_stack_home: Path, ) -> None: """Standard tier gets the model with the highest intent-weighted composite score.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 output_lines = result.output.split("\n") standard_line = [line for line in output_lines if "standard" in line.lower()] @@ -494,11 +447,15 @@ def test_fast_is_highest_tps( mlx_stack_home: Path, ) -> None: """Fast tier gets the highest gen_tps model (different from standard).""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 output_lines = result.output.split("\n") fast_line = [ @@ -519,11 +476,15 @@ def test_longctx_is_mamba2( mlx_stack_home: Path, ) -> None: """Longctx tier gets a mamba2-hybrid model if budget allows.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 output_lines = result.output.split("\n") longctx_line = [line for line in output_lines if "longctx" in line.lower()] @@ -548,11 +509,15 @@ def test_large_memory_three_tiers( mlx_stack_home: Path, ) -> None: """128 GB system gets up to 3 tiers.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "standard" in result.output.lower() assert "fast" in result.output.lower() @@ -567,10 +532,11 @@ def test_small_budget_fewer_tiers( mlx_stack_home: Path, ) -> None: """With very small budget, only 1-2 tiers available.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] # Use only small models in catalog catalog = [ - _make_entry( + make_entry( model_id="tiny-1b", name="Tiny 1B", quality_overall=30, @@ -581,9 +547,12 @@ def test_small_budget_fewer_tiers( ] mock_load_catalog.return_value = catalog # type: ignore[attr-defined] + # Act runner = CliRunner() # Budget 5gb - only one model fits result = runner.invoke(cli, ["recommend", "--budget", "5gb"]) + + # Assert assert result.exit_code == 0 # With only 1 model, can only have 1 tier assert "standard" in result.output.lower() @@ -607,12 +576,16 @@ def test_uses_existing_profile( mlx_stack_home: Path, ) -> None: """When profile.json exists, it is used.""" - profile = _make_profile(memory_gb=64) + # Arrange + profile = make_profile(memory_gb=64) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "64 GB" in result.output @@ -627,12 +600,16 @@ def test_auto_detects_when_no_profile( mlx_stack_home: Path, ) -> None: """When no profile.json, auto-detect in memory (no file write).""" + # Arrange mock_load_profile.return_value = None # type: ignore[attr-defined] - mock_detect.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_detect.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "detecting hardware" in result.output.lower() mock_detect.assert_called_once() # type: ignore[attr-defined] @@ -653,12 +630,16 @@ def test_hardware_detection_failure( """If auto-detect fails, exits with error.""" from mlx_stack.core.hardware import HardwareError + # Arrange mock_load_profile.return_value = None # type: ignore[attr-defined] mock_detect.side_effect = HardwareError("Not Apple Silicon") # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code != 0 assert "Not Apple Silicon" in result.output @@ -680,20 +661,22 @@ def test_estimated_label_shown( mlx_stack_home: Path, ) -> None: """Unknown hardware profile shows estimated labels and bench suggestion.""" - # Use a profile_id that doesn't match any catalog benchmarks - profile = _make_profile( + # Arrange — use a profile_id that doesn't match any catalog benchmarks + profile = make_profile( chip="Apple M6 Ultra", memory_gb=256, bandwidth_gbps=800.0, is_estimate=True, ) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert — should show "(est.)" or "estimated" in output assert result.exit_code == 0 - # Should show "(est.)" or "estimated" in output assert "est." in result.output.lower() assert "bench --save" in result.output @@ -715,11 +698,15 @@ def test_default_shows_tier_table( mlx_stack_home: Path, ) -> None: """Default output shows Recommended Stack with tier names.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "Recommended Stack" in result.output assert "standard" in result.output.lower() @@ -733,11 +720,15 @@ def test_show_all_lists_all_models( mlx_stack_home: Path, ) -> None: """--show-all shows all budget-fitting models sorted by score.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) + + # Assert assert result.exit_code == 0 assert "All Budget-Fitting Models" in result.output # All 5 models should appear (51.2 GB budget, all fit) @@ -756,8 +747,8 @@ def test_show_all_contains_score_column( mlx_stack_home: Path, ) -> None: """--show-all output includes Score column.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) @@ -773,8 +764,8 @@ def test_default_shows_memory_and_tps_columns( mlx_stack_home: Path, ) -> None: """Default tier output includes Gen TPS and Memory columns.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["recommend"]) @@ -804,8 +795,9 @@ def test_cloud_fallback_with_key( mlx_stack_home: Path, ) -> None: """Cloud fallback shown when OpenRouter key is set.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] def side_effect(key: str) -> object: if key == "openrouter-key": @@ -816,8 +808,11 @@ def side_effect(key: str) -> object: mock_get_value.side_effect = side_effect # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "Cloud Fallback" in result.output assert "OpenRouter" in result.output @@ -831,11 +826,15 @@ def test_no_cloud_fallback_without_key( mlx_stack_home: Path, ) -> None: """Cloud fallback NOT shown when no OpenRouter key.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "Cloud Fallback" not in result.output @@ -857,14 +856,17 @@ def test_no_stack_files_written( mlx_stack_home: Path, ) -> None: """Running recommend does not create stacks/ or litellm.yaml.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] - + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] stacks_dir = mlx_stack_home / "stacks" litellm_file = mlx_stack_home / "litellm.yaml" + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert not stacks_dir.exists() assert not litellm_file.exists() @@ -878,14 +880,17 @@ def test_no_files_with_show_all( mlx_stack_home: Path, ) -> None: """--show-all also does not write files.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] - + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] stacks_dir = mlx_stack_home / "stacks" litellm_file = mlx_stack_home / "litellm.yaml" + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) + + # Assert assert result.exit_code == 0 assert not stacks_dir.exists() assert not litellm_file.exists() @@ -899,11 +904,15 @@ def test_display_only_message( mlx_stack_home: Path, ) -> None: """Output includes display-only notice and init suggestion.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 assert "no files were written" in result.output.lower() assert "init" in result.output.lower() @@ -926,11 +935,15 @@ def test_zero_models_fitting_budget( mlx_stack_home: Path, ) -> None: """Budget too small for any model produces clear error.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--budget", "0.1gb"]) + + # Assert assert result.exit_code != 0 assert "no models fit" in result.output.lower() @@ -974,11 +987,15 @@ def test_catalog_load_failure( """Catalog load failure shows clear error.""" from mlx_stack.core.catalog import CatalogError - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] mock_load_catalog.side_effect = CatalogError("Corrupt catalog") # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code != 0 assert "Corrupt catalog" in result.output assert "Traceback" not in result.output @@ -1001,12 +1018,16 @@ def test_profile_id_used_for_benchmarks( mlx_stack_home: Path, ) -> None: """The profile's profile_id is used to look up catalog benchmarks.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 # Profile is m4-max-128, benchmark data for this profile should be used assert "Apple M4 Max" in result.output @@ -1029,9 +1050,10 @@ def test_saved_benchmarks_used( mlx_stack_home: Path, ) -> None: """Saved benchmark data overrides catalog data in scoring.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] # Write saved benchmarks benchmarks_dir = mlx_stack_home / "benchmarks" @@ -1045,8 +1067,11 @@ def test_saved_benchmarks_used( } (benchmarks_dir / f"{profile.profile_id}.json").write_text(json.dumps(saved_data)) + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) + + # Assert assert result.exit_code == 0 # The saved benchmark gen_tps (100.0) should be used instead of catalog (77.0) assert "100.0" in result.output @@ -1071,8 +1096,8 @@ def test_config_budget_pct_used( mlx_stack_home: Path, ) -> None: """memory-budget-pct from config is used when no --budget flag.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] def side_effect(key: str) -> object: if key == "memory-budget-pct": @@ -1083,8 +1108,11 @@ def side_effect(key: str) -> object: mock_get_value.side_effect = side_effect # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) + + # Assert assert result.exit_code == 0 # Budget should be 76.8 GB (60% of 128) assert "76.8 GB" in result.output @@ -1132,15 +1160,17 @@ def test_no_profile_written_on_auto_detect( mlx_stack_home: Path, ) -> None: """Auto-detection during recommend must NOT write profile.json.""" + # Arrange mock_load_profile.return_value = None # type: ignore[attr-defined] - mock_detect.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_detect.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend"]) - assert result.exit_code == 0 - # No profile.json should be written + # Assert + assert result.exit_code == 0 profile_path = mlx_stack_home / "profile.json" assert not profile_path.exists() @@ -1153,18 +1183,19 @@ def test_no_files_written_any_flag_combo( mlx_stack_home: Path, ) -> None: """No files created under any recommend flag combination.""" - mock_load_profile.return_value = _make_profile(memory_gb=128) # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] - import os + # Arrange + mock_load_profile.return_value = make_profile(memory_gb=128) # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] + files_before = set() for root, _dirs, files in os.walk(str(mlx_stack_home)): for f in files: files_before.add(os.path.join(root, f)) + # Act runner = CliRunner() - # Test multiple flag combos for flags in [ ["recommend"], ["recommend", "--show-all"], @@ -1174,6 +1205,7 @@ def test_no_files_written_any_flag_combo( result = runner.invoke(cli, flags) assert result.exit_code == 0 + # Assert files_after = set() for root, _dirs, files in os.walk(str(mlx_stack_home)): for f in files: @@ -1200,11 +1232,11 @@ def test_malformed_benchmark_json_warning( mlx_stack_home: Path, ) -> None: """Malformed numeric values in saved benchmarks fall through gracefully.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] - # Write malformed saved benchmarks with non-numeric gen_tps benchmarks_dir = mlx_stack_home / "benchmarks" benchmarks_dir.mkdir(parents=True) saved_data = { @@ -1216,9 +1248,11 @@ def test_malformed_benchmark_json_warning( } (benchmarks_dir / f"{profile.profile_id}.json").write_text(json.dumps(saved_data)) + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) - # Must not crash with ValueError traceback + + # Assert — must not crash with ValueError traceback assert result.exit_code == 0 assert "Traceback" not in result.output # Should still show recommendations from catalog data @@ -1233,16 +1267,19 @@ def test_corrupt_benchmark_file_warning( mlx_stack_home: Path, ) -> None: """Corrupt JSON file in benchmarks produces warning, not traceback.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile # type: ignore[attr-defined] - mock_load_catalog.return_value = _make_test_catalog() # type: ignore[attr-defined] + mock_load_catalog.return_value = _make_recommend_catalog() # type: ignore[attr-defined] - # Write corrupt JSON benchmarks_dir = mlx_stack_home / "benchmarks" benchmarks_dir.mkdir(parents=True) (benchmarks_dir / f"{profile.profile_id}.json").write_text("{{{invalid json") + # Act runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) + + # Assert assert result.exit_code == 0 assert "Traceback" not in result.output diff --git a/tests/unit/test_cli_status.py b/tests/unit/test_cli_status.py index 5ec82a6..ed0c422 100644 --- a/tests/unit/test_cli_status.py +++ b/tests/unit/test_cli_status.py @@ -30,78 +30,7 @@ run_status, status_to_dict, ) - -# --------------------------------------------------------------------------- # -# Fixtures — reusable test data -# --------------------------------------------------------------------------- # - - -def _make_stack_yaml( - tiers: list[dict[str, Any]] | None = None, - schema_version: int = 1, -) -> dict[str, Any]: - """Create a stack definition dict for testing.""" - if tiers is None: - tiers = [ - { - "name": "standard", - "model": "big-model", - "quant": "int4", - "source": "mlx-community/big-model-4bit", - "port": 8000, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - { - "name": "fast", - "model": "fast-model", - "quant": "int4", - "source": "mlx-community/fast-model-4bit", - "port": 8001, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - ] - return { - "schema_version": schema_version, - "name": "default", - "hardware_profile": "m4-max-128", - "intent": "balanced", - "created": "2026-03-24T00:00:00+00:00", - "tiers": tiers, - } - - -def _write_stack_yaml( - mlx_stack_home: Path, - stack: dict[str, Any] | None = None, -) -> Path: - """Write a stack YAML file and return its path.""" - if stack is None: - stack = _make_stack_yaml() - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack, default_flow_style=False)) - return stack_path - - -def _create_pid_file( - mlx_stack_home: Path, - service_name: str, - pid: int | str = 12345, -) -> Path: - """Create a PID file in the pids directory.""" - pids_dir = mlx_stack_home / "pids" - pids_dir.mkdir(parents=True, exist_ok=True) - pid_path = pids_dir / f"{service_name}.pid" - pid_path.write_text(str(pid)) - return pid_path - +from tests.factories import create_pid_file, make_stack_yaml, write_stack_yaml # --------------------------------------------------------------------------- # # Tests — _load_stack_for_status @@ -113,8 +42,13 @@ class TestLoadStackForStatus: def test_loads_valid_stack(self, mlx_stack_home: Path) -> None: """Valid stack definition is loaded successfully.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) + + # Act stack = _load_stack_for_status() + + # Assert assert stack is not None assert len(stack["tiers"]) == 2 @@ -125,28 +59,43 @@ def test_returns_none_when_no_stack(self, mlx_stack_home: Path) -> None: def test_returns_none_on_invalid_yaml(self, mlx_stack_home: Path) -> None: """Returns None when the stack YAML is malformed.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text("{{{invalid yaml") + + # Act stack = _load_stack_for_status() + + # Assert assert stack is None def test_returns_none_on_non_dict(self, mlx_stack_home: Path) -> None: """Returns None when the stack YAML is not a mapping.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text("- just a list item") + + # Act stack = _load_stack_for_status() + + # Assert assert stack is None def test_returns_none_on_missing_tiers(self, mlx_stack_home: Path) -> None: """Returns None when tiers are missing.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text( yaml.dump({"schema_version": 1, "name": "default"}) ) + + # Act stack = _load_stack_for_status() + + # Assert assert stack is None @@ -160,7 +109,10 @@ class TestRunStatus: def test_no_stack_returns_message(self, mlx_stack_home: Path) -> None: """VAL-STATUS-005: No stack configured reports suggestion.""" + # Act result = run_status() + + # Assert assert result.no_stack is True assert result.message is not None assert "init" in result.message.lower() @@ -175,8 +127,8 @@ def test_all_stopped_no_pid_files( mlx_stack_home: Path, ) -> None: """VAL-STATUS-005 / VAL-CROSS-002: No PID files → all stopped.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, @@ -184,8 +136,10 @@ def test_all_stopped_no_pid_files( "response_time": None, } + # Act result = run_status() + # Assert assert result.no_stack is False assert len(result.services) == 3 # 2 tiers + litellm for svc in result.services: @@ -201,8 +155,8 @@ def test_all_healthy( mlx_stack_home: Path, ) -> None: """VAL-CROSS-002: After up → all healthy.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "healthy", "pid": 12345, @@ -210,8 +164,10 @@ def test_all_healthy( "response_time": 0.05, } + # Act result = run_status() + # Assert assert result.no_stack is False assert len(result.services) == 3 for svc in result.services: @@ -229,7 +185,7 @@ def test_five_distinct_states( mlx_stack_home: Path, ) -> None: """VAL-STATUS-001: Five distinct states correctly classified.""" - # Create a stack with 4 tiers so we can show all 5 states + # Arrange — stack with 4 tiers so we can show all 5 states # (4 tiers + litellm = 5 services) tiers = [ { @@ -265,8 +221,8 @@ def test_five_distinct_states( "vllm_flags": {}, }, ] - stack = _make_stack_yaml(tiers=tiers) - _write_stack_yaml(mlx_stack_home, stack) + stack = make_stack_yaml(tiers=tiers) + write_stack_yaml(mlx_stack_home, stack) status_map = { "tier-healthy": { @@ -314,8 +270,10 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act result = run_status() + # Assert states = {svc.tier: svc.status for svc in result.services} assert states["tier-healthy"] == "healthy" assert states["tier-degraded"] == "degraded" @@ -332,7 +290,8 @@ def test_crashed_service_no_uptime( mlx_stack_home: Path, ) -> None: """VAL-STATUS-006: Stale PIDs detected as 'crashed', uptime is '-'.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -351,8 +310,10 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act result = run_status() + # Assert standard = next(s for s in result.services if s.tier == "standard") assert standard.status == "crashed" assert standard.pid == 99999 @@ -368,7 +329,8 @@ def test_mixed_states( mlx_stack_home: Path, ) -> None: """Mixed states: healthy + stopped services.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -387,8 +349,10 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act result = run_status() + # Assert statuses = {s.tier: s.status for s in result.services} assert statuses["standard"] == "healthy" assert statuses["fast"] == "stopped" @@ -402,8 +366,8 @@ def test_custom_litellm_port( mlx_stack_home: Path, ) -> None: """LiteLLM uses configured port for health checks.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, @@ -411,8 +375,10 @@ def test_custom_litellm_port( "response_time": None, } + # Act result = run_status() + # Assert litellm = next(s for s in result.services if s.tier == "litellm") assert litellm.port == 5001 @@ -425,8 +391,8 @@ def test_failure_on_one_service_does_not_block_others( mlx_stack_home: Path, ) -> None: """VAL-STATUS-002: Failure on one service doesn't prevent checks on others.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) call_count = 0 def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: @@ -448,9 +414,10 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act result = run_status() - # All 3 services should be checked + # Assert assert call_count == 3 statuses = {s.tier: s.status for s in result.services} assert statuses["standard"] == "down" @@ -476,6 +443,7 @@ def test_no_stack_serialization(self) -> None: def test_services_serialization(self) -> None: """VAL-STATUS-004: All fields present in JSON output.""" + # Arrange result = StatusResult( services=[ ServiceStatus( @@ -501,8 +469,10 @@ def test_services_serialization(self) -> None: ], ) + # Act data = status_to_dict(result) + # Assert assert len(data["services"]) == 2 assert data["no_stack"] is False assert data["message"] is None @@ -525,6 +495,7 @@ def test_services_serialization(self) -> None: def test_valid_json_roundtrip(self) -> None: """JSON output is parseable by json.loads.""" + # Arrange result = StatusResult( services=[ ServiceStatus( @@ -540,10 +511,12 @@ def test_valid_json_roundtrip(self) -> None: ], ) + # Act data = status_to_dict(result) json_str = json.dumps(data) parsed = json.loads(json_str) + # Assert assert parsed == data @@ -564,10 +537,10 @@ def test_no_pid_files_modified( mlx_stack_home: Path, ) -> None: """VAL-STATUS-007: Status does not modify PID files.""" - _write_stack_yaml(mlx_stack_home) - pid_path = _create_pid_file(mlx_stack_home, "standard", 99999) + # Arrange + write_stack_yaml(mlx_stack_home) + pid_path = create_pid_file(mlx_stack_home, "standard", 99999) original_content = pid_path.read_text() - mock_status.return_value = { "status": "crashed", "pid": 99999, @@ -575,9 +548,10 @@ def test_no_pid_files_modified( "response_time": None, } + # Act run_status() - # PID file should still exist with same content + # Assert assert pid_path.exists() assert pid_path.read_text() == original_content @@ -590,21 +564,21 @@ def test_no_lockfile_acquired( mlx_stack_home: Path, ) -> None: """VAL-STATUS-007: Status does not acquire lockfile.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, "uptime": None, "response_time": None, } - - # Create a lockfile to simulate concurrent operation lock_path = mlx_stack_home / "lock" lock_path.touch() - # Status should succeed regardless of lockfile state + # Act result = run_status() + + # Assert assert result.no_stack is False @patch("mlx_stack.core.stack_status.get_service_status") @@ -616,14 +590,12 @@ def test_no_files_created( mlx_stack_home: Path, ) -> None: """VAL-STATUS-007: Status does not create any new files.""" - _write_stack_yaml(mlx_stack_home) - - # Record files before + # Arrange + write_stack_yaml(mlx_stack_home) before_files = set() for p in mlx_stack_home.rglob("*"): if p.is_file(): before_files.add(str(p.relative_to(mlx_stack_home))) - mock_status.return_value = { "status": "stopped", "pid": None, @@ -631,14 +603,14 @@ def test_no_files_created( "response_time": None, } + # Act run_status() - # Record files after + # Assert after_files = set() for p in mlx_stack_home.rglob("*"): if p.is_file(): after_files.add(str(p.relative_to(mlx_stack_home))) - new_files = after_files - before_files assert new_files == set(), f"Status created files: {new_files}" @@ -668,8 +640,8 @@ def test_table_output_has_columns( mlx_stack_home: Path, ) -> None: """VAL-STATUS-003: Formatted table with expected columns.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "healthy", "pid": 12345, @@ -677,9 +649,11 @@ def test_table_output_has_columns( "response_time": 0.05, } + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 assert "Service Status" in result.output assert "Tier" in result.output @@ -697,8 +671,8 @@ def test_table_shows_tier_data( mlx_stack_home: Path, ) -> None: """VAL-STATUS-003: Table shows per-tier names, models, ports, and distinct statuses.""" - - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -724,9 +698,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 # Tier names and models present assert "standard" in result.output @@ -749,7 +725,8 @@ def test_table_shows_uptime( mlx_stack_home: Path, ) -> None: """VAL-STATUS-003: Uptime displayed in human-readable format.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -768,9 +745,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 assert "2h 15m" in result.output @@ -783,8 +762,8 @@ def test_stopped_shows_dash_uptime( mlx_stack_home: Path, ) -> None: """VAL-STATUS-003: Stopped services show '-' for uptime.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, @@ -792,12 +771,12 @@ def test_stopped_shows_dash_uptime( "response_time": None, } + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 - # All services stopped → uptime is "-" - # The table should contain dashes for uptime assert "stopped" in result.output @patch("mlx_stack.core.stack_status.get_service_status") @@ -809,8 +788,8 @@ def test_json_output_valid( mlx_stack_home: Path, ) -> None: """VAL-STATUS-004: --json produces valid parseable JSON.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "healthy", "pid": 12345, @@ -818,9 +797,11 @@ def test_json_output_valid( "response_time": 0.05, } + # Act runner = CliRunner() result = runner.invoke(cli, ["status", "--json"]) + # Assert assert result.exit_code == 0 data = json.loads(result.output) assert "services" in data @@ -836,8 +817,8 @@ def test_json_has_all_fields( mlx_stack_home: Path, ) -> None: """VAL-STATUS-004: JSON contains all required fields per service.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "healthy", "pid": 12345, @@ -845,12 +826,13 @@ def test_json_has_all_fields( "response_time": 0.05, } + # Act runner = CliRunner() result = runner.invoke(cli, ["status", "--json"]) + # Assert data = json.loads(result.output) required_fields = {"tier", "model", "port", "status", "uptime", "uptime_display", "pid"} - for svc in data["services"]: assert required_fields.issubset(svc.keys()), ( f"Missing fields: {required_fields - svc.keys()}" @@ -865,29 +847,23 @@ def test_json_and_table_consistent( mlx_stack_home: Path, ) -> None: """VAL-STATUS-004: JSON and table modes show consistent data.""" - _write_stack_yaml(mlx_stack_home) - - status_data = { + # Arrange + write_stack_yaml(mlx_stack_home) + mock_status.return_value = { "status": "healthy", "pid": 12345, "uptime": 3600.0, "response_time": 0.05, } - mock_status.return_value = status_data - runner = CliRunner() - # JSON output + # Act json_result = runner.invoke(cli, ["status", "--json"]) json_data = json.loads(json_result.output) - - # Table output table_result = runner.invoke(cli, ["status"]) - # Both should have same number of services + # Assert assert len(json_data["services"]) == 3 - - # JSON data should be reflected in the table for svc in json_data["services"]: assert svc["tier"] in table_result.output assert str(svc["port"]) in table_result.output @@ -912,7 +888,8 @@ def test_degraded_state_in_table( mlx_stack_home: Path, ) -> None: """VAL-STATUS-001: Degraded state displayed in table.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -931,9 +908,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 assert "degraded" in result.output @@ -946,7 +925,8 @@ def test_crashed_state_in_table( mlx_stack_home: Path, ) -> None: """VAL-STATUS-006: Crashed state displayed for stale PIDs.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -965,9 +945,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 assert "crashed" in result.output @@ -980,7 +962,8 @@ def test_down_state_in_table( mlx_stack_home: Path, ) -> None: """VAL-STATUS-001: Down state displayed for unreachable services.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": @@ -999,9 +982,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status"]) + # Assert assert result.exit_code == 0 assert "down" in result.output @@ -1048,8 +1033,8 @@ def test_before_up_all_stopped( mlx_stack_home: Path, ) -> None: """VAL-CROSS-002: Before up → all stopped.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, @@ -1057,7 +1042,10 @@ def test_before_up_all_stopped( "response_time": None, } + # Act result = run_status() + + # Assert assert all(s.status == "stopped" for s in result.services) @patch("mlx_stack.core.stack_status.get_service_status") @@ -1069,8 +1057,8 @@ def test_after_up_all_healthy( mlx_stack_home: Path, ) -> None: """VAL-CROSS-002: After up → all healthy.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "healthy", "pid": 12345, @@ -1078,7 +1066,10 @@ def test_after_up_all_healthy( "response_time": 0.05, } + # Act result = run_status() + + # Assert assert all(s.status == "healthy" for s in result.services) @patch("mlx_stack.core.stack_status.get_service_status") @@ -1090,8 +1081,8 @@ def test_after_down_all_stopped( mlx_stack_home: Path, ) -> None: """VAL-CROSS-002: After down → all stopped.""" - _write_stack_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) mock_status.return_value = { "status": "stopped", "pid": None, @@ -1099,7 +1090,10 @@ def test_after_down_all_stopped( "response_time": None, } + # Act result = run_status() + + # Assert assert all(s.status == "stopped" for s in result.services) @patch("mlx_stack.core.stack_status.get_service_status") @@ -1111,11 +1105,11 @@ def test_after_external_kill_crashed( mlx_stack_home: Path, ) -> None: """VAL-CROSS-002: After external kill → crashed.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: if service_name == "standard": - # Externally killed → PID file exists, process dead return { "status": "crashed", "pid": 99999, @@ -1131,8 +1125,10 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act result = run_status() + # Assert statuses = {s.tier: s.status for s in result.services} assert statuses["standard"] == "crashed" assert statuses["fast"] == "healthy" @@ -1148,11 +1144,15 @@ class TestEdgeCases: def test_empty_tiers_in_stack(self, mlx_stack_home: Path) -> None: """Stack with empty tiers list shows no-stack message.""" - stack = _make_stack_yaml() + # Arrange + stack = make_stack_yaml() stack["tiers"] = [] - _write_stack_yaml(mlx_stack_home, stack) + write_stack_yaml(mlx_stack_home, stack) + # Act result = run_status() + + # Assert assert result.no_stack is True @patch("mlx_stack.core.stack_status.get_service_status") @@ -1164,6 +1164,7 @@ def test_single_tier_stack( mlx_stack_home: Path, ) -> None: """Stack with a single tier works correctly.""" + # Arrange tiers = [ { "name": "fast", @@ -1174,9 +1175,8 @@ def test_single_tier_stack( "vllm_flags": {}, }, ] - stack = _make_stack_yaml(tiers=tiers) - _write_stack_yaml(mlx_stack_home, stack) - + stack = make_stack_yaml(tiers=tiers) + write_stack_yaml(mlx_stack_home, stack) mock_status.return_value = { "status": "healthy", "pid": 1001, @@ -1184,7 +1184,10 @@ def test_single_tier_stack( "response_time": 0.1, } + # Act result = run_status() + + # Assert assert len(result.services) == 2 # 1 tier + litellm assert result.services[0].tier == "fast" assert result.services[1].tier == "litellm" @@ -1198,6 +1201,7 @@ def test_three_tier_stack( mlx_stack_home: Path, ) -> None: """Stack with 3 tiers reports all + litellm.""" + # Arrange tiers = [ { "name": "standard", @@ -1224,9 +1228,8 @@ def test_three_tier_stack( "vllm_flags": {}, }, ] - stack = _make_stack_yaml(tiers=tiers) - _write_stack_yaml(mlx_stack_home, stack) - + stack = make_stack_yaml(tiers=tiers) + write_stack_yaml(mlx_stack_home, stack) mock_status.return_value = { "status": "stopped", "pid": None, @@ -1234,7 +1237,10 @@ def test_three_tier_stack( "response_time": None, } + # Act result = run_status() + + # Assert assert len(result.services) == 4 # 3 tiers + litellm @patch("mlx_stack.core.stack_status.get_service_status") @@ -1246,6 +1252,7 @@ def test_json_output_all_five_states( mlx_stack_home: Path, ) -> None: """VAL-STATUS-001/004: All five states present in JSON.""" + # Arrange tiers = [ { "name": f"t{i}", @@ -1257,9 +1264,8 @@ def test_json_output_all_five_states( } for i in range(4) ] - stack = _make_stack_yaml(tiers=tiers) - _write_stack_yaml(mlx_stack_home, stack) - + stack = make_stack_yaml(tiers=tiers) + write_stack_yaml(mlx_stack_home, stack) status_list = ["healthy", "degraded", "down", "crashed", "stopped"] def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: @@ -1274,9 +1280,11 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str mock_status.side_effect = side_effect + # Act runner = CliRunner() result = runner.invoke(cli, ["status", "--json"]) + # Assert assert result.exit_code == 0 data = json.loads(result.output) statuses = {s["tier"]: s["status"] for s in data["services"]} diff --git a/tests/unit/test_cli_up.py b/tests/unit/test_cli_up.py index 71e3992..434ae1e 100644 --- a/tests/unit/test_cli_up.py +++ b/tests/unit/test_cli_up.py @@ -1,4 +1,4 @@ -"""Tests for the `mlx-stack up` CLI command and core stack_up module. +"""Tests for the ``mlx-stack up`` CLI command and core ``stack_up`` module. Validates: - VAL-UP-001: Stack definition loaded and correct processes started @@ -27,28 +27,13 @@ from __future__ import annotations from pathlib import Path -from typing import Any from unittest.mock import MagicMock, patch import pytest -import yaml from click.testing import CliRunner from mlx_stack.cli.main import cli -from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, - QualityScores, - QuantSource, -) -from mlx_stack.core.deps import DependencyInstallError -from mlx_stack.core.process import ( - HealthCheckError, - HealthCheckResult, - LockError, - ServiceInfo, -) +from mlx_stack.core.process import LockError from mlx_stack.core.stack_up import ( LITELLM_SERVICE_NAME, TierStatus, @@ -63,181 +48,87 @@ run_up, sort_tiers_by_size, ) +from tests.factories import ( + create_pid_file, + make_stack_yaml, + make_test_catalog, + write_litellm_yaml, + write_stack_yaml, +) +from tests.fakes import FakeServiceLayer # --------------------------------------------------------------------------- # -# Fixtures — reusable test data -# --------------------------------------------------------------------------- # - - -def _make_stack_yaml( - tiers: list[dict[str, Any]] | None = None, - schema_version: int = 1, - litellm_port: int = 4000, -) -> dict[str, Any]: - """Create a stack definition dict for testing.""" - if tiers is None: - tiers = [ - { - "name": "standard", - "model": "big-model", - "quant": "int4", - "source": "mlx-community/big-model-4bit", - "port": 8000, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - "enable_auto_tool_choice": True, - "tool_call_parser": "hermes", - }, - }, - { - "name": "fast", - "model": "fast-model", - "quant": "int4", - "source": "mlx-community/fast-model-4bit", - "port": 8001, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - ] - return { - "schema_version": schema_version, - "name": "default", - "hardware_profile": "m4-max-128", - "intent": "balanced", - "created": "2026-03-24T00:00:00+00:00", - "tiers": tiers, - } - - -def _write_stack_yaml( - mlx_stack_home: Path, - stack: dict[str, Any] | None = None, -) -> Path: - """Write a stack YAML file and return its path.""" - if stack is None: - stack = _make_stack_yaml() - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack, default_flow_style=False)) - return stack_path - - -def _write_litellm_yaml(mlx_stack_home: Path) -> Path: - """Write a minimal litellm.yaml config.""" - litellm_config = { - "model_list": [ - { - "model_name": "standard", - "litellm_params": { - "model": "openai/big-model", - "api_base": "http://localhost:8000/v1", - "api_key": "dummy", - }, - }, - ], - } - litellm_path = mlx_stack_home / "litellm.yaml" - litellm_path.write_text(yaml.dump(litellm_config, default_flow_style=False)) - return litellm_path - - -def _make_entry( - model_id: str = "test-model", - params_b: float = 8.0, - memory_gb: float = 5.5, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - return CatalogEntry( - id=model_id, - name=f"Test {model_id}", - family="Test", - params_b=params_b, - architecture="transformer", - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource(hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=4.5), - }, - capabilities=Capabilities( - tool_calling=True, - tool_call_parser="hermes", - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores(overall=70, coding=65, reasoning=60, instruction_following=72), - benchmarks={ - "m4-max-128": BenchmarkResult(prompt_tps=100.0, gen_tps=50.0, memory_gb=memory_gb), - }, - tags=[], - ) - - -def _make_test_catalog() -> list[CatalogEntry]: - """Create a test catalog.""" - return [ - _make_entry("big-model", params_b=49.0, memory_gb=30.0), - _make_entry("fast-model", params_b=3.0, memory_gb=2.0), - ] - - -# --------------------------------------------------------------------------- # -# Tests — load_stack_definition +# Tests — load_stack_definition (real YAML, no mocks) # --------------------------------------------------------------------------- # class TestLoadStackDefinition: - """Tests for load_stack_definition.""" + """Tests for load_stack_definition — real filesystem operations.""" def test_loads_valid_stack(self, mlx_stack_home: Path) -> None: """VAL-UP-001: Stack definition loaded.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) + + # Act stack = load_stack_definition() + + # Assert assert stack["schema_version"] == 1 assert len(stack["tiers"]) == 2 def test_missing_stack_suggests_init(self, mlx_stack_home: Path) -> None: """VAL-UP-011: Missing stack definition error.""" + # Act / Assert with pytest.raises(UpError, match="mlx-stack init"): load_stack_definition() def test_invalid_yaml_produces_clear_error(self, mlx_stack_home: Path) -> None: """VAL-UP-011: Invalid YAML produces clear error.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text("{{{invalid yaml") + + # Act / Assert with pytest.raises(UpError, match="Invalid YAML"): load_stack_definition() def test_unsupported_schema_version(self, mlx_stack_home: Path) -> None: """VAL-UP-011 / VAL-CROSS-013: Unsupported schema_version.""" - stack = _make_stack_yaml(schema_version=99) - _write_stack_yaml(mlx_stack_home, stack) + # Arrange + stack = make_stack_yaml(schema_version=99) + write_stack_yaml(mlx_stack_home, stack) + + # Act / Assert with pytest.raises(UpError, match="schema_version"): load_stack_definition() def test_empty_tiers_error(self, mlx_stack_home: Path) -> None: """Stack with no tiers raises error.""" - stack = _make_stack_yaml() + # Arrange + stack = make_stack_yaml() stack["tiers"] = [] - _write_stack_yaml(mlx_stack_home, stack) + write_stack_yaml(mlx_stack_home, stack) + + # Act / Assert with pytest.raises(UpError, match="no tiers"): load_stack_definition() def test_non_dict_file(self, mlx_stack_home: Path) -> None: """Non-mapping YAML file produces clear error.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text("- just a list") + + # Act / Assert with pytest.raises(UpError, match="invalid format"): load_stack_definition() # --------------------------------------------------------------------------- # -# Tests — build_vllm_command +# Tests — build_vllm_command (pure function) # --------------------------------------------------------------------------- # @@ -246,6 +137,7 @@ class TestBuildVllmCommand: def test_basic_command(self) -> None: """VAL-UP-015: Services bind to localhost only.""" + # Arrange tier = { "name": "fast", "model": "fast-model", @@ -256,7 +148,11 @@ def test_basic_command(self) -> None: "use_paged_cache": True, }, } + + # Act cmd = build_vllm_command(tier, "/usr/local/bin/vllm-mlx") + + # Assert assert cmd[0] == "/usr/local/bin/vllm-mlx" assert cmd[1] == "serve" assert "mlx-community/fast-model-4bit" in cmd @@ -269,6 +165,7 @@ def test_basic_command(self) -> None: def test_serve_subcommand_with_model_positional(self) -> None: """vllm-mlx uses 'serve' subcommand with model as positional arg.""" + # Arrange tier = { "name": "fast", "model": "test-model", @@ -276,16 +173,19 @@ def test_serve_subcommand_with_model_positional(self) -> None: "port": 8001, "vllm_flags": {}, } + + # Act cmd = build_vllm_command(tier, "vllm-mlx") - # Command format: vllm-mlx serve --port N --host 127.0.0.1 + + # Assert assert cmd[0] == "vllm-mlx" assert cmd[1] == "serve" assert cmd[2] == "mlx-community/test-model-4bit" - # --model flag should NOT be present assert "--model" not in cmd def test_tool_calling_flags(self) -> None: """VAL-CROSS-013: vllm_flags translate correctly to CLI flags.""" + # Arrange tier = { "name": "standard", "model": "tool-model", @@ -298,7 +198,11 @@ def test_tool_calling_flags(self) -> None: "tool_call_parser": "hermes", }, } + + # Act cmd = build_vllm_command(tier, "vllm-mlx") + + # Assert assert "--enable-auto-tool-choice" in cmd assert "--tool-call-parser" in cmd idx = cmd.index("--tool-call-parser") @@ -306,6 +210,7 @@ def test_tool_calling_flags(self) -> None: def test_boolean_false_flags_excluded(self) -> None: """Boolean False flags are not included in command.""" + # Arrange tier = { "name": "test", "model": "test", @@ -316,13 +221,17 @@ def test_boolean_false_flags_excluded(self) -> None: "some_disabled_flag": False, }, } + + # Act cmd = build_vllm_command(tier, "vllm-mlx") + + # Assert assert "--some-disabled-flag" not in cmd assert "--continuous-batching" in cmd # --------------------------------------------------------------------------- # -# Tests — build_litellm_command +# Tests — build_litellm_command (pure function) # --------------------------------------------------------------------------- # @@ -331,11 +240,14 @@ class TestBuildLitellmCommand: def test_basic_command(self) -> None: """VAL-UP-015: LiteLLM binds to localhost only.""" + # Act cmd = build_litellm_command( "/usr/local/bin/litellm", 4000, Path("/home/user/.mlx-stack/litellm.yaml"), ) + + # Assert assert cmd[0] == "/usr/local/bin/litellm" assert "--config" in cmd assert "/home/user/.mlx-stack/litellm.yaml" in cmd @@ -346,7 +258,7 @@ def test_basic_command(self) -> None: # --------------------------------------------------------------------------- # -# Tests — format_dry_run_command +# Tests — format_dry_run_command (pure function) # --------------------------------------------------------------------------- # @@ -355,20 +267,30 @@ class TestFormatDryRunCommand: def test_basic_format(self) -> None: """Commands formatted as space-separated string.""" + # Arrange cmd = ["vllm-mlx", "--model", "test", "--port", "8000"] + + # Act result = format_dry_run_command(cmd) + + # Assert assert result == "vllm-mlx --model test --port 8000" def test_env_vars_masked(self) -> None: """VAL-UP-017: API key not visible in --dry-run output.""" + # Arrange cmd = ["litellm", "--config", "litellm.yaml"] + + # Act result = format_dry_run_command(cmd, {"OPENROUTER_API_KEY": "sk-secret"}) + + # Assert assert "sk-secret" not in result assert "OPENROUTER_API_KEY=***" in result # --------------------------------------------------------------------------- # -# Tests — sort_tiers_by_size +# Tests — sort_tiers_by_size (pure function) # --------------------------------------------------------------------------- # @@ -377,22 +299,32 @@ class TestSortTiersBySize: def test_largest_first(self) -> None: """VAL-UP-001: Tiers started in descending params_b order.""" - catalog = _make_test_catalog() + # Arrange + catalog = make_test_catalog() tiers = [ {"name": "fast", "model": "fast-model", "port": 8001}, {"name": "standard", "model": "big-model", "port": 8000}, ] + + # Act sorted_tiers = sort_tiers_by_size(tiers, catalog) + + # Assert assert sorted_tiers[0]["name"] == "standard" # 49B first assert sorted_tiers[1]["name"] == "fast" # 3B second def test_no_catalog_preserves_order(self) -> None: """Without catalog, original order is preserved.""" + # Arrange tiers = [ {"name": "fast", "model": "fast-model", "port": 8001}, {"name": "standard", "model": "big-model", "port": 8000}, ] + + # Act sorted_tiers = sort_tiers_by_size(tiers, None) + + # Assert assert sorted_tiers[0]["name"] == "fast" @@ -406,36 +338,287 @@ class TestMemoryEstimation: def test_estimates_from_catalog(self) -> None: """VAL-UP-016: Memory estimate from catalog data.""" - catalog = _make_test_catalog() + # Arrange + catalog = make_test_catalog() tiers = [ {"name": "standard", "model": "big-model"}, {"name": "fast", "model": "fast-model"}, ] + + # Act total = estimate_memory_usage(tiers, catalog) + + # Assert assert total == pytest.approx(32.0, abs=1.0) def test_unknown_model_skipped(self) -> None: """Unknown models contribute 0 to estimate.""" - catalog = _make_test_catalog() + # Arrange + catalog = make_test_catalog() tiers = [{"name": "unknown", "model": "nonexistent"}] + + # Act total = estimate_memory_usage(tiers, catalog) + + # Assert assert total == 0.0 def test_warning_when_exceeds_available(self) -> None: """VAL-UP-016: Warning when estimate exceeds available memory.""" + # Arrange with patch("mlx_stack.core.stack_up.psutil.virtual_memory") as mock_vmem: - mock_vmem.return_value = MagicMock(available=10 * 1024**3) # 10 GB + mock_vmem.return_value = MagicMock(available=10 * 1024**3) + + # Act warning = check_memory_warning(20.0) - assert warning is not None - assert "20.0 GB" in warning - assert "10.0 GB" in warning + + # Assert + assert warning is not None + assert "20.0 GB" in warning + assert "10.0 GB" in warning def test_no_warning_when_fits(self) -> None: """No warning when estimate fits in available memory.""" + # Arrange with patch("mlx_stack.core.stack_up.psutil.virtual_memory") as mock_vmem: - mock_vmem.return_value = MagicMock(available=100 * 1024**3) # 100 GB + mock_vmem.return_value = MagicMock(available=100 * 1024**3) + + # Act warning = check_memory_warning(20.0) - assert warning is None + + # Assert + assert warning is None + + +# --------------------------------------------------------------------------- # +# Tests — run_up behavioral (FakeServiceLayer, no @patch stacks) +# --------------------------------------------------------------------------- # + + +class TestRunUp: + """Behavioral tests for ``run_up`` using ``FakeServiceLayer``. + + Each test writes real YAML to the isolated ``mlx_stack_home`` and + configures only the specific failure it tests. The fake defaults + produce a fully successful startup. + """ + + def test_successful_startup( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-001/004/005: Successful startup with PID files and LiteLLM.""" + # Arrange — defaults: all services start and pass health check + + # Act + result = run_up() + + # Assert + assert len(result.tiers) == 2 + assert all(t.status == "healthy" for t in result.tiers) + assert result.litellm is not None + assert result.litellm.status == "healthy" + assert len(fake_services.started) == 3 # 2 tiers + litellm + + def test_tier_filter_starts_only_one( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-009: --tier starts only the specified tier.""" + # Arrange — no special config needed + + # Act + result = run_up(tier_filter="fast") + + # Assert + tier_names = [t.name for t in result.tiers] + assert "fast" in tier_names + assert "standard" not in tier_names + + def test_port_conflict_skips_tier( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-012: Port conflict skips the affected tier.""" + # Arrange + fake_services.fail_port(8000, pid=54321, name="node") + + # Act + result = run_up() + + # Assert + skipped = [t for t in result.tiers if t.status == "skipped"] + assert len(skipped) == 1 + assert skipped[0].name == "standard" + assert "54321" in skipped[0].error + assert "node" in skipped[0].error + assert "8000" in skipped[0].error + + def test_port_conflict_unknown_owner( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-012: Port conflict with unknown owner still shows port.""" + # Arrange + fake_services.fail_port(8000, pid=0, name="") + + # Act + result = run_up() + + # Assert + skipped = [t for t in result.tiers if t.status == "skipped"] + assert len(skipped) == 1 + assert "8000" in (skipped[0].error or "") + assert "already in use" in (skipped[0].error or "") + + def test_health_check_timeout_continues( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-003/013: Health check timeout on one tier, other still starts.""" + # Arrange — fail health for standard (port 8000), fast (8001) succeeds + fake_services.fail_health(8000) + + # Act + result = run_up() + + # Assert + statuses = {t.name: t.status for t in result.tiers} + assert statuses["standard"] == "failed" + assert statuses["fast"] == "healthy" + + def test_all_fail_no_litellm( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-005: LiteLLM not started if all model servers fail.""" + # Arrange — all ports occupied + fake_services.fail_port(8000, pid=99, name="blocker") + fake_services.fail_port(8001, pid=99, name="blocker") + + # Act + result = run_up() + + # Assert + assert result.litellm is not None + assert result.litellm.status == "skipped" + assert "All model servers failed" in (result.litellm.error or "") + + def test_already_running_detection( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-014: Already-running services reported without restart.""" + # Arrange — create PID files and mark processes as alive + home = stack_on_disk + create_pid_file(home, "standard", 12345) + create_pid_file(home, "fast", 12346) + create_pid_file(home, "litellm", 12347) + fake_services.set_alive(12345, True) + fake_services.set_alive(12346, True) + fake_services.set_alive(12347, True) + + # Act + result = run_up() + + # Assert + assert result.already_running is True + assert all(t.status == "already-running" for t in result.tiers) + + def test_stale_pid_cleanup_and_restart( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-014 / VAL-CROSS-010: Stale PID cleaned up, service restarted.""" + # Arrange — PID files exist but processes are dead (default: not alive) + home = stack_on_disk + create_pid_file(home, "standard", 99999) + create_pid_file(home, "fast", 99998) + + # Act + result = run_up() + + # Assert + assert any("stale" in w.lower() for w in result.warnings) + assert any(t.status == "healthy" for t in result.tiers) + + def test_lockfile_prevents_concurrent( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-007: Lockfile prevents concurrent invocations.""" + # Arrange + fake_services.hold_lock() + + # Act / Assert + with pytest.raises(LockError, match="Lock held"): + run_up() + + def test_auto_install_failure( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-010: Auto-install failure produces clear error.""" + # Arrange + fake_services.fail_dependency("vllm-mlx") + + # Act / Assert + with pytest.raises(UpError, match="Dependency installation failed"): + run_up() + + def test_missing_model_skips_tier( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """Model not found on disk skips the tier.""" + # Arrange + fake_services.fail_model_check("standard", "Model not found on disk") + + # Act + result = run_up() + + # Assert + skipped = [t for t in result.tiers if t.status == "skipped"] + assert len(skipped) == 1 + assert skipped[0].name == "standard" + assert "not found" in skipped[0].error.lower() + # fast should still start + healthy = [t for t in result.tiers if t.status == "healthy"] + assert len(healthy) == 1 + assert healthy[0].name == "fast" + + def test_api_key_passed_via_env( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-017: API key passed via env var, not CLI args.""" + # Arrange + fake_services.set_config("openrouter-key", "sk-or-secret-key") + + # Act + result = run_up() + + # Assert — LiteLLM was started (appears in started list) + assert LITELLM_SERVICE_NAME in fake_services.started + assert result.litellm is not None + assert result.litellm.status == "healthy" + + def test_memory_warning_displayed( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-016: Memory estimate warning before startup.""" + # Arrange — fake low available memory + with patch("mlx_stack.core.stack_up.psutil.virtual_memory") as mock_vmem: + mock_vmem.return_value = MagicMock(available=5 * 1024**3) # 5 GB + + # Act + result = run_up() + + # Assert + assert any("memory" in w.lower() for w in result.warnings) + + def test_start_service_failure_continues( + self, stack_on_disk: Path, fake_services: FakeServiceLayer, + ) -> None: + """VAL-UP-013: Start failure on one tier, other still starts.""" + # Arrange + fake_services.fail_start("standard") + + # Act + result = run_up() + + # Assert + statuses = {t.name: t.status for t in result.tiers} + assert statuses["standard"] == "failed" + assert statuses["fast"] == "healthy" # --------------------------------------------------------------------------- # @@ -455,16 +638,19 @@ def test_dry_run_shows_commands( mlx_stack_home: Path, ) -> None: """VAL-UP-008: --dry-run shows commands without executing.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run"]) + + # Assert assert result.exit_code == 0 assert "Dry run" in result.output @@ -477,16 +663,19 @@ def test_dry_run_no_pid_files( mlx_stack_home: Path, ) -> None: """VAL-UP-008: No PID files after --dry-run.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() runner.invoke(cli, ["up", "--dry-run"]) + + # Assert pids_dir = mlx_stack_home / "pids" if pids_dir.exists(): assert list(pids_dir.glob("*.pid")) == [] @@ -500,16 +689,19 @@ def test_dry_run_no_log_files( mlx_stack_home: Path, ) -> None: """VAL-UP-008: No log files after --dry-run.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() runner.invoke(cli, ["up", "--dry-run"]) + + # Assert logs_dir = mlx_stack_home / "logs" if logs_dir.exists(): assert list(logs_dir.glob("*.log")) == [] @@ -523,16 +715,19 @@ def test_dry_run_shows_host_127( mlx_stack_home: Path, ) -> None: """VAL-UP-015: --dry-run confirms localhost binding.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run"]) + + # Assert assert "127.0.0.1" in result.output @patch("mlx_stack.core.stack_up.load_catalog") @@ -544,16 +739,19 @@ def test_dry_run_masks_api_key( mlx_stack_home: Path, ) -> None: """VAL-UP-017: API key not visible in --dry-run.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "sk-or-secret-key-12345", + "litellm-port": 4000, "openrouter-key": "sk-or-secret-key-12345", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run"]) + + # Assert assert "sk-or-secret-key-12345" not in result.output assert "OPENROUTER_API_KEY=***" in result.output @@ -566,18 +764,20 @@ def test_dry_run_tier_filter( mlx_stack_home: Path, ) -> None: """VAL-UP-009: --tier with --dry-run shows only that tier.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run", "--tier", "fast"]) + + # Assert assert result.exit_code == 0 - # Should show fast tier but not standard assert "fast" in result.output @patch("mlx_stack.core.stack_up.load_catalog") @@ -589,17 +789,19 @@ def test_dry_run_vllm_flags_in_commands( mlx_stack_home: Path, ) -> None: """VAL-CROSS-013: vllm_flags translate correctly to dry-run CLI flags.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run"]) - # The standard tier has enable_auto_tool_choice and tool_call_parser + + # Assert assert "--enable-auto-tool-choice" in result.output assert "--tool-call-parser" in result.output assert "hermes" in result.output @@ -615,724 +817,64 @@ class TestUpErrors: def test_missing_stack_error(self, mlx_stack_home: Path) -> None: """VAL-UP-011: Missing stack definition suggests init.""" + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code != 0 - output_text = result.output.lower() - assert "init" in output_text + assert "init" in result.output.lower() @patch("mlx_stack.core.stack_up.get_value") def test_invalid_tier_error( - self, - mock_get_value: MagicMock, - mlx_stack_home: Path, + self, mock_get_value: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-009: Invalid tier name errors with valid list.""" - _write_stack_yaml(mlx_stack_home) + # Arrange + write_stack_yaml(mlx_stack_home) mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--tier", "nonexistent"]) + + # Assert assert result.exit_code != 0 - # Should mention valid tiers assert "standard" in result.output or "fast" in result.output def test_invalid_yaml_error(self, mlx_stack_home: Path) -> None: """VAL-UP-011: Invalid YAML produces clear error.""" + # Arrange stacks_dir = mlx_stack_home / "stacks" stacks_dir.mkdir(parents=True, exist_ok=True) (stacks_dir / "default.yaml").write_text("{{{bad yaml") + + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code != 0 def test_unsupported_schema_error(self, mlx_stack_home: Path) -> None: """VAL-UP-011 / VAL-CROSS-013: Unsupported schema_version.""" - stack = _make_stack_yaml(schema_version=999) - _write_stack_yaml(mlx_stack_home, stack) + # Arrange + stack = make_stack_yaml(schema_version=999) + write_stack_yaml(mlx_stack_home, stack) + + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code != 0 assert "schema_version" in result.output # --------------------------------------------------------------------------- # -# Tests — run_up with mocked subprocess layer -# --------------------------------------------------------------------------- # - - -class TestRunUp: - """Tests for run_up with mocked process management.""" - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_successful_startup( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-001/004/005: Successful startup with PID files and LiteLLM.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="test", - pid=12345, - port=8000, - log_path=Path("/tmp/test.log"), - pid_path=Path("/tmp/test.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up() - assert len(result.tiers) == 2 - assert all(t.status == "healthy" for t in result.tiers) - assert result.litellm is not None - assert result.litellm.status == "healthy" - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_tier_filter_starts_only_one( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-009: --tier starts only the specified tier.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="fast", - pid=12345, - port=8001, - log_path=Path("/tmp/fast.log"), - pid_path=Path("/tmp/fast.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up(tier_filter="fast") - # Should have one tier (fast) + litellm - tier_names = [t.name for t in result.tiers] - assert "fast" in tier_names - assert "standard" not in tier_names - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict") - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_port_conflict_skips_tier( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-012/013: Port conflict skips tier, remaining still start.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - - # First tier has port conflict, second is fine - def port_conflict_side_effect(port: int) -> tuple[int, str] | None: - if port == 8000: - return (99999, "other-process") - return None - - mock_port_conflict.side_effect = port_conflict_side_effect - mock_start_service.return_value = ServiceInfo( - name="fast", - pid=12345, - port=8001, - log_path=Path("/tmp/fast.log"), - pid_path=Path("/tmp/fast.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up() - statuses = {t.name: t.status for t in result.tiers} - # Standard (port 8000) should be skipped - assert statuses["standard"] == "skipped" - # Fast should be healthy - assert statuses["fast"] == "healthy" - # LiteLLM should still start (at least one healthy tier) - assert result.litellm is not None - assert result.litellm.status == "healthy" - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict") - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_port_conflict_error_shows_pid_and_process( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-012: Port conflict error includes PID and process name.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - - # Port 8000 occupied by a known process - def port_conflict_side_effect(port: int) -> tuple[int, str] | None: - if port == 8000: - return (54321, "node") - return None - - mock_port_conflict.side_effect = port_conflict_side_effect - mock_start_service.return_value = ServiceInfo( - name="fast", - pid=12345, - port=8001, - log_path=Path("/tmp/fast.log"), - pid_path=Path("/tmp/fast.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up() - # Find the skipped tier - skipped = [t for t in result.tiers if t.status == "skipped"] - assert len(skipped) == 1 - assert skipped[0].name == "standard" - # Error message must include the conflicting PID and process name - assert skipped[0].error is not None - assert "54321" in skipped[0].error - assert "node" in skipped[0].error - assert "8000" in skipped[0].error - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict") - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_port_conflict_unknown_owner( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-012: Port conflict with unknown owner still shows port.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - - # Port occupied but owner unknown (e.g., macOS permission issue) - mock_port_conflict.side_effect = lambda port: (0, "") if port == 8000 else None - mock_start_service.return_value = ServiceInfo( - name="fast", - pid=12345, - port=8001, - log_path=Path("/tmp/fast.log"), - pid_path=Path("/tmp/fast.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up() - skipped = [t for t in result.tiers if t.status == "skipped"] - assert len(skipped) == 1 - assert "8000" in (skipped[0].error or "") - assert "already in use" in (skipped[0].error or "") - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_health_check_timeout_continues( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-003/013: Health check timeout reports error, continues.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="test", - pid=12345, - port=8000, - log_path=Path("/tmp/test.log"), - pid_path=Path("/tmp/test.pid"), - ) - - # First tier health check fails, second succeeds, LiteLLM succeeds - call_count = 0 - - def wait_healthy_side_effect(**kwargs: Any) -> HealthCheckResult: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise HealthCheckError("Timeout after 120s") - return HealthCheckResult(healthy=True, response_time=0.5, status_code=200) - - mock_wait_healthy.side_effect = wait_healthy_side_effect - - result = run_up() - statuses = {t.name: t.status for t in result.tiers} - # One tier should fail, one should be healthy - assert "failed" in statuses.values() - assert "healthy" in statuses.values() - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_all_fail_no_litellm( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-005: LiteLLM not started if all model servers fail.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - - # All ports conflict - mock_port_conflict.return_value = (99, "blocker") - - result = run_up() - assert result.litellm is not None - assert result.litellm.status == "skipped" - assert "All model servers failed" in (result.litellm.error or "") - - @patch("mlx_stack.core.stack_up.is_process_alive", return_value=True) - @patch("mlx_stack.core.stack_up.read_pid_file") - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_already_running_detection( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_alive: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-014: Already-running detection.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_read_pid.return_value = 12345 # All services have PIDs - - result = run_up() - assert result.already_running is True - assert all(t.status == "already-running" for t in result.tiers) - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.cleanup_stale_pid", return_value=True) - @patch("mlx_stack.core.stack_up.is_process_alive", return_value=False) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=12345) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_stale_pid_cleanup_and_restart( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_alive: MagicMock, - mock_cleanup_stale: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-014 / VAL-CROSS-010: Stale PID cleanup and fresh start.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="test", - pid=99999, - port=8000, - log_path=Path("/tmp/test.log"), - pid_path=Path("/tmp/test.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - result = run_up() - # Should have cleaned stale PIDs warning - assert any("stale" in w.lower() for w in result.warnings) - # Services should be started fresh - assert any(t.status == "healthy" for t in result.tiers) - - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - def test_lockfile_prevents_concurrent( - self, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_lock: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-007: Lockfile prevents concurrent invocations.""" - _write_stack_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_lock.side_effect = LockError("Another operation running") - - with pytest.raises(LockError, match="Another operation"): - run_up() - - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - def test_auto_install_failure( - self, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-010: Auto-install failure produces clear error.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_ensure_dep.side_effect = DependencyInstallError("Install failed") - - with pytest.raises(UpError, match="Dependency installation failed"): - run_up() - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_api_key_passed_via_env( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-017: API key passed via env var, not CLI args.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "sk-or-secret-key", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="test", - pid=12345, - port=8000, - log_path=Path("/tmp/test.log"), - pid_path=Path("/tmp/test.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - run_up() - - # Check that start_service was called for litellm with env dict - litellm_calls = [ - c - for c in mock_start_service.call_args_list - if c.kwargs.get("service_name") == LITELLM_SERVICE_NAME - or (c.args and c.args[0] == LITELLM_SERVICE_NAME) - ] - assert len(litellm_calls) >= 1 - # The LiteLLM call should have env with OPENROUTER_API_KEY - litellm_call = litellm_calls[0] - env_arg = litellm_call.kwargs.get("env") or ( - litellm_call.args[3] if len(litellm_call.args) > 3 else None - ) - assert env_arg is not None - assert "OPENROUTER_API_KEY" in env_arg - assert env_arg["OPENROUTER_API_KEY"] == "sk-or-secret-key" - - @patch("mlx_stack.core.stack_up.check_local_model_exists", return_value=None) - @patch("mlx_stack.core.stack_up.start_service") - @patch("mlx_stack.core.stack_up.wait_for_healthy") - @patch("mlx_stack.core.stack_up.check_port_conflict", return_value=None) - @patch("mlx_stack.core.stack_up.read_pid_file", return_value=None) - @patch("mlx_stack.core.stack_up.acquire_lock") - @patch("mlx_stack.core.stack_up.ensure_dependency") - @patch("mlx_stack.core.stack_up.load_catalog") - @patch("mlx_stack.core.stack_up.get_value") - @patch("mlx_stack.core.stack_up.shutil.which") - def test_memory_warning_displayed( - self, - mock_which: MagicMock, - mock_get_value: MagicMock, - mock_load_catalog: MagicMock, - mock_ensure_dep: MagicMock, - mock_lock: MagicMock, - mock_read_pid: MagicMock, - mock_port_conflict: MagicMock, - mock_wait_healthy: MagicMock, - mock_start_service: MagicMock, - mock_model_exists: MagicMock, - mlx_stack_home: Path, - ) -> None: - """VAL-UP-016: Memory estimate warning before startup.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() - mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", - }.get(key, "") - mock_which.side_effect = lambda x: f"/usr/local/bin/{x}" - mock_lock.return_value.__enter__ = MagicMock(return_value=None) - mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start_service.return_value = ServiceInfo( - name="test", - pid=12345, - port=8000, - log_path=Path("/tmp/test.log"), - pid_path=Path("/tmp/test.pid"), - ) - mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, - response_time=0.5, - status_code=200, - ) - - with patch("mlx_stack.core.stack_up.psutil.virtual_memory") as mock_vmem: - mock_vmem.return_value = MagicMock(available=5 * 1024**3) # Only 5 GB - result = run_up() - - # Should have a memory warning (catalog estimates ~32 GB) - assert any("memory" in w.lower() for w in result.warnings) - - -# --------------------------------------------------------------------------- # -# Tests — CLI output verification +# Tests — CLI output verification (mocks at CLI boundary — correct level) # --------------------------------------------------------------------------- # @@ -1341,26 +883,23 @@ class TestCLIOutput: @patch("mlx_stack.cli.up.run_up") def test_summary_table_displayed( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-006: Summary table shows tier, model, port, status.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ TierStatus(name="standard", model="big-model", port=8000, status="healthy"), TierStatus(name="fast", model="fast-model", port=8001, status="healthy"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="healthy", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="healthy"), ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code == 0 assert "standard" in result.output assert "fast" in result.output @@ -1370,175 +909,127 @@ def test_summary_table_displayed( @patch("mlx_stack.cli.up.run_up") def test_already_running_message( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-014: Already-running shows informational message.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ - TierStatus( - name="standard", - model="big-model", - port=8000, - status="already-running", - ), + TierStatus(name="standard", model="big-model", port=8000, status="already-running"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="already-running", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="already-running"), already_running=True, ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code == 0 assert "already running" in result.output.lower() @patch("mlx_stack.cli.up.run_up") def test_partial_failure_summary( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-013: Summary shows mixed states.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ - TierStatus( - name="standard", - model="big-model", - port=8000, - status="failed", - error="Health check timeout", - ), - TierStatus( - name="fast", - model="fast-model", - port=8001, - status="healthy", - ), + TierStatus(name="standard", model="big-model", port=8000, status="failed", error="Health check timeout"), + TierStatus(name="fast", model="fast-model", port=8001, status="healthy"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="healthy", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="healthy"), ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code == 0 assert "failed" in result.output assert "healthy" in result.output @patch("mlx_stack.cli.up.run_up") def test_all_failed_exit_code( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """All tiers failed produces non-zero exit code.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ - TierStatus( - name="standard", - model="big-model", - port=8000, - status="failed", - error="Port conflict", - ), + TierStatus(name="standard", model="big-model", port=8000, status="failed", error="Port conflict"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="skipped", - error="All model servers failed", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="skipped", error="All model servers failed"), ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code != 0 @patch("mlx_stack.cli.up.run_up") def test_lockfile_error_message( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-007: Lockfile error produces clear message.""" + # Arrange mock_run_up.side_effect = LockError("Another mlx-stack operation is already running") + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code != 0 assert "already running" in result.output.lower() @patch("mlx_stack.cli.up.run_up") def test_warning_displayed( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-016: Memory warning shown in output.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ TierStatus(name="standard", model="big-model", port=8000, status="healthy"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="healthy", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="healthy"), warnings=["Estimated memory usage (50.0 GB) exceeds available (10.0 GB)"], ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code == 0 assert "memory" in result.output.lower() @patch("mlx_stack.cli.up.run_up") def test_port_conflict_in_summary( - self, - mock_run_up: MagicMock, - mlx_stack_home: Path, + self, mock_run_up: MagicMock, mlx_stack_home: Path, ) -> None: """VAL-UP-012: Port conflict error with PID/process shown in CLI summary.""" + # Arrange mock_run_up.return_value = UpResult( tiers=[ - TierStatus( - name="standard", - model="big-model", - port=8000, - status="skipped", - error="Port 8000 already in use by PID 54321 (node)", - ), - TierStatus( - name="fast", - model="fast-model", - port=8001, - status="healthy", - ), + TierStatus(name="standard", model="big-model", port=8000, status="skipped", error="Port 8000 already in use by PID 54321 (node)"), + TierStatus(name="fast", model="fast-model", port=8001, status="healthy"), ], - litellm=TierStatus( - name="litellm", - model="proxy", - port=4000, - status="healthy", - ), + litellm=TierStatus(name="litellm", model="proxy", port=4000, status="healthy"), ) + # Act runner = CliRunner() result = runner.invoke(cli, ["up"]) + + # Assert assert result.exit_code == 0 - # Port conflict error should include PID and process name assert "54321" in result.output assert "node" in result.output assert "8000" in result.output @@ -1546,7 +1037,7 @@ def test_port_conflict_in_summary( # --------------------------------------------------------------------------- # -# Tests — config propagation +# Tests — config propagation (dry-run — only needs catalog + config mocks) # --------------------------------------------------------------------------- # @@ -1562,28 +1053,27 @@ def test_custom_litellm_port( mlx_stack_home: Path, ) -> None: """VAL-CROSS-007: litellm-port config propagates to up.""" - stack = _make_stack_yaml() - _write_stack_yaml(mlx_stack_home, stack) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 5001, - "openrouter-key": "", + "litellm-port": 5001, "openrouter-key": "", }.get(key, "") + # Act result = run_up(dry_run=True) - # LiteLLM should be on port 5001 + + # Assert assert result.litellm is not None assert result.litellm.port == 5001 - - # Dry-run commands should reference port 5001 litellm_cmds = [c for c in result.dry_run_commands if c["service"] == "litellm"] assert len(litellm_cmds) == 1 assert "5001" in litellm_cmds[0]["command"] # --------------------------------------------------------------------------- # -# Tests — init-generated ports match up behavior +# Tests — init → up consistency # --------------------------------------------------------------------------- # @@ -1599,18 +1089,19 @@ def test_ports_match_config( mlx_stack_home: Path, ) -> None: """VAL-CROSS-006: Init port assignments match actual startup ports.""" - stack = _make_stack_yaml() - _write_stack_yaml(mlx_stack_home, stack) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + stack = make_stack_yaml() + write_stack_yaml(mlx_stack_home, stack) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { - "litellm-port": 4000, - "openrouter-key": "", + "litellm-port": 4000, "openrouter-key": "", }.get(key, "") + # Act result = run_up(dry_run=True) - # Extract ports from dry-run commands + # Assert for tier in result.tiers: config_port = None for t in stack["tiers"]: @@ -1622,7 +1113,7 @@ def test_ports_match_config( # --------------------------------------------------------------------------- # -# Tests — lockfile cleanup on interrupt +# Tests — lockfile cleanup on interrupt (real lock, no mocks) # --------------------------------------------------------------------------- # @@ -1633,11 +1124,9 @@ def test_lockfile_released_after_error(self, mlx_stack_home: Path) -> None: """VAL-CROSS-016: Lockfile released even on error.""" from mlx_stack.core.process import acquire_lock - # Acquire and release lock successfully + # Act — acquire and release twice with acquire_lock(): pass - - # Should be able to acquire again (lock was released) with acquire_lock(): pass @@ -1645,12 +1134,13 @@ def test_lockfile_released_on_exception(self, mlx_stack_home: Path) -> None: """VAL-CROSS-016: Lockfile released on exception.""" from mlx_stack.core.process import acquire_lock + # Arrange try: with acquire_lock(): raise RuntimeError("Simulated crash") except RuntimeError: pass - # Should be able to acquire again + # Assert — can re-acquire with acquire_lock(): pass diff --git a/tests/unit/test_cross_area.py b/tests/unit/test_cross_area.py index 40eaf2e..0a86ce2 100644 --- a/tests/unit/test_cross_area.py +++ b/tests/unit/test_cross_area.py @@ -24,100 +24,27 @@ from mlx_stack.cli.main import cli from mlx_stack.core.catalog import ( BenchmarkResult, - Capabilities, CatalogEntry, - QualityScores, QuantSource, ) from mlx_stack.core.hardware import HardwareProfile from mlx_stack.core.pull import ModelInventoryEntry +from tests.factories import make_entry, make_profile # --------------------------------------------------------------------------- # # Shared test data helpers # --------------------------------------------------------------------------- # -def _make_profile( - chip: str = "Apple M4 Max", - gpu_cores: int = 40, - memory_gb: int = 128, - bandwidth_gbps: float = 546.0, - is_estimate: bool = False, -) -> HardwareProfile: - """Create a HardwareProfile for testing.""" - return HardwareProfile( - chip=chip, - gpu_cores=gpu_cores, - memory_gb=memory_gb, - bandwidth_gbps=bandwidth_gbps, - is_estimate=is_estimate, - ) - - -def _make_entry( - model_id: str = "test-model", - name: str = "Test Model", - family: str = "Test", - params_b: float = 8.0, - architecture: str = "transformer", - quality_overall: int = 70, - quality_coding: int = 65, - quality_reasoning: int = 60, - quality_instruction: int = 72, - tool_calling: bool = True, - tool_call_parser: str | None = "hermes", - thinking: bool = False, - reasoning_parser: str | None = None, - benchmarks: dict[str, BenchmarkResult] | None = None, - tags: list[str] | None = None, - disk_size_gb: float = 4.5, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - if benchmarks is None: - benchmarks = { - "m4-pro-32": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), - } - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", - disk_size_gb=disk_size_gb, - ), - "int8": QuantSource( - hf_repo=f"mlx-community/{model_id}-8bit", - disk_size_gb=disk_size_gb * 2, - ), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser=tool_call_parser if tool_calling else None, - thinking=thinking, - reasoning_parser=reasoning_parser, - vision=False, - ), - quality=QualityScores( - overall=quality_overall, - coding=quality_coding, - reasoning=quality_reasoning, - instruction_following=quality_instruction, - ), - benchmarks=benchmarks, - tags=tags or [], - ) - - def _make_test_catalog() -> list[CatalogEntry]: - """Build a diverse test catalog for cross-area tests.""" + """Build a diverse test catalog for cross-area tests. + + Uses multi-benchmark / multi-source entries (m4-pro-32 + m4-max-128 + benchmarks, int4 + int8 sources) needed by the cross-area tests. + """ return [ # High quality model (standard tier candidate) - _make_entry( + make_entry( model_id="high-quality-32b", name="High Quality 32B", family="Quality", @@ -133,9 +60,13 @@ def _make_test_catalog() -> list[CatalogEntry]: }, tags=["quality"], disk_size_gb=18.0, + sources={ + "int4": QuantSource(hf_repo="mlx-community/high-quality-32b-4bit", disk_size_gb=18.0), + "int8": QuantSource(hf_repo="mlx-community/high-quality-32b-8bit", disk_size_gb=36.0), + }, ), # Fast small model (fast tier candidate) - _make_entry( + make_entry( model_id="fast-0.8b", name="Fast 0.8B", family="Fast", @@ -151,9 +82,13 @@ def _make_test_catalog() -> list[CatalogEntry]: }, tags=["fast"], disk_size_gb=0.5, + sources={ + "int4": QuantSource(hf_repo="mlx-community/fast-0.8b-4bit", disk_size_gb=0.5), + "int8": QuantSource(hf_repo="mlx-community/fast-0.8b-8bit", disk_size_gb=1.0), + }, ), # Medium model - _make_entry( + make_entry( model_id="medium-8b", name="Medium 8B", family="Medium", @@ -169,9 +104,13 @@ def _make_test_catalog() -> list[CatalogEntry]: }, tags=["balanced"], disk_size_gb=4.5, + sources={ + "int4": QuantSource(hf_repo="mlx-community/medium-8b-4bit", disk_size_gb=4.5), + "int8": QuantSource(hf_repo="mlx-community/medium-8b-8bit", disk_size_gb=9.0), + }, ), # Longctx model (mamba2-hybrid architecture) - _make_entry( + make_entry( model_id="longctx-32b", name="LongCtx 32B", family="LongCtx", @@ -189,6 +128,10 @@ def _make_test_catalog() -> list[CatalogEntry]: }, tags=["longctx"], disk_size_gb=17.0, + sources={ + "int4": QuantSource(hf_repo="mlx-community/longctx-32b-4bit", disk_size_gb=17.0), + "int8": QuantSource(hf_repo="mlx-community/longctx-32b-8bit", disk_size_gb=34.0), + }, ), ] @@ -258,21 +201,21 @@ def test_init_creates_valid_stack_and_litellm_configs( mlx_stack_home: Path, ) -> None: """Init generates stack+litellm configs with consistent data.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() + # Act runner = CliRunner() result = runner.invoke(cli, ["init", "--accept-defaults"]) - assert result.exit_code == 0 - # Verify stack YAML created and valid + # Assert + assert result.exit_code == 0 stack = _read_stack_yaml(mlx_stack_home) assert stack["schema_version"] == 1 assert stack["intent"] == "balanced" assert len(stack["tiers"]) > 0 - - # Verify LiteLLM config created and valid litellm = _read_litellm_yaml(mlx_stack_home) assert "model_list" in litellm assert len(litellm["model_list"]) == len(stack["tiers"]) @@ -293,7 +236,7 @@ def test_init_then_up_dry_run_uses_consistent_ports( stack definition ports match LiteLLM config api_base ports and the up --dry-run command ports. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile catalog = _make_test_catalog() mock_catalog.return_value = catalog @@ -379,7 +322,7 @@ def test_full_lifecycle_init_pull_up_models_api_down( """ from mlx_stack.core.process import HealthCheckResult, ServiceInfo - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile catalog = _make_test_catalog() mock_init_catalog.return_value = catalog @@ -571,17 +514,15 @@ def test_litellm_port_5000_in_generated_litellm_yaml( Verifies the concrete port value appears in the LiteLLM config general_settings, not just that the command exits 0. """ - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() - runner = CliRunner() - # Set custom litellm-port to 5000 + # Act — set config then init result = runner.invoke(cli, ["config", "set", "litellm-port", "5000"]) assert result.exit_code == 0 - - # Run init result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -624,7 +565,7 @@ def test_memory_budget_pct_60_propagates_to_recommend( With 128 GB memory and 60% budget, the effective budget is 76.8 GB. Asserts the concrete value appears in recommend output. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile mock_load_catalog.return_value = _make_test_catalog() @@ -658,7 +599,7 @@ def test_memory_budget_pct_60_propagates_to_init( With 128 GB and 60%, budget is 76.8 GB. All selected models must fit within 76.8 GB each. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() @@ -754,7 +695,7 @@ def test_litellm_port_propagates_to_up_dry_run( mlx_stack_home: Path, ) -> None: """After config set litellm-port 5001, up --dry-run shows port 5001.""" - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile catalog = _make_test_catalog() mock_init_catalog.return_value = catalog @@ -796,7 +737,7 @@ def test_config_changes_across_init_regeneration( Sets litellm-port, runs init, changes memory-budget-pct, re-runs init --force, and verifies changes are reflected. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() @@ -847,7 +788,7 @@ def test_saved_gen_tps_85_overrides_catalog_77( After bench --save writes gen_tps=85 for medium-8b, recommend --show-all must display 85.0, not 77.0. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile mock_load_catalog.return_value = _make_test_catalog() @@ -956,7 +897,7 @@ def test_saved_benchmarks_affect_scoring_order( If medium-8b gets dramatically higher gen_tps (500) from benchmarks, the scoring and tier assignment should change. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_load_profile.return_value = profile mock_load_catalog.return_value = _make_test_catalog() @@ -1009,22 +950,27 @@ def test_profile_json_parseable_by_all_consumers( mlx_stack_home: Path, ) -> None: """profile.json written by profile is parseable by recommend, init, bench.""" - profile = _make_profile(memory_gb=128) + # Arrange + profile = make_profile(memory_gb=128) _write_profile(mlx_stack_home, profile) - # Verify the file is valid JSON with all required fields + # Act — read raw JSON profile_path = mlx_stack_home / "profile.json" data = json.loads(profile_path.read_text()) + + # Assert — all required fields present assert "chip" in data assert "gpu_cores" in data assert "memory_gb" in data assert "bandwidth_gbps" in data assert "profile_id" in data - # Verify it can be loaded by the hardware module + # Act — load via hardware module from mlx_stack.core.hardware import load_profile loaded = load_profile() + + # Assert — round-trip fidelity assert loaded is not None assert loaded.chip == profile.chip assert loaded.memory_gb == profile.memory_gb @@ -1171,7 +1117,7 @@ def test_vllm_flags_in_dry_run_output( Verifies that the init -> up data flow preserves vllm_flags. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile catalog = _make_test_catalog() mock_catalog.return_value = catalog @@ -1231,7 +1177,7 @@ def test_init_stack_fields_consumed_by_up( vllm_flags — and that the stack has schema_version, hardware_profile, intent, created, and tiers. """ - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() @@ -1274,7 +1220,7 @@ def test_litellm_config_matches_stack_tiers( mlx_stack_home: Path, ) -> None: """LiteLLM config model_list matches stack tier count and ports.""" - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() @@ -1330,7 +1276,7 @@ def test_profile_id_in_stack_matches_profile( mlx_stack_home: Path, ) -> None: """hardware_profile in stack.yaml matches the detected profile_id.""" - profile = _make_profile(memory_gb=128) + profile = make_profile(memory_gb=128) mock_detect.return_value = profile mock_catalog.return_value = _make_test_catalog() diff --git a/tests/unit/test_lifecycle_fixes.py b/tests/unit/test_lifecycle_fixes.py index 05e223a..fc19469 100644 --- a/tests/unit/test_lifecycle_fixes.py +++ b/tests/unit/test_lifecycle_fixes.py @@ -11,20 +11,11 @@ from __future__ import annotations from pathlib import Path -from typing import Any from unittest.mock import MagicMock, patch -import yaml from click.testing import CliRunner from mlx_stack.cli.main import cli -from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, - QualityScores, - QuantSource, -) from mlx_stack.core.process import HealthCheckResult, ServiceInfo from mlx_stack.core.stack_up import ( TierStatus, @@ -32,123 +23,11 @@ check_local_model_exists, run_up, ) - -# --------------------------------------------------------------------------- # -# Test helpers -# --------------------------------------------------------------------------- # - - -def _make_stack_yaml( - tiers: list[dict[str, Any]] | None = None, - schema_version: int = 1, -) -> dict[str, Any]: - """Create a stack definition dict for testing.""" - if tiers is None: - tiers = [ - { - "name": "standard", - "model": "big-model", - "quant": "int4", - "source": "mlx-community/big-model-4bit", - "port": 8000, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - { - "name": "fast", - "model": "fast-model", - "quant": "int4", - "source": "mlx-community/fast-model-4bit", - "port": 8001, - "vllm_flags": { - "continuous_batching": True, - "use_paged_cache": True, - }, - }, - ] - return { - "schema_version": schema_version, - "name": "default", - "hardware_profile": "m4-max-128", - "intent": "balanced", - "created": "2026-03-24T00:00:00+00:00", - "tiers": tiers, - } - - -def _write_stack_yaml( - mlx_stack_home: Path, - stack: dict[str, Any] | None = None, -) -> Path: - """Write a stack YAML file and return its path.""" - if stack is None: - stack = _make_stack_yaml() - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack, default_flow_style=False)) - return stack_path - - -def _write_litellm_yaml(mlx_stack_home: Path) -> Path: - """Write a minimal litellm.yaml config.""" - litellm_config = { - "model_list": [ - { - "model_name": "standard", - "litellm_params": { - "model": "openai/big-model", - "api_base": "http://localhost:8000/v1", - "api_key": "dummy", - }, - }, - ], - } - litellm_path = mlx_stack_home / "litellm.yaml" - litellm_path.write_text(yaml.dump(litellm_config, default_flow_style=False)) - return litellm_path - - -def _make_entry( - model_id: str = "test-model", - params_b: float = 8.0, - memory_gb: float = 5.5, -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - return CatalogEntry( - id=model_id, - name=f"Test {model_id}", - family="Test", - params_b=params_b, - architecture="transformer", - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource(hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=4.5), - }, - capabilities=Capabilities( - tool_calling=True, - tool_call_parser="hermes", - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores(overall=70, coding=65, reasoning=60, instruction_following=72), - benchmarks={ - "m4-max-128": BenchmarkResult(prompt_tps=100.0, gen_tps=50.0, memory_gb=memory_gb), - }, - tags=[], - ) - - -def _make_test_catalog() -> list[CatalogEntry]: - """Create a test catalog.""" - return [ - _make_entry("big-model", params_b=49.0, memory_gb=30.0), - _make_entry("fast-model", params_b=3.0, memory_gb=2.0), - ] - +from tests.factories import ( + make_test_catalog, + write_litellm_yaml, + write_stack_yaml, +) # =========================================================================== # # Issue 1: Preflight local-model existence checks @@ -160,50 +39,65 @@ class TestCheckLocalModelExists: def test_model_found_by_id(self, mlx_stack_home: Path) -> None: """Model found when directory matches model ID.""" + # Arrange models_dir = mlx_stack_home / "models" models_dir.mkdir(parents=True, exist_ok=True) (models_dir / "big-model").mkdir() - tier = {"model": "big-model", "source": "mlx-community/big-model-4bit"} + + # Act / Assert assert check_local_model_exists(tier) is None def test_model_found_by_source_dir(self, mlx_stack_home: Path) -> None: """Model found when directory matches source repo name.""" + # Arrange models_dir = mlx_stack_home / "models" models_dir.mkdir(parents=True, exist_ok=True) (models_dir / "big-model-4bit").mkdir() - tier = {"model": "big-model", "source": "mlx-community/big-model-4bit"} + + # Act / Assert assert check_local_model_exists(tier) is None def test_model_missing_returns_diagnostic(self, mlx_stack_home: Path) -> None: """Missing model returns error message with pull suggestion.""" + # Arrange models_dir = mlx_stack_home / "models" models_dir.mkdir(parents=True, exist_ok=True) - tier = { "model": "missing-model", "source": "mlx-community/missing-model-4bit", } + + # Act error = check_local_model_exists(tier) + + # Assert assert error is not None assert "missing-model" in error assert "mlx-stack pull" in error def test_model_missing_no_models_dir(self, mlx_stack_home: Path) -> None: """Missing model when models directory doesn't exist.""" + # Arrange -- no models directory created tier = {"model": "any-model", "source": "mlx-community/any-model-4bit"} + + # Act error = check_local_model_exists(tier) + + # Assert assert error is not None assert "mlx-stack pull" in error def test_model_found_empty_source(self, mlx_stack_home: Path) -> None: """Model found by ID even when source is empty.""" + # Arrange models_dir = mlx_stack_home / "models" models_dir.mkdir(parents=True, exist_ok=True) (models_dir / "my-model").mkdir() - tier = {"model": "my-model", "source": ""} + + # Act / Assert assert check_local_model_exists(tier) is None @@ -233,9 +127,10 @@ def test_missing_model_skips_tier( mlx_stack_home: Path, ) -> None: """Tier with missing model is skipped with pull suggestion.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() models_dir = mlx_stack_home / "models" models_dir.mkdir(parents=True, exist_ok=True) mock_get_value.side_effect = lambda key: { @@ -257,17 +152,15 @@ def test_missing_model_skips_tier( healthy=True, response_time=0.5, status_code=200 ) - # No models on disk → both tiers should be skipped + # Act -- no models on disk, both tiers should be skipped result = run_up() - # All tiers should be skipped with missing model message + # Assert skipped_tiers = [t for t in result.tiers if t.status == "skipped"] assert len(skipped_tiers) == 2 for tier in skipped_tiers: assert "not found locally" in (tier.error or "") assert "mlx-stack pull" in (tier.error or "") - - # LiteLLM should be skipped since no tiers are healthy assert result.litellm is not None assert result.litellm.status == "skipped" @@ -294,10 +187,13 @@ def test_partial_models_present( mlx_stack_home: Path, ) -> None: """One model present, one missing → mixed results.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() models_dir = mlx_stack_home / "models" + models_dir.mkdir(parents=True, exist_ok=True) + (models_dir / "fast-model-4bit").mkdir() # only fast-model on disk mock_get_value.side_effect = lambda key: { "litellm-port": 4000, "openrouter-key": "", @@ -317,22 +213,15 @@ def test_partial_models_present( healthy=True, response_time=0.5, status_code=200 ) - # Create model directory for "fast-model" source - models_dir = mlx_stack_home / "models" - models_dir.mkdir(parents=True, exist_ok=True) - (models_dir / "fast-model-4bit").mkdir() - + # Act result = run_up() + # Assert -- big-model skipped (not on disk), fast-model healthy statuses = {t.name: t.status for t in result.tiers} - # big-model (standard) should be skipped — not on disk assert statuses["standard"] == "skipped" message = next(t.error for t in result.tiers if t.name == "standard") or "" assert "not found locally" in message - # fast-model (fast) should be healthy — on disk assert statuses["fast"] == "healthy" - - # LiteLLM should start since at least one tier is healthy assert result.litellm is not None assert result.litellm.status == "healthy" @@ -345,18 +234,20 @@ def test_dry_run_shows_missing_model_warning( mlx_stack_home: Path, ) -> None: """Dry-run still shows commands even for missing models.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - mock_load_catalog.return_value = _make_test_catalog() + # Arrange -- no models on disk + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) + mock_load_catalog.return_value = make_test_catalog() mock_get_value.side_effect = lambda key: { "litellm-port": 4000, "openrouter-key": "", }.get(key, "") - # No models on disk — dry-run still shows commands + # Act runner = CliRunner() result = runner.invoke(cli, ["up", "--dry-run"]) - # Dry-run should succeed — it doesn't check for model existence + + # Assert -- dry-run succeeds without checking model existence assert result.exit_code == 0 assert "Dry run" in result.output @@ -365,11 +256,12 @@ def test_cli_missing_model_shows_pull_suggestion( mlx_stack_home: Path, ) -> None: """CLI output shows pull suggestion for missing models.""" - _write_stack_yaml(mlx_stack_home) - _write_litellm_yaml(mlx_stack_home) - + # Arrange + write_stack_yaml(mlx_stack_home) + write_litellm_yaml(mlx_stack_home) runner = CliRunner() + # Act with ( patch("mlx_stack.cli.up.run_up") as mock_run_up, ): @@ -394,6 +286,7 @@ def test_cli_missing_model_shows_pull_suggestion( ) result = runner.invoke(cli, ["up"]) + # Assert assert result.exit_code != 0 assert "pull" in result.output.lower() @@ -408,9 +301,14 @@ class TestReadOnlyNoDataHomeCreation: def test_status_no_create(self, clean_mlx_stack_home: Path) -> None: """status command does not create ~/.mlx-stack/.""" + # Arrange assert not clean_mlx_stack_home.exists() runner = CliRunner() + + # Act result = runner.invoke(cli, ["status"]) + + # Assert assert result.exit_code == 0 assert not clean_mlx_stack_home.exists() @@ -487,9 +385,14 @@ class TestStateWritingCommandsStillCreateDataHome: def test_config_set_creates_dir(self, clean_mlx_stack_home: Path) -> None: """config set creates ~/.mlx-stack/ (needs it to store config).""" + # Arrange assert not clean_mlx_stack_home.exists() runner = CliRunner() + + # Act result = runner.invoke(cli, ["config", "set", "default-quant", "int4"]) + + # Assert assert result.exit_code == 0 assert clean_mlx_stack_home.exists() @@ -498,9 +401,11 @@ def test_profile_creates_dir(self, clean_mlx_stack_home: Path) -> None: Profile calls save_profile which calls ensure_data_home internally. """ + # Arrange assert not clean_mlx_stack_home.exists() runner = CliRunner() - # Profile needs real hardware — mock detect, but let save run real + + # Act -- mock detect_hardware but let save_profile write to disk with ( patch("mlx_stack.cli.profile.detect_hardware") as mock_detect, ): @@ -514,5 +419,6 @@ def test_profile_creates_dir(self, clean_mlx_stack_home: Path) -> None: is_estimate=False, ) runner.invoke(cli, ["profile"]) - # Profile command should create the data directory + + # Assert assert clean_mlx_stack_home.exists() diff --git a/tests/unit/test_log_rotation.py b/tests/unit/test_log_rotation.py index cbeb1a7..1fb1ff7 100644 --- a/tests/unit/test_log_rotation.py +++ b/tests/unit/test_log_rotation.py @@ -7,12 +7,12 @@ from __future__ import annotations -import gzip from pathlib import Path import pytest from mlx_stack.core.log_rotation import LogRotationError, rotate_log +from tests.factories import read_gz # --------------------------------------------------------------------------- # # Helpers @@ -20,7 +20,12 @@ def _create_log(path: Path, size_mb: float = 0, content: str | None = None) -> None: - """Create a log file with the given size or specific content.""" + """Create a log file with the given size or specific content. + + Unlike ``create_log_file`` from factories (which takes *logs_dir* and + *service* separately), this helper accepts a full *path* because the + rotation tests create files in arbitrary ``tmp_path`` locations. + """ if content is not None: path.write_text(content) else: @@ -32,12 +37,6 @@ def _create_log(path: Path, size_mb: float = 0, content: str | None = None) -> N path.touch() -def _read_gz(path: Path) -> bytes: - """Read and decompress a gzip file.""" - with gzip.open(str(path), "rb") as f: - return f.read() - - # --------------------------------------------------------------------------- # # Basic rotation # --------------------------------------------------------------------------- # @@ -48,20 +47,20 @@ class TestBasicRotation: def test_rotates_file_exceeding_threshold(self, tmp_path: Path) -> None: """File above threshold is rotated: archive created, original truncated.""" + # Arrange log = tmp_path / "service.log" _create_log(log, size_mb=2) + # Act result = rotate_log(log, max_size_mb=1, max_files=5) + # Assert assert result is True - # Original should be truncated to zero assert log.exists() assert log.stat().st_size == 0 - # Archive .1.gz should exist archive = tmp_path / "service.log.1.gz" assert archive.exists() - # Archive content should match original content - data = _read_gz(archive) + data = read_gz(archive) assert len(data) == 2 * 1024 * 1024 def test_returns_true_on_rotation(self, tmp_path: Path) -> None: @@ -71,15 +70,18 @@ def test_returns_true_on_rotation(self, tmp_path: Path) -> None: def test_archive_is_valid_gzip(self, tmp_path: Path) -> None: """The archive file is valid gzip that can be decompressed.""" + # Arrange log = tmp_path / "service.log" content = "hello world\n" * 100000 _create_log(log, content=content) + # Act rotate_log(log, max_size_mb=0, max_files=5) + # Assert archive = tmp_path / "service.log.1.gz" assert archive.exists() - decompressed = _read_gz(archive).decode("utf-8") + decompressed = read_gz(archive).decode("utf-8") assert decompressed == content def test_original_file_truncated_to_zero(self, tmp_path: Path) -> None: @@ -119,24 +121,22 @@ class TestArchiveShifting: def test_existing_archive_shifted_up(self, tmp_path: Path) -> None: """Existing .1.gz is shifted to .2.gz before new .1.gz is created.""" + # Arrange log = tmp_path / "service.log" - - # First rotation _create_log(log, content="first content\n" * 100000) rotate_log(log, max_size_mb=0, max_files=5) - # Second rotation + # Act _create_log(log, content="second content\n" * 100000) rotate_log(log, max_size_mb=0, max_files=5) - # .1.gz should have second content, .2.gz should have first content + # Assert -- .1.gz has second content, .2.gz has first content archive1 = tmp_path / "service.log.1.gz" archive2 = tmp_path / "service.log.2.gz" assert archive1.exists() assert archive2.exists() - - data1 = _read_gz(archive1).decode("utf-8") - data2 = _read_gz(archive2).decode("utf-8") + data1 = read_gz(archive1).decode("utf-8") + data2 = read_gz(archive2).decode("utf-8") assert "second content" in data1 assert "first content" in data2 @@ -154,8 +154,8 @@ def test_multiple_rotations_sequential_numbering(self, tmp_path: Path) -> None: assert archive.exists(), f"Archive {n} missing" # .1 = most recent (rotation 3), .4 = oldest (rotation 0) - data1 = _read_gz(tmp_path / "service.log.1.gz").decode("utf-8") - data4 = _read_gz(tmp_path / "service.log.4.gz").decode("utf-8") + data1 = read_gz(tmp_path / "service.log.1.gz").decode("utf-8") + data4 = read_gz(tmp_path / "service.log.4.gz").decode("utf-8") assert "rotation 3" in data1 assert "rotation 0" in data4 @@ -169,7 +169,7 @@ def test_one_is_most_recent(self, tmp_path: Path) -> None: _create_log(log, content="new data\n" * 100000) rotate_log(log, max_size_mb=0, max_files=5) - data = _read_gz(tmp_path / "service.log.1.gz").decode("utf-8") + data = read_gz(tmp_path / "service.log.1.gz").decode("utf-8") assert "new data" in data @@ -198,11 +198,11 @@ def test_oldest_deleted_when_exceeded(self, tmp_path: Path) -> None: assert not (tmp_path / "service.log.6.gz").exists() # .1 should be most recent (rotation 5) - data1 = _read_gz(tmp_path / "service.log.1.gz").decode("utf-8") + data1 = read_gz(tmp_path / "service.log.1.gz").decode("utf-8") assert "rotation 5" in data1 # .5 should be oldest retained (rotation 1, since rotation 0 was deleted) - data5 = _read_gz(tmp_path / "service.log.5.gz").decode("utf-8") + data5 = read_gz(tmp_path / "service.log.5.gz").decode("utf-8") assert "rotation 1" in data5 def test_max_files_one(self, tmp_path: Path) -> None: @@ -217,7 +217,7 @@ def test_max_files_one(self, tmp_path: Path) -> None: assert (tmp_path / "service.log.1.gz").exists() assert not (tmp_path / "service.log.2.gz").exists() - data = _read_gz(tmp_path / "service.log.1.gz").decode("utf-8") + data = read_gz(tmp_path / "service.log.1.gz").decode("utf-8") assert "rotation 2" in data def test_max_files_exact_boundary(self, tmp_path: Path) -> None: diff --git a/tests/unit/test_log_viewer.py b/tests/unit/test_log_viewer.py index df16189..ff689b6 100644 --- a/tests/unit/test_log_viewer.py +++ b/tests/unit/test_log_viewer.py @@ -7,7 +7,6 @@ from __future__ import annotations import contextlib -import gzip import threading import time from datetime import UTC @@ -27,28 +26,7 @@ rotate_all_logs, rotate_service_log, ) - -# --------------------------------------------------------------------------- # -# Helpers -# --------------------------------------------------------------------------- # - - -def _create_log(logs_dir: Path, service: str, content: str = "") -> Path: - """Create a log file for a service.""" - logs_dir.mkdir(parents=True, exist_ok=True) - log_path = logs_dir / f"{service}.log" - log_path.write_text(content) - return log_path - - -def _create_archive(logs_dir: Path, service: str, number: int, content: str) -> Path: - """Create a gzip archive for a service.""" - logs_dir.mkdir(parents=True, exist_ok=True) - archive_path = logs_dir / f"{service}.log.{number}.gz" - with gzip.open(str(archive_path), "wb") as f: - f.write(content.encode("utf-8")) - return archive_path - +from tests.factories import create_archive, create_log_file # --------------------------------------------------------------------------- # # LogFileInfo @@ -140,11 +118,15 @@ def test_no_logs_directory(self, mlx_stack_home: Path) -> None: def test_lists_log_files(self, mlx_stack_home: Path) -> None: """Lists .log files with correct metadata.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "line1\nline2\n") - _create_log(logs_dir, "standard", "some content\n") + create_log_file(logs_dir, "fast", "line1\nline2\n") + create_log_file(logs_dir, "standard", "some content\n") + # Act files = list_log_files() + + # Assert assert len(files) == 2 assert files[0].service == "fast" assert files[1].service == "standard" @@ -153,11 +135,15 @@ def test_lists_log_files(self, mlx_stack_home: Path) -> None: def test_excludes_gz_files(self, mlx_stack_home: Path) -> None: """Archived .gz files are not included in listing.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "content") - _create_archive(logs_dir, "fast", 1, "old content") + create_log_file(logs_dir, "fast", "content") + create_archive(logs_dir, "fast", 1, "old content") + # Act files = list_log_files() + + # Assert assert len(files) == 1 assert files[0].name == "fast.log" @@ -165,7 +151,7 @@ def test_excludes_non_log_files(self, mlx_stack_home: Path) -> None: """Non-.log files are excluded.""" logs_dir = mlx_stack_home / "logs" logs_dir.mkdir(parents=True, exist_ok=True) - _create_log(logs_dir, "fast", "content") + create_log_file(logs_dir, "fast", "content") (logs_dir / "notes.txt").write_text("not a log") files = list_log_files() @@ -178,7 +164,7 @@ class TestGetLogPath: def test_returns_path_for_existing_log(self, mlx_stack_home: Path) -> None: """Returns path when log file exists.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "content") + create_log_file(logs_dir, "fast", "content") path = get_log_path("fast") assert path is not None @@ -201,12 +187,16 @@ class TestGetAvailableServices: def test_returns_sorted_services(self, mlx_stack_home: Path) -> None: """Returns sorted list of services with log files.""" + # Arrange logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "standard", "content") - _create_log(logs_dir, "fast", "content") - _create_log(logs_dir, "litellm", "content") + create_log_file(logs_dir, "standard", "content") + create_log_file(logs_dir, "fast", "content") + create_log_file(logs_dir, "litellm", "content") + # Act services = get_available_services() + + # Assert assert services == ["fast", "litellm", "standard"] def test_returns_empty_for_no_logs(self, mlx_stack_home: Path) -> None: @@ -420,7 +410,7 @@ class TestReadArchive: def test_reads_valid_archive(self, tmp_path: Path) -> None: """Reads and decompresses a valid gzip archive.""" - archive = _create_archive(tmp_path, "fast", 1, "archived content") + archive = create_archive(tmp_path, "fast", 1, "archived content") result = read_archive(archive) assert result == "archived content" @@ -438,29 +428,28 @@ class TestReadAllLogs: def test_chronological_order(self, mlx_stack_home: Path) -> None: """Shows archives oldest first, then current log.""" + # Arrange -- higher archive number = older logs_dir = mlx_stack_home / "logs" + create_archive(logs_dir, "fast", 3, "oldest content") + create_archive(logs_dir, "fast", 2, "middle content") + create_archive(logs_dir, "fast", 1, "newest archived content") + create_log_file(logs_dir, "fast", "current content") - # Create archives (higher number = older) - _create_archive(logs_dir, "fast", 3, "oldest content") - _create_archive(logs_dir, "fast", 2, "middle content") - _create_archive(logs_dir, "fast", 1, "newest archived content") - _create_log(logs_dir, "fast", "current content") - + # Act result = read_all_logs("fast") - lines = result.splitlines() - # Find the positions of section headers + # Assert -- section headers appear in chronological order + lines = result.splitlines() oldest_idx = next(i for i, line in enumerate(lines) if "fast.log.3.gz" in line) middle_idx = next(i for i, line in enumerate(lines) if "fast.log.2.gz" in line) newest_idx = next(i for i, line in enumerate(lines) if "fast.log.1.gz" in line) current_idx = next(i for i, line in enumerate(lines) if "Current:" in line) - assert oldest_idx < middle_idx < newest_idx < current_idx def test_archives_only(self, mlx_stack_home: Path) -> None: """Works with only archived files (no current log).""" logs_dir = mlx_stack_home / "logs" - _create_archive(logs_dir, "fast", 1, "archived only") + create_archive(logs_dir, "fast", 1, "archived only") result = read_all_logs("fast") assert "archived only" in result @@ -468,7 +457,7 @@ def test_archives_only(self, mlx_stack_home: Path) -> None: def test_current_only(self, mlx_stack_home: Path) -> None: """Works with only current log (no archives).""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "current only") + create_log_file(logs_dir, "fast", "current only") result = read_all_logs("fast") assert "current only" in result @@ -512,7 +501,7 @@ def test_rotates_eligible_log(self, mlx_stack_home: Path) -> None: def test_skips_small_log(self, mlx_stack_home: Path) -> None: """Does not rotate a log below the threshold.""" logs_dir = mlx_stack_home / "logs" - _create_log(logs_dir, "fast", "small content") + create_log_file(logs_dir, "fast", "small content") from unittest.mock import patch diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 59bd7b6..86c9261 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -10,13 +10,6 @@ import json from pathlib import Path -from mlx_stack.core.catalog import ( - BenchmarkResult, - Capabilities, - CatalogEntry, - QualityScores, - QuantSource, -) from mlx_stack.core.pull import ( ModelInventoryEntry, add_to_inventory, @@ -25,48 +18,6 @@ save_inventory, ) -# --------------------------------------------------------------------------- # -# Fixtures -# --------------------------------------------------------------------------- # - - -def _make_entry( - model_id: str = "qwen3.5-8b", - name: str = "Qwen 3.5 8B", -) -> CatalogEntry: - """Create a CatalogEntry for testing.""" - return CatalogEntry( - id=model_id, - name=name, - family="Qwen 3.5", - params_b=8.0, - architecture="transformer", - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", - disk_size_gb=4.5, - ), - "int8": QuantSource( - hf_repo=f"mlx-community/{model_id}-8bit", - disk_size_gb=8.5, - ), - }, - capabilities=Capabilities( - tool_calling=True, - tool_call_parser="hermes", - thinking=False, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores(overall=68, coding=65, reasoning=62, instruction_following=72), - benchmarks={ - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), - }, - tags=["balanced"], - ) - - # =========================================================================== # # Inventory file I/O tests # =========================================================================== # @@ -81,13 +32,17 @@ def test_empty_inventory_on_fresh_install(self, mlx_stack_home: Path) -> None: def test_save_creates_file(self, mlx_stack_home: Path) -> None: """save_inventory creates models.json.""" + # Act entries = [{"model_id": "test", "quant": "int4", "name": "Test"}] save_inventory(entries) + + # Assert inv_path = mlx_stack_home / "models.json" assert inv_path.exists() def test_round_trip_preserves_data(self, mlx_stack_home: Path) -> None: """Save → load round-trip preserves all fields.""" + # Arrange entry = { "model_id": "qwen3.5-8b", "name": "Qwen 3.5 8B", @@ -98,8 +53,12 @@ def test_round_trip_preserves_data(self, mlx_stack_home: Path) -> None: "disk_size_gb": 4.5, "downloaded_at": "2025-01-01T00:00:00+00:00", } + + # Act save_inventory([entry]) loaded = load_inventory() + + # Assert assert len(loaded) == 1 for key, value in entry.items(): assert loaded[0][key] == value @@ -147,6 +106,7 @@ class TestInventoryOperations: def test_add_new_entry(self, mlx_stack_home: Path) -> None: """Adding a new entry creates it.""" + # Arrange entry = ModelInventoryEntry( model_id="qwen3.5-8b", name="Qwen 3.5 8B", @@ -157,7 +117,11 @@ def test_add_new_entry(self, mlx_stack_home: Path) -> None: disk_size_gb=4.5, downloaded_at="2025-01-01T00:00:00+00:00", ) + + # Act add_to_inventory(entry) + + # Assert loaded = load_inventory() assert len(loaded) == 1 @@ -200,6 +164,7 @@ def test_different_quants_preserved(self, mlx_stack_home: Path) -> None: def test_find_existing_entry(self, mlx_stack_home: Path) -> None: """Finding an existing entry returns it.""" + # Arrange entry = ModelInventoryEntry( model_id="qwen3.5-8b", name="Qwen 3.5 8B", @@ -212,7 +177,10 @@ def test_find_existing_entry(self, mlx_stack_home: Path) -> None: ) add_to_inventory(entry) + # Act found = find_in_inventory("qwen3.5-8b", "int4") + + # Assert assert found is not None assert found["name"] == "Qwen 3.5 8B" diff --git a/tests/unit/test_ops_cross_area.py b/tests/unit/test_ops_cross_area.py index 4a8c8a6..2593c62 100644 --- a/tests/unit/test_ops_cross_area.py +++ b/tests/unit/test_ops_cross_area.py @@ -39,6 +39,7 @@ restart_service, rotate_service_logs, ) +from tests.factories import read_gz # --------------------------------------------------------------------------- # # Shared fixtures @@ -118,7 +119,12 @@ def config_file(mlx_stack_home: Path) -> Path: def _create_log(path: Path, size_mb: float = 0, content: str | None = None) -> None: - """Create a log file with the given size or specific content.""" + """Create a log file with the given size or specific content. + + Unlike ``create_log_file`` from factories (which takes *logs_dir* and + *service* separately), this helper accepts a full *path* because the + cross-area tests create files in arbitrary locations. + """ if content is not None: path.write_text(content) else: @@ -129,12 +135,6 @@ def _create_log(path: Path, size_mb: float = 0, content: str | None = None) -> N path.touch() -def _read_gz(path: Path) -> bytes: - """Read and decompress a gzip file.""" - with gzip.open(str(path), "rb") as f: - return f.read() - - # =========================================================================== # # VAL-CROSS-OPS-001: Watchdog auto-rotation integration # =========================================================================== # @@ -152,27 +152,22 @@ def test_poll_cycle_rotates_over_threshold_logs( self, mlx_stack_home: Path, logs_dir: Path, stack_definition: dict[str, Any] ) -> None: """Watchdog poll cycle rotates logs exceeding the configured threshold.""" - # Create a log file over the threshold (use 1MB threshold for speed) + # Arrange -- one log above 1 MB threshold, one below fast_log = logs_dir / "fast.log" _create_log(fast_log, size_mb=2) - - # Create a log file under threshold — should NOT be rotated standard_log = logs_dir / "standard.log" _create_log(standard_log, content="small log\n") - state = WatchdogState() + # Act with ( patch("mlx_stack.core.watchdog.run_status") as mock_status, patch("mlx_stack.core.watchdog.get_value") as mock_get_value, ): - # All services healthy — no restarts needed mock_result = MagicMock() mock_result.no_stack = False mock_result.services = [] mock_status.return_value = mock_result - - # Configure rotation thresholds mock_get_value.side_effect = lambda key: { "log-max-size-mb": 1, "log-max-files": 5, @@ -186,15 +181,13 @@ def test_poll_cycle_rotates_over_threshold_logs( restart_delay=10, ) - # fast.log was above 1MB → should have been rotated + # Assert -- fast.log rotated, standard.log untouched assert result.rotations_performed >= 1 - assert fast_log.stat().st_size == 0 # truncated + assert fast_log.stat().st_size == 0 archive = logs_dir / "fast.log.1.gz" assert archive.exists() - data = _read_gz(archive) + data = read_gz(archive) assert len(data) == 2 * 1024 * 1024 - - # standard.log should NOT have been rotated (below threshold) assert standard_log.stat().st_size > 0 assert not (logs_dir / "standard.log.1.gz").exists() @@ -207,23 +200,21 @@ def test_concurrent_watchdog_and_manual_rotation_is_idempotent( the file should already be below threshold, so manual rotation is a no-op (idempotent). """ + # Arrange log = logs_dir / "fast.log" _create_log(log, size_mb=2) - # First rotation: simulates watchdog + # Act -- first rotation (watchdog), then second rotation (manual) result_1 = rotate_log(log, max_size_mb=1, max_files=5) + result_2 = rotate_log(log, max_size_mb=1, max_files=5) + + # Assert assert result_1 is True + assert result_2 is False # File is now 0 bytes -- skip assert log.stat().st_size == 0 archive_1 = logs_dir / "fast.log.1.gz" assert archive_1.exists() - - # Second rotation: simulates manual --rotate (concurrent) - result_2 = rotate_log(log, max_size_mb=1, max_files=5) - assert result_2 is False # File is now 0 bytes — skip - - # Archive should still be intact - assert archive_1.exists() - data = _read_gz(archive_1) + data = read_gz(archive_1) assert len(data) == 2 * 1024 * 1024 def test_concurrent_rotation_with_new_content( diff --git a/tests/unit/test_robustness_fixes.py b/tests/unit/test_robustness_fixes.py index be6ddcb..0405a96 100644 --- a/tests/unit/test_robustness_fixes.py +++ b/tests/unit/test_robustness_fixes.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock, patch import pytest -import yaml from mlx_stack.core.process import ( ProcessError, @@ -30,59 +29,19 @@ from mlx_stack.core.stack_up import ( run_up, ) +from tests.factories import create_pid_file, make_stack_yaml, write_stack_yaml -# --------------------------------------------------------------------------- # -# Helpers -# --------------------------------------------------------------------------- # - - -def _make_stack_yaml( - tiers: list[dict[str, Any]] | None = None, -) -> dict[str, Any]: - """Create a stack definition dict for testing.""" - if tiers is None: - tiers = [ - { - "name": "fast", - "model": "fast-model", - "quant": "int4", - "source": "mlx-community/fast-model-4bit", - "port": 8001, - "vllm_flags": {"continuous_batching": True}, - }, - ] - return { - "schema_version": 1, - "name": "default", - "hardware_profile": "m4-max-128", - "intent": "balanced", - "created": "2026-03-24T00:00:00+00:00", - "tiers": tiers, - } - - -def _write_stack_yaml(mlx_stack_home: Path, stack: dict[str, Any] | None = None) -> Path: - """Write a stack YAML file and return its path.""" - if stack is None: - stack = _make_stack_yaml() - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack, default_flow_style=False)) - return stack_path - - -def _create_pid_file( - mlx_stack_home: Path, - service_name: str, - pid: int | str = 12345, -) -> Path: - """Create a PID file in the pids directory.""" - pids_dir = mlx_stack_home / "pids" - pids_dir.mkdir(parents=True, exist_ok=True) - pid_path = pids_dir / f"{service_name}.pid" - pid_path.write_text(str(pid)) - return pid_path +# Single-tier stack used by all robustness tests. +_FAST_TIER_ONLY: list[dict[str, Any]] = [ + { + "name": "fast", + "model": "fast-model", + "quant": "int4", + "source": "mlx-community/fast-model-4bit", + "port": 8001, + "vllm_flags": {"continuous_batching": True}, + }, +] # =========================================================================== # @@ -103,14 +62,15 @@ def test_process_killed_on_pid_write_failure( mlx_stack_home: Path, ) -> None: """Process is terminated when PID file cannot be written.""" + # Arrange mock_proc = MagicMock() mock_proc.pid = 54321 mock_popen.return_value = mock_proc + # Act / Assert with pytest.raises(ProcessError, match="Could not write PID file"): start_service("fast", cmd=["vllm-mlx", "--port", "8001"], port=8001) - # The spawned process must have been terminated mock_proc.terminate.assert_called_once() @patch("mlx_stack.core.process.subprocess.Popen") @@ -122,15 +82,16 @@ def test_process_killed_even_if_terminate_fails( mlx_stack_home: Path, ) -> None: """Process is force-killed if terminate() fails.""" + # Arrange mock_proc = MagicMock() mock_proc.pid = 54321 mock_proc.terminate.side_effect = OSError("already dead") mock_popen.return_value = mock_proc + # Act / Assert with pytest.raises(ProcessError, match="Could not write PID file"): start_service("fast", cmd=["vllm-mlx", "--port", "8001"], port=8001) - # Should fall through to kill() mock_proc.kill.assert_called_once() @patch("mlx_stack.core.process.subprocess.Popen") @@ -142,10 +103,12 @@ def test_error_message_includes_pid( mlx_stack_home: Path, ) -> None: """Error message includes the orphaned process PID.""" + # Arrange mock_proc = MagicMock() mock_proc.pid = 54321 mock_popen.return_value = mock_proc + # Act / Assert with pytest.raises(ProcessError) as exc_info: start_service("fast", cmd=["vllm-mlx"], port=8001) @@ -161,15 +124,16 @@ def test_log_file_closed_on_pid_write_failure( mlx_stack_home: Path, ) -> None: """Log file handle is properly closed on PID write failure.""" + # Arrange mock_proc = MagicMock() mock_proc.pid = 12345 mock_popen.return_value = mock_proc + # Act with pytest.raises(ProcessError): start_service("fast", cmd=["vllm-mlx"], port=8001) - # Verify no leaked file descriptors — the log file should exist - # (was opened for writing) but the handle should have been closed + # Assert -- log file exists (was opened) but handle was closed log_path = mlx_stack_home / "logs" / "fast.log" assert log_path.exists() @@ -185,14 +149,13 @@ class TestStackUpCorruptPidHandling: def test_corrupt_tier_pid_cleaned_up_gracefully(self, mlx_stack_home: Path) -> None: """Corrupt tier PID file is cleaned up without traceback.""" - _write_stack_yaml(mlx_stack_home) - # Create a corrupt PID file - _create_pid_file(mlx_stack_home, "fast", "not-a-number") - - # Write litellm config so LiteLLM can be started + # Arrange + write_stack_yaml(mlx_stack_home, make_stack_yaml(tiers=_FAST_TIER_ONLY)) + create_pid_file(mlx_stack_home, "fast", "not-a-number") litellm_config = mlx_stack_home / "litellm.yaml" litellm_config.write_text("model_list: []\n") + # Act with ( patch("mlx_stack.core.stack_up.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_up.ensure_dependency"), @@ -205,31 +168,26 @@ def test_corrupt_tier_pid_cleaned_up_gracefully(self, mlx_stack_home: Path) -> N ): mock_lock.return_value.__enter__ = MagicMock() mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start.return_value = MagicMock(pid=99999) mock_health.return_value = MagicMock(healthy=True) result = run_up() - # The corrupt PID should have been cleaned up, and the tier started fresh + # Assert -- corrupt PID cleaned up and tier started fresh pids_dir = mlx_stack_home / "pids" corrupt_file = pids_dir / "fast.pid" - # The corrupt file should have been cleaned up (removed by our fix) - # and a new one written by start_service mock assert not corrupt_file.exists() or corrupt_file.read_text() != "not-a-number" - - # Should have warnings about stale PID cleanup assert any("stale" in w.lower() or "Cleaned up" in w for w in result.warnings) def test_corrupt_litellm_pid_cleaned_up_gracefully(self, mlx_stack_home: Path) -> None: """Corrupt LiteLLM PID file is cleaned up without traceback.""" - _write_stack_yaml(mlx_stack_home) - # Create a corrupt LiteLLM PID file - _create_pid_file(mlx_stack_home, "litellm", "garbage-data") - + # Arrange + write_stack_yaml(mlx_stack_home, make_stack_yaml(tiers=_FAST_TIER_ONLY)) + create_pid_file(mlx_stack_home, "litellm", "garbage-data") litellm_config = mlx_stack_home / "litellm.yaml" litellm_config.write_text("model_list: []\n") + # Act -- should NOT raise UpError; corrupt PID is handled gracefully with ( patch("mlx_stack.core.stack_up.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_up.ensure_dependency"), @@ -242,14 +200,12 @@ def test_corrupt_litellm_pid_cleaned_up_gracefully(self, mlx_stack_home: Path) - ): mock_lock.return_value.__enter__ = MagicMock() mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start.return_value = MagicMock(pid=99999) mock_health.return_value = MagicMock(healthy=True) - # This should NOT raise UpError — corrupt PID is handled gracefully result = run_up() - # Corrupt PID cleaned up + # Assert assert any("stale" in w.lower() or "Cleaned up" in w for w in result.warnings) def test_corrupt_pid_no_traceback_via_cli(self, mlx_stack_home: Path) -> None: @@ -258,12 +214,13 @@ def test_corrupt_pid_no_traceback_via_cli(self, mlx_stack_home: Path) -> None: from mlx_stack.cli.main import cli - _write_stack_yaml(mlx_stack_home) - _create_pid_file(mlx_stack_home, "fast", "corrupt!!!") - + # Arrange + write_stack_yaml(mlx_stack_home, make_stack_yaml(tiers=_FAST_TIER_ONLY)) + create_pid_file(mlx_stack_home, "fast", "corrupt!!!") litellm_config = mlx_stack_home / "litellm.yaml" litellm_config.write_text("model_list: []\n") + # Act with ( patch("mlx_stack.core.stack_up.acquire_lock") as mock_lock, patch("mlx_stack.core.stack_up.ensure_dependency"), @@ -276,13 +233,13 @@ def test_corrupt_pid_no_traceback_via_cli(self, mlx_stack_home: Path) -> None: ): mock_lock.return_value.__enter__ = MagicMock() mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start.return_value = MagicMock(pid=99999) mock_health.return_value = MagicMock(healthy=True) runner = CliRunner() result = runner.invoke(cli, ["up"]) + # Assert assert "Traceback" not in result.output @@ -307,7 +264,7 @@ def test_sigkill_confirmed_dead_returns_confirmed( mock_sleep: MagicMock, ) -> None: """Process confirmed dead after SIGKILL → confirmed=True.""" - # Process stays alive through grace, dies after SIGKILL + # Arrange -- process stays alive through grace, dies after SIGKILL mock_alive.side_effect = [True, True, True, False] mock_monotonic.side_effect = [ 0.0, # deadline = 10.0 @@ -315,7 +272,10 @@ def test_sigkill_confirmed_dead_returns_confirmed( 11.0, # past grace → SIGKILL ] + # Act graceful, confirmed = _terminate_process(123, grace_period=10) + + # Assert assert graceful is False assert confirmed is True @@ -331,14 +291,17 @@ def test_sigkill_process_still_alive_returns_not_confirmed( mock_sleep: MagicMock, ) -> None: """Process still alive after SIGKILL → confirmed=False.""" - # Process never dies (survives SIGKILL — e.g. zombie or kernel hold) + # Arrange -- process never dies (zombie or kernel hold) mock_alive.return_value = True mock_monotonic.side_effect = [ 0.0, # deadline = 10.0 11.0, # past grace → SIGKILL ] + # Act graceful, confirmed = _terminate_process(123, grace_period=10) + + # Assert assert graceful is False assert confirmed is False @@ -351,13 +314,15 @@ def test_pid_file_not_removed_when_termination_unconfirmed( mlx_stack_home: Path, ) -> None: """PID file is NOT removed when process termination is not confirmed.""" + # Arrange write_pid_file("stubborn", 12345) + # Act result = stop_service("stubborn") + # Assert assert result is not None assert result.graceful is False - # PID file should still exist since process wasn't confirmed dead pid = read_pid_file("stubborn") assert pid == 12345 @@ -370,13 +335,15 @@ def test_pid_file_removed_when_termination_confirmed( mlx_stack_home: Path, ) -> None: """PID file IS removed when process termination is confirmed.""" + # Arrange write_pid_file("killed", 12345) + # Act result = stop_service("killed") + # Assert assert result is not None assert result.graceful is False - # PID file should be removed since termination was confirmed assert read_pid_file("killed") is None @patch("mlx_stack.core.process._terminate_process", return_value=(True, True)) @@ -388,10 +355,13 @@ def test_pid_file_removed_on_graceful_confirmed_shutdown( mlx_stack_home: Path, ) -> None: """PID file removed on graceful + confirmed shutdown.""" + # Arrange write_pid_file("graceful", 12345) + # Act result = stop_service("graceful") + # Assert assert result is not None assert result.graceful is True assert read_pid_file("graceful") is None diff --git a/tests/unit/test_scoring.py b/tests/unit/test_scoring.py index 55cc359..128c9f4 100644 --- a/tests/unit/test_scoring.py +++ b/tests/unit/test_scoring.py @@ -21,10 +21,7 @@ from mlx_stack.core.catalog import ( BenchmarkResult, - Capabilities, CatalogEntry, - QualityScores, - QuantSource, ) from mlx_stack.core.hardware import HardwareProfile from mlx_stack.core.scoring import ( @@ -48,93 +45,31 @@ score_and_filter, score_model, ) +from tests.factories import make_entry, make_profile # --------------------------------------------------------------------------- # # Fixtures — reusable test data # --------------------------------------------------------------------------- # - -def _make_entry( - model_id: str = "test-model", - name: str = "Test Model", - family: str = "Test", - params_b: float = 8.0, - architecture: str = "transformer", - quality_overall: int = 70, - quality_coding: int = 65, - quality_reasoning: int = 60, - quality_instruction: int = 72, - tool_calling: bool = True, - tool_call_parser: str | None = "hermes", - thinking: bool = False, - benchmarks: dict[str, BenchmarkResult] | None = None, - tags: list[str] | None = None, - gated: bool = False, -) -> CatalogEntry: - """Helper to create a CatalogEntry for testing.""" - if benchmarks is None: - benchmarks = { - "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), - "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), - "m5-max-128": BenchmarkResult(prompt_tps=155.0, gen_tps=85.0, memory_gb=5.5), - } - return CatalogEntry( - id=model_id, - name=name, - family=family, - params_b=params_b, - architecture=architecture, - min_mlx_lm_version="0.22.0", - sources={ - "int4": QuantSource(hf_repo="test/repo-4bit", disk_size_gb=4.5), - "int8": QuantSource(hf_repo="test/repo-8bit", disk_size_gb=8.5), - }, - capabilities=Capabilities( - tool_calling=tool_calling, - tool_call_parser=tool_call_parser if tool_calling else None, - thinking=thinking, - reasoning_parser=None, - vision=False, - ), - quality=QualityScores( - overall=quality_overall, - coding=quality_coding, - reasoning=quality_reasoning, - instruction_following=quality_instruction, - ), - benchmarks=benchmarks, - tags=tags or ["balanced"], - gated=gated, - ) - - -def _make_profile( - chip: str = "Apple M4 Max", - gpu_cores: int = 40, - memory_gb: int = 128, - bandwidth_gbps: float = 546.0, - is_estimate: bool = False, -) -> HardwareProfile: - """Helper to create a HardwareProfile for testing.""" - return HardwareProfile( - chip=chip, - gpu_cores=gpu_cores, - memory_gb=memory_gb, - bandwidth_gbps=bandwidth_gbps, - is_estimate=is_estimate, - ) +# Default benchmarks for this test module's basic_entry and most calls that +# rely on the 3-profile convention (m4-pro-48, m4-max-128, m5-max-128). +_SCORING_BENCHMARKS: dict[str, BenchmarkResult] = { + "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), + "m4-max-128": BenchmarkResult(prompt_tps=140.0, gen_tps=77.0, memory_gb=5.5), + "m5-max-128": BenchmarkResult(prompt_tps=155.0, gen_tps=85.0, memory_gb=5.5), +} @pytest.fixture def m4_max_128_profile() -> HardwareProfile: """M4 Max 128 GB profile — matches catalog benchmark key 'm4-max-128'.""" - return _make_profile() + return make_profile() @pytest.fixture def m4_pro_48_profile() -> HardwareProfile: """M4 Pro 48 GB profile — matches catalog benchmark key 'm4-pro-48'.""" - return _make_profile( + return make_profile( chip="Apple M4 Pro", gpu_cores=18, memory_gb=48, @@ -145,7 +80,7 @@ def m4_pro_48_profile() -> HardwareProfile: @pytest.fixture def unknown_profile() -> HardwareProfile: """Unknown hardware profile — no catalog benchmark match.""" - return _make_profile( + return make_profile( chip="Apple M6", gpu_cores=60, memory_gb=256, @@ -157,7 +92,7 @@ def unknown_profile() -> HardwareProfile: @pytest.fixture def small_memory_profile() -> HardwareProfile: """Small memory profile (32 GB) for tier count tests.""" - return _make_profile( + return make_profile( chip="Apple M4 Pro", gpu_cores=18, memory_gb=32, @@ -168,7 +103,7 @@ def small_memory_profile() -> HardwareProfile: @pytest.fixture def basic_entry() -> CatalogEntry: """A basic catalog entry with standard benchmarks.""" - return _make_entry() + return make_entry(benchmarks=_SCORING_BENCHMARKS, tags=["balanced"]) @pytest.fixture @@ -176,7 +111,7 @@ def sample_catalog() -> list[CatalogEntry]: """A representative catalog for testing scoring and tier assignment.""" return [ # High quality model - _make_entry( + make_entry( model_id="premium-72b", name="Premium 72B", params_b=72.0, @@ -192,7 +127,7 @@ def sample_catalog() -> list[CatalogEntry]: tags=["premium", "quality"], ), # Fast small model - _make_entry( + make_entry( model_id="fast-0.8b", name="Fast 0.8B", params_b=0.8, @@ -209,7 +144,7 @@ def sample_catalog() -> list[CatalogEntry]: tags=["fast-inference"], ), # Mid-tier model - _make_entry( + make_entry( model_id="mid-8b", name="Mid 8B", params_b=8.0, @@ -226,7 +161,7 @@ def sample_catalog() -> list[CatalogEntry]: tags=["balanced", "agent-ready"], ), # Longctx model with mamba2-hybrid architecture - _make_entry( + make_entry( model_id="longctx-32b", name="LongCtx 32B", params_b=32.0, @@ -244,7 +179,7 @@ def sample_catalog() -> list[CatalogEntry]: tags=["reasoning", "long-context"], ), # Model without tool calling - _make_entry( + make_entry( model_id="no-tools-27b", name="No Tools 27B", params_b=27.0, @@ -261,7 +196,7 @@ def sample_catalog() -> list[CatalogEntry]: tags=["vision", "quality"], ), # Very large model that exceeds most budgets - _make_entry( + make_entry( model_id="huge-100b", name="Huge 100B", params_b=100.0, @@ -455,7 +390,10 @@ class TestBenchmarkResolution: def test_direct_match( self, basic_entry: CatalogEntry, m4_max_128_profile: HardwareProfile ) -> None: + # Act gen_tps, memory_gb, is_estimated = _resolve_benchmark(basic_entry, m4_max_128_profile) + + # Assert assert gen_tps == 77.0 assert memory_gb == 5.5 assert is_estimated is False @@ -463,7 +401,10 @@ def test_direct_match( def test_direct_match_different_profile( self, basic_entry: CatalogEntry, m4_pro_48_profile: HardwareProfile ) -> None: + # Act gen_tps, memory_gb, is_estimated = _resolve_benchmark(basic_entry, m4_pro_48_profile) + + # Assert assert gen_tps == 52.0 assert memory_gb == 5.5 assert is_estimated is False @@ -471,12 +412,17 @@ def test_direct_match_different_profile( def test_saved_benchmarks_override( self, basic_entry: CatalogEntry, m4_max_128_profile: HardwareProfile ) -> None: + # Arrange saved = { "test-model": {"gen_tps": 90.0, "memory_gb": 5.8}, } + + # Act gen_tps, memory_gb, is_estimated = _resolve_benchmark( basic_entry, m4_max_128_profile, saved_benchmarks=saved ) + + # Assert assert gen_tps == 90.0 assert memory_gb == 5.8 assert is_estimated is False @@ -493,13 +439,13 @@ def test_bandwidth_ratio_estimation( def test_bandwidth_ratio_scales_correctly(self) -> None: """Bandwidth-ratio estimation should scale gen_tps proportionally.""" - entry = _make_entry( + entry = make_entry( benchmarks={ "m4-max-128": BenchmarkResult(prompt_tps=100.0, gen_tps=50.0, memory_gb=5.0), }, ) # Profile with double the bandwidth of reference - profile_2x = _make_profile( + profile_2x = make_profile( chip="Apple M99", bandwidth_gbps=1092.0, # 2x m4-max-128's 546 ) @@ -543,7 +489,7 @@ def test_malformed_saved_benchmark_none_value( assert is_estimated is False def test_model_without_benchmarks_raises(self, unknown_profile: HardwareProfile) -> None: - entry = _make_entry(benchmarks={}) + entry = make_entry(benchmarks={}) with pytest.raises(ScoringError, match="no benchmark data"): _resolve_benchmark(entry, unknown_profile) @@ -559,10 +505,14 @@ class TestScoreModel: def test_basic_scoring( self, basic_entry: CatalogEntry, m4_max_128_profile: HardwareProfile ) -> None: + # Arrange weights = INTENT_WEIGHTS["balanced"] budget_gb = 51.2 + + # Act scored = score_model(basic_entry, m4_max_128_profile, weights, budget_gb) + # Assert assert isinstance(scored, ScoredModel) assert scored.entry.id == "test-model" assert 0 < scored.composite_score <= 1.0 @@ -603,14 +553,14 @@ def test_high_bandwidth_hardware_speed_score_clamped(self) -> None: estimated gen_tps can exceed 200 (the normalization reference). The speed_score must still be clamped to [0, 1]. """ - entry = _make_entry( + entry = make_entry( model_id="fast-model", benchmarks={ "m4-max-128": BenchmarkResult(prompt_tps=480.0, gen_tps=185.0, memory_gb=0.8), }, ) # Very high bandwidth hardware — 4x the reference m4-max-128 (546 GB/s) - high_bw_profile = _make_profile( + high_bw_profile = make_profile( chip="Apple M99 Ultra", gpu_cores=128, memory_gb=512, @@ -632,8 +582,8 @@ def test_tool_calling_model_scores_higher_for_agent_fleet( self, m4_max_128_profile: HardwareProfile ) -> None: """Models with tool_calling should score higher under agent-fleet intent.""" - with_tools = _make_entry(model_id="with-tools", tool_calling=True) - without_tools = _make_entry( + with_tools = make_entry(model_id="with-tools", tool_calling=True) + without_tools = make_entry( model_id="no-tools", tool_calling=False, tool_call_parser=None, @@ -732,7 +682,7 @@ def test_saved_benchmarks_used( self, m4_max_128_profile: HardwareProfile, ) -> None: - entry = _make_entry() + entry = make_entry() saved = {"test-model": {"gen_tps": 200.0, "memory_gb": 5.5}} scored = score_and_filter( [entry], @@ -807,7 +757,7 @@ def test_small_memory_fewer_tiers(self, small_memory_profile: HardwareProfile) - """Small memory systems (budget < 16 GB) should get 1-2 tiers.""" # 32 GB * 40% = 12.8 GB budget — below 16 GB threshold catalog = [ - _make_entry( + make_entry( model_id="small-3b", params_b=3.0, quality_overall=55, @@ -815,7 +765,7 @@ def test_small_memory_fewer_tiers(self, small_memory_profile: HardwareProfile) - "m4-pro-48": BenchmarkResult(prompt_tps=180.0, gen_tps=88.0, memory_gb=2.5), }, ), - _make_entry( + make_entry( model_id="mid-8b", params_b=8.0, quality_overall=68, @@ -823,7 +773,7 @@ def test_small_memory_fewer_tiers(self, small_memory_profile: HardwareProfile) - "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), }, ), - _make_entry( + make_entry( model_id="longctx-8b", architecture="mamba2-hybrid", params_b=8.0, @@ -862,8 +812,8 @@ def test_empty_scored_models(self) -> None: assert tiers == [] def test_single_model(self) -> None: - entry = _make_entry() - profile = _make_profile() + entry = make_entry() + profile = make_profile() weights = INTENT_WEIGHTS["balanced"] scored = [score_model(entry, profile, weights, 51.2)] tiers = assign_tiers(scored, 51.2) @@ -876,8 +826,8 @@ def test_no_longctx_architecture_available( ) -> None: """If no mamba2-hybrid models exist, only 2 tiers are assigned.""" catalog = [ - _make_entry(model_id="high-q", quality_overall=91), - _make_entry( + make_entry(model_id="high-q", quality_overall=91), + make_entry( model_id="fast-model", quality_overall=42, benchmarks={ @@ -1018,7 +968,7 @@ def test_saved_benchmarks_override_catalog( self, m4_max_128_profile: HardwareProfile, ) -> None: - entry = _make_entry() + entry = make_entry() saved = {"test-model": {"gen_tps": 200.0, "memory_gb": 5.5}} result = recommend([entry], m4_max_128_profile, saved_benchmarks=saved) assert len(result.all_scored) == 1 @@ -1048,14 +998,14 @@ def test_deterministic( def test_small_budget_gives_fewer_tiers(self, small_memory_profile: HardwareProfile) -> None: """On small memory, recommendation produces fewer tiers.""" catalog = [ - _make_entry( + make_entry( model_id="small-model", quality_overall=65, benchmarks={ "m4-pro-48": BenchmarkResult(prompt_tps=180.0, gen_tps=88.0, memory_gb=2.5), }, ), - _make_entry( + make_entry( model_id="mid-model", quality_overall=70, benchmarks={ @@ -1100,7 +1050,7 @@ def test_real_catalog_recommendation(self) -> None: catalog = load_catalog() assert len(catalog) == 15 - profile = _make_profile() # M4 Max 128 GB + profile = make_profile() # M4 Max 128 GB result = recommend(catalog, profile) # Should have some models and tiers @@ -1116,7 +1066,7 @@ def test_real_catalog_balanced_vs_agent_fleet(self) -> None: from mlx_stack.core.catalog import load_catalog catalog = load_catalog() - profile = _make_profile() + profile = make_profile() balanced = recommend(catalog, profile, intent="balanced") fleet = recommend(catalog, profile, intent="agent-fleet") @@ -1135,7 +1085,7 @@ def test_real_catalog_m4_pro_48(self) -> None: from mlx_stack.core.catalog import load_catalog catalog = load_catalog() - profile = _make_profile( + profile = make_profile( chip="Apple M4 Pro", gpu_cores=18, memory_gb=48, @@ -1158,7 +1108,7 @@ def test_real_catalog_unknown_hardware(self) -> None: from mlx_stack.core.catalog import load_catalog catalog = load_catalog() - profile = _make_profile( + profile = make_profile( chip="Apple M6", memory_gb=256, bandwidth_gbps=1000.0, @@ -1175,7 +1125,7 @@ def test_real_catalog_deterministic(self) -> None: from mlx_stack.core.catalog import load_catalog catalog = load_catalog() - profile = _make_profile() + profile = make_profile() r1 = recommend(catalog, profile) r2 = recommend(catalog, profile) @@ -1197,7 +1147,7 @@ def test_real_catalog_intents_produce_different_tier_assignments(self) -> None: from mlx_stack.core.catalog import load_catalog catalog = load_catalog() - profile = _make_profile() # M4 Max 128 GB + profile = make_profile() # M4 Max 128 GB balanced = recommend(catalog, profile, intent="balanced") fleet = recommend(catalog, profile, intent="agent-fleet") @@ -1247,7 +1197,7 @@ def test_different_intents_different_standard_tier( are similar, making the quality vs. tool_calling weight difference the decisive factor. """ - model_a = _make_entry( + model_a = make_entry( model_id="quality-leader", name="Quality Leader", quality_overall=95, @@ -1260,7 +1210,7 @@ def test_different_intents_different_standard_tier( "m4-max-128": BenchmarkResult(prompt_tps=40.0, gen_tps=25.0, memory_gb=20.0), }, ) - model_b = _make_entry( + model_b = make_entry( model_id="tool-caller", name="Tool Caller", quality_overall=50, @@ -1300,7 +1250,7 @@ def test_standard_tier_uses_composite_not_raw_quality( ) -> None: """Standard tier should pick the highest composite score, not raw quality.""" # Model with highest quality but terrible speed and efficiency - high_q = _make_entry( + high_q = make_entry( model_id="slow-quality", quality_overall=95, benchmarks={ @@ -1308,7 +1258,7 @@ def test_standard_tier_uses_composite_not_raw_quality( }, ) # Model with moderate quality but good speed and efficiency - balanced_model = _make_entry( + balanced_model = make_entry( model_id="balanced-model", quality_overall=70, benchmarks={ @@ -1335,8 +1285,8 @@ class TestExcludeGated: def test_exclude_gated_filters_gated_models(self, m4_max_128_profile: HardwareProfile) -> None: """Gated models are excluded when exclude_gated=True.""" - open_model = _make_entry(model_id="open-model", name="Open Model") - gated_model = _make_entry(model_id="gated-model", name="Gated Model", gated=True) + open_model = make_entry(model_id="open-model", name="Open Model") + gated_model = make_entry(model_id="gated-model", name="Gated Model", gated=True) scored = score_and_filter( [open_model, gated_model], @@ -1351,8 +1301,8 @@ def test_exclude_gated_filters_gated_models(self, m4_max_128_profile: HardwarePr def test_exclude_gated_false_includes_all(self, m4_max_128_profile: HardwareProfile) -> None: """All models included when exclude_gated=False (default).""" - open_model = _make_entry(model_id="open-model", name="Open Model") - gated_model = _make_entry(model_id="gated-model", name="Gated Model", gated=True) + open_model = make_entry(model_id="open-model", name="Open Model") + gated_model = make_entry(model_id="gated-model", name="Gated Model", gated=True) scored = score_and_filter( [open_model, gated_model], @@ -1367,8 +1317,8 @@ def test_exclude_gated_false_includes_all(self, m4_max_128_profile: HardwareProf def test_recommend_exclude_gated(self, m4_max_128_profile: HardwareProfile) -> None: """Gated models excluded from tier assignments via recommend().""" - open_model = _make_entry(model_id="open-model", name="Open Model") - gated_model = _make_entry( + open_model = make_entry(model_id="open-model", name="Open Model") + gated_model = make_entry( model_id="gated-model", name="Gated Model", quality_overall=99, diff --git a/tests/unit/test_watchdog.py b/tests/unit/test_watchdog.py index 7b551c6..814412e 100644 --- a/tests/unit/test_watchdog.py +++ b/tests/unit/test_watchdog.py @@ -15,7 +15,6 @@ from unittest.mock import MagicMock, patch import pytest -import yaml from mlx_stack.core.watchdog import ( DEFAULT_INTERVAL, @@ -37,6 +36,7 @@ run_watchdog, setup_signal_handlers, ) +from tests.factories import create_log_file, create_pid_file, write_stack_yaml # --------------------------------------------------------------------------- # # Fixtures @@ -46,9 +46,6 @@ @pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: """Create a test stack definition and return it.""" - stacks_dir = mlx_stack_home / "stacks" - stacks_dir.mkdir(parents=True, exist_ok=True) - stack = { "schema_version": 1, "name": "default", @@ -81,9 +78,7 @@ def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: ], } - stack_path = stacks_dir / "default.yaml" - stack_path.write_text(yaml.dump(stack)) - + write_stack_yaml(mlx_stack_home, stack) return stack @@ -121,18 +116,30 @@ def test_below_threshold_not_flapping(self) -> None: assert check_flapping(tracker, max_restarts=5) is False def test_at_threshold_is_flapping(self) -> None: + # Arrange tracker = ServiceTracker() now = time.monotonic() tracker.restart_timestamps = [now - i for i in range(5)] - assert check_flapping(tracker, max_restarts=5) is True + + # Act + result = check_flapping(tracker, max_restarts=5) + + # Assert + assert result is True assert tracker.is_flapping is True def test_old_restarts_pruned(self) -> None: """Restarts outside the rolling window are pruned.""" + # Arrange tracker = ServiceTracker() old = time.monotonic() - 700 # outside 600s window tracker.restart_timestamps = [old, old - 1, old - 2, old - 3, old - 4] - assert check_flapping(tracker, max_restarts=5) is False + + # Act + result = check_flapping(tracker, max_restarts=5) + + # Assert + assert result is False assert len(tracker.restart_timestamps) == 0 def test_custom_window(self) -> None: @@ -158,6 +165,7 @@ def test_flapping_too_recent_returns_false(self) -> None: assert reset_flap_state(tracker, stable_period=300.0) is False def test_flapping_stable_resets(self) -> None: + # Arrange tracker = ServiceTracker( is_flapping=True, last_restart_time=time.monotonic() - 400, @@ -165,7 +173,12 @@ def test_flapping_stable_resets(self) -> None: restart_count=5, consecutive_failures=3, ) - assert reset_flap_state(tracker, stable_period=300.0) is True + + # Act + result = reset_flap_state(tracker, stable_period=300.0) + + # Assert + assert result is True assert tracker.is_flapping is False assert tracker.restart_timestamps == [] assert tracker.restart_count == 0 @@ -215,6 +228,7 @@ class TestRestartService: def test_restart_tier_success( self, mlx_stack_home: Path, stack_definition: dict[str, Any], pids_dir: Path ) -> None: + # Arrange tracker = ServiceTracker() with ( @@ -222,12 +236,11 @@ def test_restart_tier_success( patch("mlx_stack.core.watchdog.start_service") as mock_start, patch("mlx_stack.core.watchdog.remove_pid_file"), ): - # Make acquire_lock a context manager mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) - mock_start.return_value = MagicMock() + # Act result = restart_service( service_name="fast", stack=stack_definition, @@ -235,11 +248,13 @@ def test_restart_tier_success( vllm_binary="/usr/bin/vllm-mlx", ) + # Assert assert result is True def test_restart_unknown_tier_fails( self, mlx_stack_home: Path, stack_definition: dict[str, Any], pids_dir: Path ) -> None: + # Arrange tracker = ServiceTracker() with ( @@ -249,6 +264,7 @@ def test_restart_unknown_tier_fails( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) + # Act result = restart_service( service_name="nonexistent", stack=stack_definition, @@ -256,6 +272,7 @@ def test_restart_unknown_tier_fails( vllm_binary="/usr/bin/vllm-mlx", ) + # Assert assert result is False def test_restart_lock_failure( @@ -263,8 +280,10 @@ def test_restart_lock_failure( ) -> None: from mlx_stack.core.process import LockError + # Arrange tracker = ServiceTracker() + # Act with ( patch("mlx_stack.core.watchdog.acquire_lock", side_effect=LockError("locked")), patch("mlx_stack.core.watchdog.remove_pid_file"), @@ -276,11 +295,13 @@ def test_restart_lock_failure( vllm_binary="/usr/bin/vllm-mlx", ) + # Assert assert result is False def test_restart_litellm( self, mlx_stack_home: Path, stack_definition: dict[str, Any], pids_dir: Path ) -> None: + # Arrange tracker = ServiceTracker() with ( @@ -297,6 +318,7 @@ def test_restart_litellm( "openrouter-key": "", }.get(key, "") + # Act result = restart_service( service_name="litellm", stack=stack_definition, @@ -304,6 +326,7 @@ def test_restart_litellm( litellm_binary="/usr/bin/litellm", ) + # Assert assert result is True @@ -319,24 +342,27 @@ def test_no_pid_file(self, mlx_stack_home: Path) -> None: assert check_existing_watchdog() is None def test_stale_pid_file(self, mlx_stack_home: Path, pids_dir: Path) -> None: - # Write a PID file with a dead PID - pid_path = pids_dir / "watchdog.pid" - pid_path.write_text("99999999") + # Arrange + pid_path = create_pid_file(mlx_stack_home, "watchdog", pid=99999999) + # Act with patch("mlx_stack.core.watchdog.is_process_alive", return_value=False): result = check_existing_watchdog() + # Assert assert result is None # Stale PID file should be cleaned up assert not pid_path.exists() def test_running_watchdog(self, mlx_stack_home: Path, pids_dir: Path) -> None: - pid_path = pids_dir / "watchdog.pid" - pid_path.write_text("12345") + # Arrange + create_pid_file(mlx_stack_home, "watchdog", pid=12345) + # Act with patch("mlx_stack.core.watchdog.is_process_alive", return_value=True): result = check_existing_watchdog() + # Assert assert result == 12345 @@ -349,23 +375,29 @@ class TestSignalHandling: """Tests for setup_signal_handlers.""" def test_sigterm_sets_shutdown_flag(self) -> None: + # Arrange state = WatchdogState() setup_signal_handlers(state) assert state.shutdown_requested is False - # Send SIGTERM to ourselves + # Act os.kill(os.getpid(), signal.SIGTERM) + + # Assert assert state.shutdown_requested is True def test_sigint_sets_shutdown_flag(self) -> None: + # Arrange state = WatchdogState() setup_signal_handlers(state) assert state.shutdown_requested is False - # Manually call the handler (SIGINT would be caught by pytest) + # Act — manually call the handler (SIGINT would be caught by pytest) handler = signal.getsignal(signal.SIGINT) if callable(handler): handler(signal.SIGINT, None) + + # Assert assert state.shutdown_requested is True @@ -378,41 +410,42 @@ class TestRotateServiceLogs: """Tests for rotate_service_logs.""" def test_no_logs_directory(self, mlx_stack_home: Path) -> None: + # Act with patch("mlx_stack.core.watchdog.get_value", return_value=50): count = rotate_service_logs() + + # Assert assert count == 0 def test_rotates_eligible_files(self, mlx_stack_home: Path, logs_dir: Path) -> None: - # Create a log file exceeding threshold - log_file = logs_dir / "fast.log" - log_file.write_bytes(b"x" * (1 * 1024 * 1024)) # 1 MB + # Arrange — create a log file exceeding threshold (1 MB) + create_log_file(logs_dir, "fast", size_mb=1) + # Act with patch("mlx_stack.core.watchdog.get_value") as mock_get: mock_get.side_effect = lambda key: { "log-max-size-mb": 0, # Will cause threshold to be 0 "log-max-files": 5, }.get(key, 50) - - # rotate_log checks if size >= threshold_bytes - # With threshold 0 * 1024 * 1024 = 0, but file > 0, should rotate count = rotate_service_logs() - # Since max_size_mb=0 means threshold=0 bytes, any non-empty file rotates - assert count >= 0 # The actual rotation depends on implementation + # Assert — max_size_mb=0 means threshold=0 bytes, any non-empty file rotates + assert count >= 0 def test_skips_non_log_files(self, mlx_stack_home: Path, logs_dir: Path) -> None: - # Create a non-.log file + # Arrange other_file = logs_dir / "fast.log.1.gz" other_file.write_bytes(b"compressed data") + # Act with patch("mlx_stack.core.watchdog.get_value") as mock_get: mock_get.side_effect = lambda key: { "log-max-size-mb": 50, "log-max-files": 5, }.get(key, 50) - count = rotate_service_logs() + # Assert assert count == 0 @@ -427,10 +460,10 @@ class TestPollCycle: def test_basic_poll_no_crashes( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: - state = WatchdogState() - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult( services=[ ServiceStatus( @@ -456,6 +489,7 @@ def test_basic_poll_no_crashes( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -468,6 +502,7 @@ def test_basic_poll_no_crashes( restart_delay=10, ) + # Assert assert result.restarts_attempted == 0 assert len(result.statuses) == 2 assert state.cycle_count == 1 @@ -475,10 +510,10 @@ def test_basic_poll_no_crashes( def test_poll_with_crashed_service_triggers_restart( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: - state = WatchdogState() - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult( services=[ ServiceStatus( @@ -504,6 +539,7 @@ def test_poll_with_crashed_service_triggers_restart( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -517,6 +553,7 @@ def test_poll_with_crashed_service_triggers_restart( restart_delay=10, ) + # Assert assert result.restarts_attempted == 1 assert result.restarts_succeeded == 1 assert len(state.restart_log) == 1 @@ -527,10 +564,10 @@ def test_poll_does_not_restart_stopped_service( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: """Stopped services (no PID file) should NOT be restarted.""" - state = WatchdogState() - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult( services=[ ServiceStatus( @@ -546,6 +583,7 @@ def test_poll_does_not_restart_stopped_service( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -559,21 +597,21 @@ def test_poll_does_not_restart_stopped_service( restart_delay=10, ) + # Assert assert result.restarts_attempted == 0 def test_poll_flapping_service_not_restarted( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: """Flapping services should not be restarted.""" + from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + + # Arrange state = WatchdogState() - # Pre-mark service as flapping state.service_trackers["fast"] = ServiceTracker( is_flapping=True, last_restart_time=time.monotonic(), ) - - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult - mock_status = StatusResult( services=[ ServiceStatus( @@ -589,6 +627,7 @@ def test_poll_flapping_service_not_restarted( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -602,6 +641,7 @@ def test_poll_flapping_service_not_restarted( restart_delay=10, ) + # Assert assert result.restarts_attempted == 0 assert "fast" in result.flapping_services @@ -609,15 +649,14 @@ def test_poll_respects_restart_delay( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: """Should not restart if delay has not elapsed.""" + from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + + # Arrange state = WatchdogState() - # Set last restart time to now (delay not elapsed yet) state.service_trackers["fast"] = ServiceTracker( last_restart_time=time.monotonic(), consecutive_failures=0, ) - - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult - mock_status = StatusResult( services=[ ServiceStatus( @@ -633,6 +672,7 @@ def test_poll_respects_restart_delay( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -646,16 +686,17 @@ def test_poll_respects_restart_delay( restart_delay=10, ) + # Assert assert result.restarts_attempted == 0 def test_poll_with_failed_restart( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: """Failed restart should increment consecutive_failures.""" - state = WatchdogState() - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult( services=[ ServiceStatus( @@ -671,6 +712,7 @@ def test_poll_with_failed_restart( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -684,6 +726,7 @@ def test_poll_with_failed_restart( restart_delay=10, ) + # Assert assert result.restarts_attempted == 1 assert result.restarts_succeeded == 0 assert state.service_trackers["fast"].consecutive_failures == 1 @@ -691,12 +734,13 @@ def test_poll_with_failed_restart( def test_poll_log_rotation_counted( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: - state = WatchdogState() - from mlx_stack.core.stack_status import StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult(services=[]) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=3), @@ -709,20 +753,22 @@ def test_poll_log_rotation_counted( restart_delay=10, ) + # Assert assert result.rotations_performed == 3 def test_poll_no_stack_returns_early( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: - state = WatchdogState() - from mlx_stack.core.stack_status import StatusResult + # Arrange + state = WatchdogState() mock_status = StatusResult( no_stack=True, message="No stack configured", ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -735,20 +781,21 @@ def test_poll_no_stack_returns_early( restart_delay=10, ) + # Assert assert result.restarts_attempted == 0 def test_poll_healthy_resets_consecutive_failures( self, mlx_stack_home: Path, stack_definition: dict[str, Any] ) -> None: """Healthy service should reset consecutive_failures.""" + from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult + + # Arrange state = WatchdogState() state.service_trackers["fast"] = ServiceTracker( consecutive_failures=3, last_restart_time=time.monotonic() - 100, ) - - from mlx_stack.core.stack_status import ServiceHealth, ServiceStatus, StatusResult - mock_status = StatusResult( services=[ ServiceStatus( @@ -764,6 +811,7 @@ def test_poll_healthy_resets_consecutive_failures( ] ) + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -776,6 +824,7 @@ def test_poll_healthy_resets_consecutive_failures( restart_delay=10, ) + # Assert assert state.service_trackers["fast"].consecutive_failures == 0 @@ -788,16 +837,17 @@ class TestRunWatchdog: """Tests for run_watchdog.""" def test_no_stack_raises_error(self, mlx_stack_home: Path) -> None: + # Act / Assert with pytest.raises(WatchdogError, match="No stack configuration found"): run_watchdog() def test_already_running_raises_error( self, mlx_stack_home: Path, stack_definition: dict[str, Any], pids_dir: Path ) -> None: - # Write a watchdog PID file with a "running" process - pid_path = pids_dir / "watchdog.pid" - pid_path.write_text(str(os.getpid())) # Current process is alive + # Arrange — write a watchdog PID file with a "running" process + create_pid_file(mlx_stack_home, "watchdog", pid=os.getpid()) + # Act / Assert with pytest.raises(WatchdogError, match="already running"): run_watchdog() @@ -807,15 +857,16 @@ def test_watchdog_loop_with_immediate_shutdown( """Test that the watchdog loop runs and exits on shutdown.""" from mlx_stack.core.stack_status import StatusResult + # Arrange mock_status = StatusResult(services=[]) call_count = 0 def status_callback(result: PollResult, state: WatchdogState) -> None: nonlocal call_count call_count += 1 - # Set shutdown after first cycle state.shutdown_requested = True + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), @@ -825,9 +876,9 @@ def status_callback(result: PollResult, state: WatchdogState) -> None: status_callback=status_callback, ) + # Assert assert call_count == 1 assert state.cycle_count == 1 - # PID file should be cleaned up assert not (pids_dir / "watchdog.pid").exists() def test_watchdog_cleanup_on_exit( @@ -836,17 +887,20 @@ def test_watchdog_cleanup_on_exit( """Test that watchdog cleans up PID file on exit.""" from mlx_stack.core.stack_status import StatusResult + # Arrange mock_status = StatusResult(services=[]) def status_callback(result: PollResult, state: WatchdogState) -> None: state.shutdown_requested = True + # Act with ( patch("mlx_stack.core.watchdog.run_status", return_value=mock_status), patch("mlx_stack.core.watchdog.rotate_service_logs", return_value=0), ): run_watchdog(interval=1, status_callback=status_callback) + # Assert assert not (pids_dir / "watchdog.pid").exists() @@ -859,6 +913,7 @@ class TestDaemonize: """Tests for daemonize (mocked).""" def test_daemonize_calls_fork_and_setsid(self, mlx_stack_home: Path) -> None: + # Act with ( patch("os.fork", side_effect=[0, 0]) as mock_fork, patch("os.setsid") as mock_setsid, @@ -869,11 +924,13 @@ def test_daemonize_calls_fork_and_setsid(self, mlx_stack_home: Path) -> None: ): daemonize() + # Assert assert mock_fork.call_count == 2 mock_setsid.assert_called_once() def test_daemonize_first_fork_parent_exits(self, mlx_stack_home: Path) -> None: """When first fork returns >0, os._exit(0) should be called.""" + # Arrange exit_called = False def fake_exit(code: int) -> None: @@ -881,6 +938,7 @@ def fake_exit(code: int) -> None: exit_called = True raise SystemExit(code) + # Act / Assert with ( patch("os.fork", return_value=123), patch("os._exit", side_effect=fake_exit), @@ -891,6 +949,7 @@ def fake_exit(code: int) -> None: assert exit_called def test_daemonize_first_fork_failure(self, mlx_stack_home: Path) -> None: + # Act / Assert with ( patch("os.fork", side_effect=OSError("fork failed")), pytest.raises(WatchdogError, match="First fork failed"), @@ -907,10 +966,13 @@ class TestRemoveWatchdogPid: """Tests for remove_watchdog_pid.""" def test_removes_existing_pid_file(self, mlx_stack_home: Path, pids_dir: Path) -> None: - pid_path = pids_dir / "watchdog.pid" - pid_path.write_text("12345") + # Arrange + pid_path = create_pid_file(mlx_stack_home, "watchdog", pid=12345) + # Act remove_watchdog_pid() + + # Assert assert not pid_path.exists() def test_no_pid_file_is_noop(self, mlx_stack_home: Path) -> None: From f3fb4c6bd15f502f91f0277f742db0c8afd812e5 Mon Sep 17 00:00:00 2001 From: Wes Eklund Date: Sat, 4 Apr 2026 10:11:06 -0400 Subject: [PATCH 2/3] fix(tests): add None guards for TierStatus.error to satisfy pyright Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/unit/test_cli_up.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_cli_up.py b/tests/unit/test_cli_up.py index 434ae1e..3b11682 100644 --- a/tests/unit/test_cli_up.py +++ b/tests/unit/test_cli_up.py @@ -447,6 +447,7 @@ def test_port_conflict_skips_tier( skipped = [t for t in result.tiers if t.status == "skipped"] assert len(skipped) == 1 assert skipped[0].name == "standard" + assert skipped[0].error is not None assert "54321" in skipped[0].error assert "node" in skipped[0].error assert "8000" in skipped[0].error @@ -570,6 +571,7 @@ def test_missing_model_skips_tier( skipped = [t for t in result.tiers if t.status == "skipped"] assert len(skipped) == 1 assert skipped[0].name == "standard" + assert skipped[0].error is not None assert "not found" in skipped[0].error.lower() # fast should still start healthy = [t for t in result.tiers if t.status == "healthy"] From a791172b458600bca5a19ac90013c409ddf40068 Mon Sep 17 00:00:00 2001 From: Wes Eklund Date: Sat, 4 Apr 2026 10:13:07 -0400 Subject: [PATCH 3/3] chore: include pyright in make lint for shift-left type checking Running `make lint` now executes both ruff and pyright so type errors are caught locally before push, not just in CI. Co-Authored-By: Claude Opus 4.6 (1M context) --- Makefile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 97eba5a..e30d924 100644 --- a/Makefile +++ b/Makefile @@ -4,11 +4,12 @@ install: uv sync --dev -## Lint source and tests +## Lint source and tests (ruff + pyright) lint: uv run ruff check src/ tests/ + uv run python -m pyright -## Run type checker across the full project +## Run type checker only (alias kept for CI compatibility) typecheck: uv run python -m pyright