From 85313f5559243ba801b886eae7ec38b770a154a3 Mon Sep 17 00:00:00 2001 From: Wes Eklund Date: Fri, 3 Apr 2026 10:24:06 -0400 Subject: [PATCH] feat: expand ruff lint rules with tier 1+2 quality rulesets Add 12 new ruff rule sets for code quality, bug prevention, and modernization. Fix all auto-fixable violations and manually resolve remaining issues including dead code, ambiguous unicode, raw regex strings, loop-to-comprehension conversions, and control flow simplification. New rule sets: UP (pyupgrade), B (bugbear), SIM (simplify), C4 (comprehensions), RUF (ruff-specific), PIE (misc cleanup), RET (return simplification), PERF (performance), PT (pytest style), FURB (modernization), FLY (f-string conversion). Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 35 +- src/mlx_stack/cli/bench.py | 19 +- src/mlx_stack/cli/config.py | 4 +- src/mlx_stack/cli/init.py | 10 +- src/mlx_stack/cli/install.py | 4 +- src/mlx_stack/cli/logs.py | 8 +- src/mlx_stack/cli/models.py | 8 +- src/mlx_stack/cli/profile.py | 8 +- src/mlx_stack/cli/pull.py | 15 +- src/mlx_stack/cli/recommend.py | 29 +- src/mlx_stack/cli/setup.py | 61 ++-- src/mlx_stack/cli/up.py | 16 +- src/mlx_stack/cli/watch.py | 20 +- src/mlx_stack/core/benchmark.py | 39 +-- src/mlx_stack/core/catalog.py | 49 +-- src/mlx_stack/core/config.py | 18 +- src/mlx_stack/core/deps.py | 21 +- src/mlx_stack/core/discovery.py | 24 +- src/mlx_stack/core/hardware.py | 2 +- src/mlx_stack/core/launchd.py | 25 +- src/mlx_stack/core/litellm_gen.py | 4 +- src/mlx_stack/core/log_rotation.py | 13 +- src/mlx_stack/core/log_viewer.py | 15 +- src/mlx_stack/core/models.py | 71 ++-- src/mlx_stack/core/onboarding.py | 103 +++--- src/mlx_stack/core/process.py | 29 +- src/mlx_stack/core/pull.py | 35 +- src/mlx_stack/core/scoring.py | 50 +-- src/mlx_stack/core/stack_down.py | 18 +- src/mlx_stack/core/stack_init.py | 33 +- src/mlx_stack/core/stack_status.py | 53 +-- src/mlx_stack/core/stack_up.py | 194 +++++------ src/mlx_stack/core/watchdog.py | 17 +- tests/conftest.py | 4 +- tests/integration/conftest.py | 51 +-- tests/integration/report.py | 6 +- tests/integration/test_catalog_validation.py | 19 +- .../integration/test_harness_compatibility.py | 20 +- tests/integration/test_inference_e2e.py | 23 +- tests/integration/test_launchd_e2e.py | 40 +-- tests/integration/test_model_smoke.py | 7 +- tests/integration/test_stack_integration.py | 23 +- tests/unit/test_benchmark.py | 94 +++--- tests/unit/test_catalog.py | 42 ++- tests/unit/test_cli.py | 12 +- tests/unit/test_cli_bench.py | 83 +++-- tests/unit/test_cli_down.py | 91 +++-- tests/unit/test_cli_init.py | 318 +++++++++++------- tests/unit/test_cli_install.py | 52 +-- tests/unit/test_cli_logs.py | 8 +- tests/unit/test_cli_models.py | 78 +++-- tests/unit/test_cli_profile.py | 112 ++---- tests/unit/test_cli_pull.py | 8 +- tests/unit/test_cli_recommend.py | 23 +- tests/unit/test_cli_setup.py | 87 +++-- tests/unit/test_cli_status.py | 31 +- tests/unit/test_cli_up.py | 150 ++++++--- tests/unit/test_cli_watch.py | 37 +- tests/unit/test_config.py | 11 +- tests/unit/test_cross_area.py | 82 ++--- tests/unit/test_deps.py | 51 +-- tests/unit/test_discovery.py | 6 +- tests/unit/test_hardware.py | 13 +- tests/unit/test_launchd.py | 44 +-- tests/unit/test_lifecycle_fixes.py | 16 +- tests/unit/test_litellm_gen.py | 15 +- tests/unit/test_log_rotation.py | 2 +- tests/unit/test_log_viewer.py | 42 +-- tests/unit/test_onboarding.py | 146 ++++++-- tests/unit/test_ops_cross_area.py | 187 ++++++---- tests/unit/test_paths.py | 4 +- tests/unit/test_process.py | 43 +-- tests/unit/test_robustness_fixes.py | 6 +- tests/unit/test_scoring.py | 213 ++++-------- tests/unit/test_watchdog.py | 117 ++++--- 75 files changed, 1769 insertions(+), 1698 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2c66b2..0f7ba79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,40 @@ line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -select = ["E", "F", "I", "W"] +select = [ + # Tier 1 — high value, very safe + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "W", # pycodestyle warnings + "UP", # pyupgrade — modern Python syntax + "B", # bugbear — common bug patterns + "SIM", # simplify — reduce complexity + "C4", # flake8-comprehensions + "RUF", # Ruff-specific rules + # Tier 2 — strong value, minor tuning + "PIE", # misc cleanup + "RET", # return simplification + "PERF", # performance anti-patterns + "PT", # pytest style + # "C90", # mccabe complexity — enable after refactoring complex functions + "FURB", # modernization + "FLY", # f-string conversion +] +ignore = [ + "E501", # line length (formatter handles this) + "SIM108", # ternary operator (opinionated) + "SIM117", # nested with statements (clearer in test mocking patterns) + "PT018", # composite assertions (splitting weakens error messages) + "PT019", # fixture without value (usefixtures less readable) + "PT017", # assert in except (valid test pattern) +] + +[tool.ruff.lint.mccabe] +max-complexity = 10 + +[tool.ruff.lint.per-file-ignores] +"src/mlx_stack/_version.py" = ["RUF022"] # auto-generated by hatch-vcs [tool.pyright] pythonVersion = "3.13" diff --git a/src/mlx_stack/cli/bench.py b/src/mlx_stack/cli/bench.py index c6ba33b..cf4b45b 100644 --- a/src/mlx_stack/cli/bench.py +++ b/src/mlx_stack/cli/bench.py @@ -168,23 +168,16 @@ def _display_results(result: BenchmarkResult_, out: Console, save: bool = False) out.print(Text("Tool Calling", style="bold cyan")) tc = result.tool_call_result if tc.success: - out.print( - f" [green]✓ Valid tool call[/green] — " - f"round-trip: {tc.round_trip_time:.2f}s" - ) + out.print(f" [green]✓ Valid tool call[/green] — round-trip: {tc.round_trip_time:.2f}s") else: - out.print( - f" [red]✗ Tool call failed[/red] — {tc.error}" - ) + out.print(f" [red]✗ Tool call failed[/red] — {tc.error}") out.print() elif not result.tool_call_result: # Check if model supports tool calling from entry if not result.catalog_data_available: pass # Skip silently if no catalog data else: - out.print( - "[dim]Tool calling: skipped (model does not support tool calling)[/dim]" - ) + out.print("[dim]Tool calling: skipped (model does not support tool calling)[/dim]") out.print() # Iteration details @@ -212,6 +205,8 @@ def _display_results(result: BenchmarkResult_, out: Console, save: bool = False) # Save confirmation if save: - out.print("[green]✓ Results saved.[/green] " - "These will be used by 'recommend' and 'init' for scoring.") + out.print( + "[green]✓ Results saved.[/green] " + "These will be used by 'recommend' and 'init' for scoring." + ) out.print() diff --git a/src/mlx_stack/cli/config.py b/src/mlx_stack/cli/config.py index fe77654..cfc73ba 100644 --- a/src/mlx_stack/cli/config.py +++ b/src/mlx_stack/cli/config.py @@ -138,9 +138,7 @@ def config_reset(yes: bool, force: bool) -> None: # Check if stdin is a TTY for interactive confirmation try: if click.get_text_stream("stdin").isatty(): - confirmed = click.confirm( - "Reset all configuration to defaults?", default=False - ) + confirmed = click.confirm("Reset all configuration to defaults?", default=False) else: console.print( "[bold red]Error:[/bold red] Reset requires --yes or --force flag " diff --git a/src/mlx_stack/cli/init.py b/src/mlx_stack/cli/init.py index a017a24..9d9bf23 100644 --- a/src/mlx_stack/cli/init.py +++ b/src/mlx_stack/cli/init.py @@ -64,13 +64,10 @@ def _display_summary(result: dict) -> None: budget_gb = result["memory_budget_gb"] total_memory_gb = result.get("total_memory_gb", 0.0) out.print( - f"[dim]Hardware: {profile.chip} ({profile.memory_gb} GB) · " - f"Budget: {budget_gb:.1f} GB[/dim]" + f"[dim]Hardware: {profile.chip} ({profile.memory_gb} GB) · Budget: {budget_gb:.1f} GB[/dim]" ) if total_memory_gb > 0: - out.print( - f"[dim]Total estimated memory: {total_memory_gb:.1f} GB[/dim]" - ) + out.print(f"[dim]Total estimated memory: {total_memory_gb:.1f} GB[/dim]") # Warnings (e.g., memory budget exceeded with --add) init_warnings = result.get("warnings", []) @@ -83,8 +80,7 @@ def _display_summary(result: dict) -> None: if stack.get("cloud_fallback"): out.print() out.print( - "[bold green]☁ Cloud Fallback[/bold green] " - "Premium tier via OpenRouter configured" + "[bold green]☁ Cloud Fallback[/bold green] Premium tier via OpenRouter configured" ) # Missing models warning diff --git a/src/mlx_stack/cli/install.py b/src/mlx_stack/cli/install.py index 6d7ae72..6d6b73c 100644 --- a/src/mlx_stack/cli/install.py +++ b/src/mlx_stack/cli/install.py @@ -42,9 +42,7 @@ def _display_status(status: AgentStatus) -> None: if not status.installed: out.print(Text("Status: not installed", style="dim")) elif status.running and status.pid is not None: - out.print( - Text(f"Status: installed and running (PID {status.pid})", style="green") - ) + out.print(Text(f"Status: installed and running (PID {status.pid})", style="green")) else: out.print(Text("Status: installed but not running", style="yellow")) diff --git a/src/mlx_stack/cli/logs.py b/src/mlx_stack/cli/logs.py index 7f1738a..0652a84 100644 --- a/src/mlx_stack/cli/logs.py +++ b/src/mlx_stack/cli/logs.py @@ -6,6 +6,7 @@ from __future__ import annotations +import contextlib import sys import click @@ -78,7 +79,7 @@ def _display_rotation_results(results: list) -> None: out.print(f"[green]✓[/green] {result.service}: rotated") any_rotated = True else: - out.print(f"[dim]–[/dim] {result.service}: no rotation needed") + out.print(f"[dim]-[/dim] {result.service}: no rotation needed") if not results: out.print(Text("No log files found to rotate.", style="yellow")) @@ -213,11 +214,8 @@ def logs( # Handle --follow mode if follow: num = tail_lines if tail_lines is not None else DEFAULT_TAIL_LINES - try: + with contextlib.suppress(KeyboardInterrupt): follow_log(log_path, num_lines=num, output_callback=click.echo) - except KeyboardInterrupt: - # Belt-and-suspenders: ensure clean exit - pass return # Default: show tail of log diff --git a/src/mlx_stack/cli/models.py b/src/mlx_stack/cli/models.py index 81f6beb..095a646 100644 --- a/src/mlx_stack/cli/models.py +++ b/src/mlx_stack/cli/models.py @@ -80,7 +80,7 @@ def _display_local_models() -> None: indicator_style = "bold green" if model.is_active else "" # Display name: prefer catalog name, fall back to directory name - display_name = model.catalog_name if model.catalog_name else model.name + display_name = model.catalog_name or model.name # Size size_str = format_size(model.disk_size_bytes) @@ -217,7 +217,7 @@ def _display_catalog( local_style = "bold green" if cm.is_local else "" # Parameters - params_str = f"{cm.params_b:.1f}B" if cm.params_b >= 1.0 else f"{cm.params_b:.1f}B" + params_str = f"{cm.params_b:.1f}B" # Quantizations quants_str = ", ".join(cm.quants) @@ -270,7 +270,9 @@ def _display_catalog( @click.option("--family", default=None, help="Filter catalog by model family (e.g., 'qwen3.5').") @click.option("--tag", default=None, help="Filter catalog by tag (e.g., 'agent-ready').") @click.option( - "--tool-calling", "tool_calling", is_flag=True, + "--tool-calling", + "tool_calling", + is_flag=True, help="Filter catalog to tool-calling-capable models only.", ) def models( diff --git a/src/mlx_stack/cli/profile.py b/src/mlx_stack/cli/profile.py index cfc750a..5e31bca 100644 --- a/src/mlx_stack/cli/profile.py +++ b/src/mlx_stack/cli/profile.py @@ -52,12 +52,8 @@ def profile() -> None: if hw.is_estimate: out.print() - out.print( - "[yellow]⚠ Bandwidth is estimated for unknown chip.[/yellow]" - ) - out.print( - " Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements." - ) + out.print("[yellow]⚠ Bandwidth is estimated for unknown chip.[/yellow]") + out.print(" Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements.") out.print() from mlx_stack.core.paths import get_profile_path diff --git a/src/mlx_stack/cli/pull.py b/src/mlx_stack/cli/pull.py index f8b5acd..4416f3a 100644 --- a/src/mlx_stack/cli/pull.py +++ b/src/mlx_stack/cli/pull.py @@ -115,20 +115,13 @@ def _run_post_download_bench(model_id: str, quant: str, out: Console) -> None: from mlx_stack.core.benchmark import BenchmarkError, run_benchmark result = run_benchmark(target=model_id, save=True) - out.print( - f" Prompt TPS: {result.prompt_tps_mean:.1f} ± {result.prompt_tps_std:.1f} tok/s" - ) - out.print( - f" Gen TPS: {result.gen_tps_mean:.1f} ± {result.gen_tps_std:.1f} tok/s" - ) + out.print(f" Prompt TPS: {result.prompt_tps_mean:.1f} ± {result.prompt_tps_std:.1f} tok/s") + out.print(f" Gen TPS: {result.gen_tps_mean:.1f} ± {result.gen_tps_std:.1f} tok/s") out.print() - out.print( - "[dim]Results saved for use by 'recommend' and 'init' scoring.[/dim]" - ) + out.print("[dim]Results saved for use by 'recommend' and 'init' scoring.[/dim]") except BenchmarkError as exc: out.print( - f"[yellow]Benchmark failed: {exc}[/yellow]\n" - f"Run 'mlx-stack bench {model_id}' to retry." + f"[yellow]Benchmark failed: {exc}[/yellow]\nRun 'mlx-stack bench {model_id}' to retry." ) except Exception as exc: out.print( diff --git a/src/mlx_stack/cli/recommend.py b/src/mlx_stack/cli/recommend.py index 4331b33..274cfd7 100644 --- a/src/mlx_stack/cli/recommend.py +++ b/src/mlx_stack/cli/recommend.py @@ -67,9 +67,7 @@ def parse_budget(raw: str) -> float: value = float(match.group(1)) if value <= 0: - msg = ( - f"Invalid budget '{raw}'. Budget must be a positive value." - ) + msg = f"Invalid budget '{raw}'. Budget must be a positive value." raise click.BadParameter(msg, param_hint="'--budget'") return value @@ -96,8 +94,7 @@ def _resolve_profile() -> HardwareProfile: # Auto-detect (in-memory only — recommend is display-only, no file writes) console.print("[dim]No saved profile found — detecting hardware...[/dim]") try: - profile = detect_hardware() - return profile + return detect_hardware() except HardwareError as exc: console.print(f"[bold red]Error:[/bold red] {exc}") raise SystemExit(1) from None @@ -206,12 +203,8 @@ def _display_tier_table(result: RecommendationResult) -> None: has_estimates = any(t.model.is_estimated for t in result.tiers) if has_estimates: out.print() - out.print( - "[yellow]⚠ Some performance values are estimated from bandwidth ratio.[/yellow]" - ) - out.print( - " Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements." - ) + out.print("[yellow]⚠ Some performance values are estimated from bandwidth ratio.[/yellow]") + out.print(" Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements.") out.print() out.print("[dim]This is a recommendation only — no files were written.[/dim]") @@ -268,20 +261,15 @@ def _display_all_models(result: RecommendationResult) -> None: if openrouter_key: out.print() out.print( - "[bold green]☁ Cloud Fallback[/bold green] " - "Premium tier via OpenRouter also available." + "[bold green]☁ Cloud Fallback[/bold green] Premium tier via OpenRouter also available." ) # Estimated warning has_estimates = any(m.is_estimated for m in result.all_scored) if has_estimates: out.print() - out.print( - "[yellow]⚠ Some performance values are estimated from bandwidth ratio.[/yellow]" - ) - out.print( - " Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements." - ) + out.print("[yellow]⚠ Some performance values are estimated from bandwidth ratio.[/yellow]") + out.print(" Run [bold]mlx-stack bench --save[/bold] to calibrate with real measurements.") out.print() out.print("[dim]This is a recommendation only — no files were written.[/dim]") @@ -329,8 +317,7 @@ def recommend(budget: str | None, intent: str | None, show_all: bool) -> None: elif intent not in VALID_INTENTS: valid = ", ".join(sorted(VALID_INTENTS)) console.print( - f"[bold red]Error:[/bold red] Invalid intent '{intent}'. " - f"Valid intents: {valid}" + f"[bold red]Error:[/bold red] Invalid intent '{intent}'. Valid intents: {valid}" ) raise SystemExit(1) diff --git a/src/mlx_stack/cli/setup.py b/src/mlx_stack/cli/setup.py index 16eaff5..e22d29e 100644 --- a/src/mlx_stack/cli/setup.py +++ b/src/mlx_stack/cli/setup.py @@ -170,9 +170,7 @@ def _prompt_model_selection( Input like '1:int8,3' = model 1 as int8, model 3 as default quant. """ if accept_defaults: - return [ - (i, s.model.quant) for i, s in enumerate(scored) if s.is_recommended - ] + return [(i, s.model.quant) for i, s in enumerate(scored) if s.is_recommended] out.print() raw = click.prompt( @@ -184,9 +182,7 @@ def _prompt_model_selection( if not raw.strip(): # Accept defaults - return [ - (i, s.model.quant) for i, s in enumerate(scored) if s.is_recommended - ] + return [(i, s.model.quant) for i, s in enumerate(scored) if s.is_recommended] # Parse input selections: list[tuple[int, str]] = [] @@ -245,8 +241,8 @@ def _display_final_status(tiers: list[Any], litellm_port: int) -> None: out.print( f" curl http://localhost:{litellm_port}/v1/chat/completions \\\n" f" -H 'Content-Type: application/json' \\\n" - f" -d '{{\"model\":\"{tiers[0].tier_name}\"," - f"\"messages\":[{{\"role\":\"user\",\"content\":\"Hello!\"}}]}}'" + f' -d \'{{"model":"{tiers[0].tier_name}",' + f'"messages":[{{"role":"user","content":"Hello!"}}]}}\'' ) out.print() out.print(" [dim]Manage your stack:[/dim]") @@ -343,10 +339,7 @@ def setup( raise SystemExit(1) from None if not all_models: - console.print( - "[bold red]Error:[/bold red] No models found. " - "Check your network connection." - ) + console.print("[bold red]Error:[/bold red] No models found. Check your network connection.") raise SystemExit(1) from None scored = score_and_filter(all_models, intent, budget_gb) @@ -400,25 +393,29 @@ def setup( thinking=s.model.thinking, has_benchmark=s.model.has_benchmark, ) - selected_models.append(ScoredDiscoveredModel( - model=new_model, - composite_score=s.composite_score, - speed_score=s.speed_score, - quality_score=s.quality_score, - tool_calling_score=s.tool_calling_score, - memory_efficiency_score=s.memory_efficiency_score, - is_recommended=True, - )) + selected_models.append( + ScoredDiscoveredModel( + model=new_model, + composite_score=s.composite_score, + speed_score=s.speed_score, + quality_score=s.quality_score, + tool_calling_score=s.tool_calling_score, + memory_efficiency_score=s.memory_efficiency_score, + is_recommended=True, + ) + ) else: - selected_models.append(ScoredDiscoveredModel( - model=s.model, - composite_score=s.composite_score, - speed_score=s.speed_score, - quality_score=s.quality_score, - tool_calling_score=s.tool_calling_score, - memory_efficiency_score=s.memory_efficiency_score, - is_recommended=True, - )) + selected_models.append( + ScoredDiscoveredModel( + model=s.model, + composite_score=s.composite_score, + speed_score=s.speed_score, + quality_score=s.quality_score, + tool_calling_score=s.tool_calling_score, + memory_efficiency_score=s.memory_efficiency_score, + is_recommended=True, + ) + ) # ── Step 4: Tier assignment ────────────────────────────────────────── tiers = assign_tiers(selected_models) @@ -443,7 +440,7 @@ def setup( raise SystemExit(0) from None try: - stack_path, litellm_path = generate_config( + stack_path, _litellm_path = generate_config( profile=profile, intent=intent, tier_mappings=tiers, @@ -458,7 +455,7 @@ def setup( out.print(" " + "─" * 40) models_to_pull = [t.model for t in tiers] - for i, model in enumerate(models_to_pull, 1): + for i, _model in enumerate(models_to_pull, 1): out.print(f" [bold][{i}/{len(models_to_pull)}][/bold]", end=" ") try: diff --git a/src/mlx_stack/cli/up.py b/src/mlx_stack/cli/up.py index a050e47..ff40cb3 100644 --- a/src/mlx_stack/cli/up.py +++ b/src/mlx_stack/cli/up.py @@ -55,9 +55,7 @@ def _display_summary(result: UpResult) -> None: out.print() if result.already_running: - out.print( - Text("All services are already running.", style="bold yellow") - ) + out.print(Text("All services are already running.", style="bold yellow")) out.print() # Warnings @@ -115,14 +113,10 @@ def _display_summary(result: UpResult) -> None: out.print() # Next steps for healthy stacks - any_healthy = any( - t.status in ("healthy", "already-running") for t in result.tiers - ) + any_healthy = any(t.status in ("healthy", "already-running") for t in result.tiers) if any_healthy: litellm_port = result.litellm.port if result.litellm else 4000 - out.print( - f"[dim]Endpoint: http://localhost:{litellm_port}/v1[/dim]" - ) + out.print(f"[dim]Endpoint: http://localhost:{litellm_port}/v1[/dim]") out.print() @@ -156,8 +150,6 @@ def up(dry_run: bool, tier_filter: str | None) -> None: _display_summary(result) # Exit with non-zero if all tiers failed - any_success = any( - t.status in ("healthy", "already-running", "dry-run") for t in result.tiers - ) + any_success = any(t.status in ("healthy", "already-running", "dry-run") for t in result.tiers) if not any_success and not result.dry_run: raise SystemExit(1) diff --git a/src/mlx_stack/cli/watch.py b/src/mlx_stack/cli/watch.py index 64b3207..83eac9d 100644 --- a/src/mlx_stack/cli/watch.py +++ b/src/mlx_stack/cli/watch.py @@ -7,7 +7,7 @@ from __future__ import annotations import sys -from datetime import datetime, timezone +from datetime import UTC, datetime import click from rich.console import Console @@ -53,7 +53,7 @@ def _format_status_table(result: PollResult, state: WatchdogState) -> None: state: Current watchdog state. """ out = Console() - now = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") + now = datetime.now(tz=UTC).strftime("%Y-%m-%d %H:%M:%S UTC") out.print() out.print(Text(f"[Cycle {state.cycle_count}] {now}", style="bold cyan")) @@ -97,9 +97,7 @@ def _format_status_table(result: PollResult, state: WatchdogState) -> None: Text( f" Restarts: {result.restarts_succeeded}/{result.restarts_attempted} succeeded", style=( - "yellow" - if result.restarts_succeeded < result.restarts_attempted - else "green" + "yellow" if result.restarts_succeeded < result.restarts_attempted else "green" ), ) ) @@ -112,9 +110,7 @@ def _format_restart_event(record: RestartRecord) -> None: record: The restart record. """ out = Console() - ts = datetime.fromtimestamp(record.timestamp, tz=timezone.utc).strftime( - "%Y-%m-%d %H:%M:%S UTC" - ) + ts = datetime.fromtimestamp(record.timestamp, tz=UTC).strftime("%Y-%m-%d %H:%M:%S UTC") status = "✓" if record.success else "✗" style = "green" if record.success else "red" @@ -151,9 +147,7 @@ def _validate_positive_int( click.BadParameter: If value is not positive. """ if value < 1: - raise click.BadParameter( - f"Must be a positive integer (got {value})." - ) + raise click.BadParameter(f"Must be a positive integer (got {value}).") return value @@ -224,9 +218,7 @@ def watch( """ try: if daemon: - console.print( - Text("Starting watchdog in daemon mode...", style="bold cyan") - ) + console.print(Text("Starting watchdog in daemon mode...", style="bold cyan")) run_watchdog( interval=interval, diff --git a/src/mlx_stack/core/benchmark.py b/src/mlx_stack/core/benchmark.py index af86a26..e2d46e6 100644 --- a/src/mlx_stack/core/benchmark.py +++ b/src/mlx_stack/core/benchmark.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib import json import math import os @@ -207,9 +208,7 @@ def _generate_prompt(token_count: int) -> str: ) words_needed = int(token_count / 1.3) + 10 words = base_phrase.split() - repeated = [] - for i in range(words_needed): - repeated.append(words[i % len(words)]) + repeated = [words[i % len(words)] for i in range(words_needed)] return " ".join(repeated) @@ -385,10 +384,7 @@ def _run_single_iteration( with httpx.stream("POST", url, json=payload, timeout=300.0) as response: if response.status_code != 200: body = response.read().decode("utf-8", errors="replace")[:200] - msg = ( - f"API request failed with status {response.status_code}: " - f"{body}" - ) + msg = f"API request failed with status {response.status_code}: {body}" raise BenchmarkRunError(msg) for line in response.iter_lines(): @@ -595,12 +591,8 @@ def _compare_against_catalog( bench = entry.benchmarks[profile_id] classifications: list[MetricClassification] = [] - classifications.append( - _classify_metric("prompt_tps", prompt_tps_mean, bench.prompt_tps) - ) - classifications.append( - _classify_metric("gen_tps", gen_tps_mean, bench.gen_tps) - ) + classifications.append(_classify_metric("prompt_tps", prompt_tps_mean, bench.prompt_tps)) + classifications.append(_classify_metric("gen_tps", gen_tps_mean, bench.gen_tps)) return classifications @@ -807,17 +799,17 @@ def _start_temp_instance( vllm_binary = shutil.which("vllm-mlx") if vllm_binary is None: - msg = ( - "vllm-mlx binary not found on PATH after installation. " - "Try: uv tool install vllm-mlx" - ) + msg = "vllm-mlx binary not found on PATH after installation. Try: uv tool install vllm-mlx" raise BenchmarkError(msg) cmd = [ vllm_binary, - "serve", model_source, - "--port", str(port), - "--host", "127.0.0.1", + "serve", + model_source, + "--port", + str(port), + "--host", + "127.0.0.1", ] # Add tool-calling flags if the model supports it @@ -869,10 +861,8 @@ def _cleanup_temp_instance(service_name: str) -> None: Args: service_name: The service name from PID management. """ - try: + with contextlib.suppress(Exception): stop_service(service_name, grace_period=5.0) - except Exception: - pass # Double-check: try reading PID and kill directly if still alive try: @@ -1097,8 +1087,7 @@ def run_benchmark( temp_service = resolved.temp_service_name try: - result = _execute_benchmark(resolved, save) - return result + return _execute_benchmark(resolved, save) except Exception: # Ensure cleanup on any failure if temp_service: diff --git a/src/mlx_stack/core/catalog.py b/src/mlx_stack/core/catalog.py index de3f726..cbc4cac 100644 --- a/src/mlx_stack/core/catalog.py +++ b/src/mlx_stack/core/catalog.py @@ -95,7 +95,7 @@ class Capabilities: @dataclass(frozen=True) class QualityScores: - """Model quality scores (0–100 scale).""" + """Model quality scores (0-100 scale).""" overall: int coding: int @@ -192,20 +192,14 @@ def _validate_entry(data: dict[str, Any], filename: str) -> None: caps = data["capabilities"] for cap_field in _REQUIRED_CAPABILITIES: if cap_field not in caps: - msg = ( - f"Catalog file '{filename}': capabilities missing " - f"required field '{cap_field}'" - ) + msg = f"Catalog file '{filename}': capabilities missing required field '{cap_field}'" raise CatalogError(msg) # Validate quality scores quality = data["quality"] for q_field in _REQUIRED_QUALITY_FIELDS: if q_field not in quality: - msg = ( - f"Catalog file '{filename}': quality missing " - f"required field '{q_field}'" - ) + msg = f"Catalog file '{filename}': quality missing required field '{q_field}'" raise CatalogError(msg) q_value = quality[q_field] if not isinstance(q_value, (int, float)): @@ -219,10 +213,7 @@ def _validate_entry(data: dict[str, Any], filename: str) -> None: benchmarks = data["benchmarks"] for hw_key, bench_data in benchmarks.items(): if not isinstance(bench_data, dict): - msg = ( - f"Catalog file '{filename}': benchmark entry '{hw_key}' " - f"must be a mapping" - ) + msg = f"Catalog file '{filename}': benchmark entry '{hw_key}' must be a mapping" raise CatalogError(msg) for req_field in ("prompt_tps", "gen_tps", "memory_gb"): if req_field not in bench_data: @@ -275,9 +266,7 @@ def _parse_entry(data: dict[str, Any]) -> CatalogEntry: convert_from=bool(source_data.get("convert_from", False)), ) except (ValueError, TypeError) as exc: - msg = ( - f"Catalog entry '{model_id}': invalid value in source '{quant}': {exc}" - ) + msg = f"Catalog entry '{model_id}': invalid value in source '{quant}': {exc}" raise CatalogError(msg) from None # Parse capabilities @@ -317,10 +306,7 @@ def _parse_entry(data: dict[str, Any]) -> CatalogEntry: memory_gb=float(bench_data["memory_gb"]), ) except (ValueError, TypeError) as exc: - msg = ( - f"Catalog entry '{model_id}': invalid value in " - f"benchmark '{hw_key}': {exc}" - ) + msg = f"Catalog entry '{model_id}': invalid value in benchmark '{hw_key}': {exc}" raise CatalogError(msg) from None try: @@ -371,9 +357,11 @@ def load_catalog() -> list[CatalogEntry]: yaml_files: list[Any] = [] try: - for item in catalog_pkg.iterdir(): - if hasattr(item, "name") and item.name.endswith(".yaml"): - yaml_files.append(item) + yaml_files.extend( + item + for item in catalog_pkg.iterdir() + if hasattr(item, "name") and item.name.endswith(".yaml") + ) except (OSError, TypeError) as exc: msg = f"Could not read catalog directory: {exc}" raise CatalogError(msg) from None @@ -398,10 +386,7 @@ def load_catalog() -> list[CatalogEntry]: if not isinstance(data, dict): actual_type = type(data).__name__ - msg = ( - f"Catalog file '{filename}' must contain a YAML mapping, " - f"got {actual_type}" - ) + msg = f"Catalog file '{filename}' must contain a YAML mapping, got {actual_type}" raise CatalogError(msg) from None _validate_entry(data, filename) @@ -459,10 +444,7 @@ def load_catalog_from_directory(directory: str) -> list[CatalogEntry]: if not isinstance(data, dict): actual_type = type(data).__name__ - msg = ( - f"Catalog file '{filename}' must contain a YAML mapping, " - f"got {actual_type}" - ) + msg = f"Catalog file '{filename}' must contain a YAML mapping, got {actual_type}" raise CatalogError(msg) from None _validate_entry(data, filename) @@ -543,10 +525,7 @@ def query_by_capability( valid_caps = {"tool_calling", "thinking", "vision"} for cap_name in capabilities: if cap_name not in valid_caps: - msg = ( - f"Invalid capability filter '{cap_name}' " - f"(valid: {', '.join(sorted(valid_caps))})" - ) + msg = f"Invalid capability filter '{cap_name}' (valid: {', '.join(sorted(valid_caps))})" raise ValueError(msg) results: list[CatalogEntry] = [] diff --git a/src/mlx_stack/core/config.py b/src/mlx_stack/core/config.py index 8337c4c..53c5bdb 100644 --- a/src/mlx_stack/core/config.py +++ b/src/mlx_stack/core/config.py @@ -441,14 +441,16 @@ def get_all_config() -> list[dict[str, Any]]: is_default = key not in data value = data.get(key, default) - result.append({ - "name": key, - "value": value, - "default": default, - "is_default": is_default, - "description": key_def.description, - "masked_value": mask_value(key, value), - }) + result.append( + { + "name": key, + "value": value, + "default": default, + "is_default": is_default, + "description": key_def.description, + "masked_value": mask_value(key, value), + } + ) return result diff --git a/src/mlx_stack/core/deps.py b/src/mlx_stack/core/deps.py index 33f672d..2de83e4 100644 --- a/src/mlx_stack/core/deps.py +++ b/src/mlx_stack/core/deps.py @@ -158,8 +158,7 @@ def _install_tool(tool: str, version: str) -> None: uv_path = shutil.which("uv") if uv_path is None: msg = ( - "uv is not available on PATH. " - "Install it from https://docs.astral.sh/uv/ and try again." + "uv is not available on PATH. Install it from https://docs.astral.sh/uv/ and try again." ) raise DependencyError(msg) @@ -177,17 +176,10 @@ def _install_tool(tool: str, version: str) -> None: timeout=300, ) except subprocess.TimeoutExpired: - msg = ( - f"Installation timed out: {cmd_str}\n\n" - f"Install manually with: {cmd_str}" - ) + msg = f"Installation timed out: {cmd_str}\n\nInstall manually with: {cmd_str}" raise DependencyInstallError(msg) from None except OSError as exc: - msg = ( - f"Failed to run: {cmd_str}\n" - f"Error: {exc}\n\n" - f"Install manually with: {cmd_str}" - ) + msg = f"Failed to run: {cmd_str}\nError: {exc}\n\nInstall manually with: {cmd_str}" raise DependencyInstallError(msg) from None if result.returncode != 0: @@ -288,7 +280,7 @@ def ensure_dependency(tool: str) -> DependencyStatus: f"This may be because the uv tool bin directory is not in your PATH.\n\n" f"Try running:\n" f" {cmd_str}\n" - f" export PATH=\"$HOME/.local/bin:$PATH\"" + f' export PATH="$HOME/.local/bin:$PATH"' ) raise DependencyInstallError(msg) @@ -314,10 +306,7 @@ def ensure_all_dependencies() -> list[DependencyStatus]: DependencyError: If any dependency cannot be installed. DependencyInstallError: If any auto-install fails. """ - results: list[DependencyStatus] = [] - for tool in PINNED_VERSIONS: - results.append(ensure_dependency(tool)) - return results + return [ensure_dependency(tool) for tool in PINNED_VERSIONS] def _warn_version_mismatch(tool: str, status: DependencyStatus) -> None: diff --git a/src/mlx_stack/core/discovery.py b/src/mlx_stack/core/discovery.py index dfa11b4..c69d916 100644 --- a/src/mlx_stack/core/discovery.py +++ b/src/mlx_stack/core/discovery.py @@ -60,15 +60,15 @@ class DiscoveredModel: # --------------------------------------------------------------------------- # _QUANT_PATTERNS: list[tuple[re.Pattern, str]] = [ - (re.compile(r"-4bit(?:s)?$", re.I), "int4"), - (re.compile(r"-8bit(?:s)?$", re.I), "int8"), - (re.compile(r"-bf16$", re.I), "bf16"), - (re.compile(r"-fp16$", re.I), "fp16"), - (re.compile(r"-6bit$", re.I), "int6"), - (re.compile(r"-5bit$", re.I), "int5"), + (re.compile(r"-4bit(?:s)?$", re.IGNORECASE), "int4"), + (re.compile(r"-8bit(?:s)?$", re.IGNORECASE), "int8"), + (re.compile(r"-bf16$", re.IGNORECASE), "bf16"), + (re.compile(r"-fp16$", re.IGNORECASE), "fp16"), + (re.compile(r"-6bit$", re.IGNORECASE), "int6"), + (re.compile(r"-5bit$", re.IGNORECASE), "int5"), ] -_PARAMS_PATTERN = re.compile(r"(\d+(?:\.\d+)?)B(?!it)", re.I) +_PARAMS_PATTERN = re.compile(r"(\d+(?:\.\d+)?)B(?!it)", re.IGNORECASE) def infer_quant_from_repo(repo_name: str) -> str: @@ -193,7 +193,7 @@ def query_hf_models( """ try: api = HfApi() - models = list( + return list( api.list_models( author=author, pipeline_tag=pipeline_tag, @@ -201,7 +201,6 @@ def query_hf_models( limit=limit, ) ) - return models except Exception as exc: msg = f"HuggingFace API query failed: {exc}" raise DiscoveryError(msg) from None @@ -354,12 +353,15 @@ def discover_models( downloads = getattr(hf_model, "downloads", 0) or 0 model = _build_discovered_model( - repo, downloads, benchmark_data, hardware_profile_id, + repo, + downloads, + benchmark_data, + hardware_profile_id, ) models.append(model) # Add benchmark-only models not found in HF results - for repo, bench in benchmark_data.get("models", {}).items(): + for repo in benchmark_data.get("models", {}): quant = infer_quant_from_repo(repo) if quant != default_quant: continue diff --git a/src/mlx_stack/core/hardware.py b/src/mlx_stack/core/hardware.py index 7919d3f..8d81aad 100644 --- a/src/mlx_stack/core/hardware.py +++ b/src/mlx_stack/core/hardware.py @@ -170,7 +170,7 @@ def detect_memory_gb() -> int: msg = f"Unexpected hw.memsize value: {raw!r}" raise HardwareError(msg) from None - return memsize_bytes // (1024 ** 3) + return memsize_bytes // (1024**3) def detect_gpu_cores() -> int: diff --git a/src/mlx_stack/core/launchd.py b/src/mlx_stack/core/launchd.py index bc8e25c..bc73142 100644 --- a/src/mlx_stack/core/launchd.py +++ b/src/mlx_stack/core/launchd.py @@ -15,6 +15,7 @@ from __future__ import annotations +import contextlib import os import plistlib import shutil @@ -102,10 +103,7 @@ def check_platform() -> None: PlatformError: If not running on macOS (darwin). """ if sys.platform != "darwin": - msg = ( - "launchd integration is only available on macOS. " - f"Current platform: {sys.platform}" - ) + msg = f"launchd integration is only available on macOS. Current platform: {sys.platform}" raise PlatformError(msg) @@ -120,10 +118,7 @@ def check_init_prerequisite() -> None: """ stack_path = get_stacks_dir() / "default.yaml" if not stack_path.exists(): - msg = ( - "No stack configuration found. " - "Run 'mlx-stack init' first." - ) + msg = "No stack configuration found. Run 'mlx-stack init' first." raise PrerequisiteError(msg) @@ -152,10 +147,7 @@ def _resolve_mlx_stack_binary() -> str: if candidate.exists(): return str(candidate) - msg = ( - "Could not find the mlx-stack binary on PATH. " - "Ensure mlx-stack is properly installed." - ) + msg = "Could not find the mlx-stack binary on PATH. Ensure mlx-stack is properly installed." raise LaunchdError(msg) @@ -186,10 +178,7 @@ def _build_environment_variables(mlx_stack_binary: str) -> dict[str, str]: ] # Ensure binary_dir is first, then add standard paths not already present - path_components = [binary_dir] - for p in standard_paths: - if p != binary_dir: - path_components.append(p) + path_components = [binary_dir, *(p for p in standard_paths if p != binary_dir)] env["PATH"] = ":".join(path_components) @@ -475,10 +464,8 @@ def install_agent(mlx_stack_binary: str | None = None) -> tuple[Path, bool]: was_reinstall = plist_path.exists() if was_reinstall: # Bootout old agent before writing new plist - try: + with contextlib.suppress(LaunchdError): unload_agent(plist_path) - except LaunchdError: - pass # Best-effort unload of old agent # Write new plist write_plist(plist_data, plist_path) diff --git a/src/mlx_stack/core/litellm_gen.py b/src/mlx_stack/core/litellm_gen.py index 571d5a4..5c94216 100644 --- a/src/mlx_stack/core/litellm_gen.py +++ b/src/mlx_stack/core/litellm_gen.py @@ -149,9 +149,7 @@ def generate_litellm_config( # Cloud fallback has_cloud = bool(openrouter_key) if has_cloud: - model_list.append( - _build_cloud_entry("premium", "openrouter/openai/gpt-4o") - ) + model_list.append(_build_cloud_entry("premium", "openrouter/openai/gpt-4o")) model_list.append( _build_cloud_entry("premium", "openrouter/anthropic/claude-sonnet-4-20250514") ) diff --git a/src/mlx_stack/core/log_rotation.py b/src/mlx_stack/core/log_rotation.py index c167997..3cb51c8 100644 --- a/src/mlx_stack/core/log_rotation.py +++ b/src/mlx_stack/core/log_rotation.py @@ -149,11 +149,7 @@ def _delete_excess_archives(base: Path, stem: str, max_files: int) -> None: import re pattern = re.compile(rf"^{re.escape(stem)}\.(\d+)\.gz$") - archives: list[Path] = [] - - for path in base.iterdir(): - if pattern.match(path.name) and path.is_file(): - archives.append(path) + archives = [path for path in base.iterdir() if pattern.match(path.name) and path.is_file()] # We need room for the new archive that will become .1.gz after # shifting, so keep at most max_files - 1 existing archives. @@ -211,9 +207,8 @@ def _copy_and_compress(src: Path, dst_gz: Path) -> None: shutil.copy2(str(src), str(tmp_path)) # Compress the copy - with open(tmp_path, "rb") as f_in: - with gzip.open(str(dst_gz), "wb") as f_out: - shutil.copyfileobj(f_in, f_out) + with open(tmp_path, "rb") as f_in, gzip.open(str(dst_gz), "wb") as f_out: + shutil.copyfileobj(f_in, f_out) finally: # Clean up temporary file if tmp_path.exists(): @@ -228,4 +223,4 @@ def _truncate_file(path: Path) -> None: Args: path: Path to the file to truncate. """ - open(path, "w").close() # noqa: SIM115 + open(path, "w").close() diff --git a/src/mlx_stack/core/log_viewer.py b/src/mlx_stack/core/log_viewer.py index 295c36d..84073a5 100644 --- a/src/mlx_stack/core/log_viewer.py +++ b/src/mlx_stack/core/log_viewer.py @@ -13,7 +13,7 @@ import time from collections.abc import Callable from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from mlx_stack.core.config import get_value @@ -50,12 +50,11 @@ def size_display(self) -> str: """Return human-readable file size.""" if self.size_bytes < 1024: return f"{self.size_bytes} B" - elif self.size_bytes < 1024 * 1024: + if self.size_bytes < 1024 * 1024: return f"{self.size_bytes / 1024:.1f} KB" - elif self.size_bytes < 1024 * 1024 * 1024: + if self.size_bytes < 1024 * 1024 * 1024: return f"{self.size_bytes / (1024 * 1024):.1f} MB" - else: - return f"{self.size_bytes / (1024 * 1024 * 1024):.1f} GB" + return f"{self.size_bytes / (1024 * 1024 * 1024):.1f} GB" @property def modified_display(self) -> str: @@ -101,9 +100,7 @@ def list_log_files() -> list[LogFileInfo]: name=path.name, service=service, size_bytes=stat.st_size, - modified=datetime.fromtimestamp( - stat.st_mtime, tz=timezone.utc - ), + modified=datetime.fromtimestamp(stat.st_mtime, tz=UTC), ) results.append(info) except OSError: @@ -234,7 +231,7 @@ def follow_log( if current_size > position: try: - with open(log_path, "r", encoding="utf-8", errors="replace") as f: + with open(log_path, encoding="utf-8", errors="replace") as f: f.seek(position) new_content = f.read() if new_content: diff --git a/src/mlx_stack/core/models.py b/src/mlx_stack/core/models.py index b2b2330..b2d76f5 100644 --- a/src/mlx_stack/core/models.py +++ b/src/mlx_stack/core/models.py @@ -129,12 +129,14 @@ def _get_active_stack_models(stack: dict[str, Any] | None) -> list[dict[str, str for tier in tiers: model_id = tier.get("model", "") if model_id: - result.append({ - "model_id": model_id, - "quant": tier.get("quant", ""), - "source": tier.get("source", ""), - "tier": tier.get("name", ""), - }) + result.append( + { + "model_id": model_id, + "quant": tier.get("quant", ""), + "source": tier.get("source", ""), + "tier": tier.get("name", ""), + } + ) return result @@ -242,7 +244,7 @@ def _match_to_catalog(dirname: str, catalog: list[CatalogEntry]) -> CatalogEntry """ lower = dirname.lower() for entry in catalog: - for _quant, source in entry.sources.items(): + for source in entry.sources.values(): # Extract repo name from hf_repo (after the /) repo_name = ( source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo @@ -310,24 +312,23 @@ def scan_local_models( stack_source = stack_entry.get("source", "") stack_quant = stack_entry.get("quant", "") stack_model_id = stack_entry.get("model_id", "") - source_dir = ( - stack_source.rsplit("/", 1)[-1] - if "/" in stack_source - else stack_source - ) + source_dir = stack_source.rsplit("/", 1)[-1] if "/" in stack_source else stack_source # Primary match: source directory name matches AND quant matches - if source_dir and source_dir == dirname: - # Source dir match found — verify quant compatibility - if not stack_quant or quant == "unknown" or stack_quant == quant: - is_active = True - break + if ( + source_dir + and source_dir == dirname + and (not stack_quant or quant == "unknown" or stack_quant == quant) + ): + is_active = True + break # Secondary match: model_id matches dirname AND quant matches - if stack_model_id == dirname: - if not stack_quant or quant == "unknown" or stack_quant == quant: - is_active = True - break + if stack_model_id == dirname and ( + not stack_quant or quant == "unknown" or stack_quant == quant + ): + is_active = True + break # Tertiary match: catalog entry ID matches stack model ID, # AND the local model's source matches the catalog source for @@ -406,27 +407,27 @@ def get_remote_stack_models( model_id = stack_entry["model_id"] stack_source = stack_entry.get("source", "") stack_quant = stack_entry.get("quant", "int4") - source_dir = ( - stack_source.rsplit("/", 1)[-1] - if "/" in stack_source - else stack_source - ) + source_dir = stack_source.rsplit("/", 1)[-1] if "/" in stack_source else stack_source # Check if locally available using source+quant-aware matching is_local = False for lm in local_models: # Match by source directory name + quant - if source_dir and source_dir == lm.name: - if not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant: - is_local = True - break + if ( + source_dir + and source_dir == lm.name + and (not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant) + ): + is_local = True + break # Match by model_id as dirname + quant - if model_id == lm.name: - if not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant: - is_local = True - break + if model_id == lm.name and ( + not stack_quant or lm.quant == "unknown" or stack_quant == lm.quant + ): + is_local = True + break # Match by catalog entry with quant-aware source matching cat_entry = _find_catalog_entry(catalog, model_id) @@ -518,7 +519,7 @@ def list_catalog_models( break # Also check if dirname matches any catalog entry's source repo for entry in catalog: - for _quant, source in entry.sources.items(): + for source in entry.sources.values(): repo_name = ( source.hf_repo.rsplit("/", 1)[-1] if "/" in source.hf_repo else source.hf_repo ) diff --git a/src/mlx_stack/core/onboarding.py b/src/mlx_stack/core/onboarding.py index 735ba3f..989e1bd 100644 --- a/src/mlx_stack/core/onboarding.py +++ b/src/mlx_stack/core/onboarding.py @@ -12,7 +12,7 @@ import math from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -167,14 +167,16 @@ def score_and_filter( + weights.memory_efficiency * mem_eff ) - scored.append(ScoredDiscoveredModel( - model=model, - composite_score=round(composite, 4), - speed_score=round(speed, 4), - quality_score=round(quality, 4), - tool_calling_score=tool, - memory_efficiency_score=round(mem_eff, 4), - )) + scored.append( + ScoredDiscoveredModel( + model=model, + composite_score=round(composite, 4), + speed_score=round(speed, 4), + quality_score=round(quality, 4), + tool_calling_score=tool, + memory_efficiency_score=round(mem_eff, 4), + ) + ) scored.sort(key=lambda s: s.composite_score, reverse=True) return scored @@ -214,13 +216,11 @@ def select_defaults( # Pick fast: highest gen_tps that isn't the standard pick if len(scored_models) > 1: - fast_candidates = [ - (i, s) for i, s in enumerate(scored_models) - if i not in selected_indices - ] + fast_candidates = [(i, s) for i, s in enumerate(scored_models) if i not in selected_indices] if fast_candidates: fast_candidates.sort( - key=lambda x: x[1].model.gen_tps or 0.0, reverse=True, + key=lambda x: x[1].model.gen_tps or 0.0, + reverse=True, ) fast_idx, fast_model = fast_candidates[0] fast_mem = fast_model.model.memory_gb or 0.0 @@ -231,15 +231,17 @@ def select_defaults( result: list[ScoredDiscoveredModel] = [] for i, s in enumerate(scored_models): if i in selected_indices: - result.append(ScoredDiscoveredModel( - model=s.model, - composite_score=s.composite_score, - speed_score=s.speed_score, - quality_score=s.quality_score, - tool_calling_score=s.tool_calling_score, - memory_efficiency_score=s.memory_efficiency_score, - is_recommended=True, - )) + result.append( + ScoredDiscoveredModel( + model=s.model, + composite_score=s.composite_score, + speed_score=s.speed_score, + quality_score=s.quality_score, + tool_calling_score=s.tool_calling_score, + memory_efficiency_score=s.memory_efficiency_score, + is_recommended=True, + ) + ) else: result.append(s) @@ -283,10 +285,12 @@ def assign_tiers( # Additional models for extra in remaining[1:]: n = len(mappings) - 1 - mappings.append(TierMapping( - tier_name=f"added-{n}", - model=extra.model, - )) + mappings.append( + TierMapping( + tier_name=f"added-{n}", + model=extra.model, + ) + ) return mappings @@ -332,14 +336,16 @@ def generate_config( if mapping.model.tool_calling: vllm_flags["enable_auto_tool_choice"] = True - tiers_config.append({ - "name": mapping.tier_name, - "model": mapping.model.display_name, - "quant": mapping.model.quant, - "source": mapping.model.hf_repo, - "port": port, - "vllm_flags": vllm_flags, - }) + tiers_config.append( + { + "name": mapping.tier_name, + "model": mapping.model.display_name, + "quant": mapping.model.quant, + "source": mapping.model.hf_repo, + "port": port, + "vllm_flags": vllm_flags, + } + ) port += 1 # Build stack YAML @@ -348,7 +354,7 @@ def generate_config( "name": "default", "hardware_profile": profile.profile_id, "intent": intent, - "created": datetime.now(timezone.utc).isoformat(), + "created": datetime.now(UTC).isoformat(), "tiers": tiers_config, } @@ -362,8 +368,7 @@ def generate_config( # Build LiteLLM config litellm_tiers = [ - {"name": t["name"], "model": t["model"], "port": t["port"]} - for t in tiers_config + {"name": t["name"], "model": t["model"], "port": t["port"]} for t in tiers_config ] openrouter_key = str(get_value("openrouter-key") or "") litellm_config = generate_litellm_config( @@ -427,19 +432,21 @@ def pull_setup_models( hf_repo=model.hf_repo, local_path=str(local_path), disk_size_gb=model.memory_gb or 0.0, - downloaded_at=datetime.now(timezone.utc).isoformat(), + downloaded_at=datetime.now(UTC).isoformat(), ) add_to_inventory(entry) - results.append(PullResult( - model_id=model.display_name, - name=model.display_name, - quant=model.quant, - source_type="mlx-community", - local_path=local_path, - already_existed=already_exists, - disk_size_gb=model.memory_gb or 0.0, - )) + results.append( + PullResult( + model_id=model.display_name, + name=model.display_name, + quant=model.quant, + source_type="mlx-community", + local_path=local_path, + already_existed=already_exists, + disk_size_gb=model.memory_gb or 0.0, + ) + ) return results diff --git a/src/mlx_stack/core/process.py b/src/mlx_stack/core/process.py index 0f9fc9a..5bdf45a 100644 --- a/src/mlx_stack/core/process.py +++ b/src/mlx_stack/core/process.py @@ -12,15 +12,17 @@ from __future__ import annotations +import contextlib import fcntl import os import signal import subprocess import time -from contextlib import contextmanager +from collections.abc import Iterator +from contextlib import contextmanager, suppress from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator +from typing import Any import httpx import psutil @@ -187,10 +189,7 @@ def read_pid_file(service_name: str) -> int | None: try: return int(content) except ValueError: - msg = ( - f"PID file for '{service_name}' contains non-numeric content: " - f"{content!r}" - ) + msg = f"PID file for '{service_name}' contains non-numeric content: {content!r}" raise ProcessError(msg) from None @@ -205,10 +204,8 @@ def remove_pid_file(service_name: str) -> bool: """ pid_path = get_pids_dir() / f"{service_name}.pid" if pid_path.exists(): - try: + with contextlib.suppress(OSError): pid_path.unlink() - except OSError: - pass # Best-effort removal return True return False @@ -325,10 +322,8 @@ def acquire_lock() -> Iterator[None]: try: yield finally: - try: + with suppress(OSError): fcntl.flock(fd, fcntl.LOCK_UN) - except OSError: - pass os.close(fd) @@ -436,10 +431,7 @@ def wait_for_healthy( if last_result is None: last_result = HealthCheckResult(healthy=False, response_time=None, status_code=None) - msg = ( - f"Health check timed out after {total_timeout}s waiting for " - f"http://{host}:{port}{path}" - ) + msg = f"Health check timed out after {total_timeout}s waiting for http://{host}:{port}{path}" raise HealthCheckError(msg) @@ -527,7 +519,6 @@ def check_port_conflict(port: int) -> tuple[int, str] | None: return (0, "") - def detect_port_conflict(port: int) -> None: """Raise PortConflictError if a port is already in use. @@ -617,10 +608,8 @@ def start_service( proc.terminate() proc.wait(timeout=5) except Exception: - try: + with suppress(Exception): proc.kill() - except Exception: - pass log_file.close() msg = ( f"Could not write PID file for '{service_name}' after " diff --git a/src/mlx_stack/core/pull.py b/src/mlx_stack/core/pull.py index c92273f..526fae9 100644 --- a/src/mlx_stack/core/pull.py +++ b/src/mlx_stack/core/pull.py @@ -9,12 +9,13 @@ from __future__ import annotations +import contextlib import json import shutil import subprocess import time from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -164,7 +165,8 @@ def add_to_inventory(entry: ModelInventoryEntry) -> None: # Remove existing entry for same model_id + quant entries = [ - e for e in entries + e + for e in entries if not (e.get("model_id") == entry.model_id and e.get("quant") == entry.quant) ] @@ -216,18 +218,14 @@ def resolve_source( """ if quant not in entry.sources: available = ", ".join(sorted(entry.sources.keys())) - msg = ( - f"Quantization '{quant}' is not available for {entry.name}. " - f"Available: {available}" - ) + msg = f"Quantization '{quant}' is not available for {entry.name}. Available: {available}" raise PullError(msg) source = entry.sources[quant] if source.convert_from: return source, "converted" - else: - return source, "mlx-community" + return source, "mlx-community" # --------------------------------------------------------------------------- # @@ -372,10 +370,7 @@ def download_model( except DownloadError as exc: last_error = exc if attempt < max_retries: - console.print( - f"[yellow]Download attempt {attempt} failed. " - f"Retrying...[/yellow]" - ) + console.print(f"[yellow]Download attempt {attempt} failed. Retrying...[/yellow]") time.sleep(2) # Brief pause before retry else: break @@ -463,11 +458,7 @@ def convert_model( ) except FileNotFoundError: _cleanup_partial(local_dir) - msg = ( - "mlx_lm not found. Install it with:\n" - " pip install mlx_lm\n" - "Or: uv pip install mlx_lm" - ) + msg = "mlx_lm not found. Install it with:\n pip install mlx_lm\nOr: uv pip install mlx_lm" raise ConversionError(msg) from None except subprocess.TimeoutExpired: _cleanup_partial(local_dir) @@ -499,10 +490,8 @@ def _cleanup_partial(local_dir: Path) -> None: local_dir: The directory to remove. """ if local_dir.exists(): - try: + with contextlib.suppress(OSError): shutil.rmtree(local_dir) - except OSError: - pass # --------------------------------------------------------------------------- # @@ -678,14 +667,12 @@ def pull_model( hf_repo=source.hf_repo, local_path=str(local_path), disk_size_gb=source.disk_size_gb, - downloaded_at=datetime.now(timezone.utc).isoformat(), + downloaded_at=datetime.now(UTC).isoformat(), ) add_to_inventory(inv) console.print() - console.print( - f"[bold green]✓ {entry.name} ({quant}) is ready.[/bold green]" - ) + console.print(f"[bold green]✓ {entry.name} ({quant}) is ready.[/bold green]") console.print(f" Location: {local_path}") return PullResult( diff --git a/src/mlx_stack/core/scoring.py b/src/mlx_stack/core/scoring.py index 6ba29e4..5a29cd0 100644 --- a/src/mlx_stack/core/scoring.py +++ b/src/mlx_stack/core/scoring.py @@ -203,8 +203,7 @@ def _resolve_benchmark( except (ValueError, TypeError): # Malformed saved benchmark data — fall through to catalog lookup logger.warning( - "Ignoring malformed saved benchmark for model '%s': " - "invalid numeric values", + "Ignoring malformed saved benchmark for model '%s': invalid numeric values", entry.id, ) @@ -374,9 +373,7 @@ def score_model( Raises: ScoringError: If benchmark data cannot be resolved. """ - gen_tps, memory_gb, is_estimated = _resolve_benchmark( - entry, profile, quant, saved_benchmarks - ) + gen_tps, memory_gb, is_estimated = _resolve_benchmark(entry, profile, quant, saved_benchmarks) speed_score = _normalize_gen_tps_log(gen_tps) quality_score = _normalize_quality(entry.quality.overall) @@ -506,11 +503,13 @@ def assign_tiers( key=lambda m: (-m.composite_score, m.entry.id), ) standard_model = standard_candidates[0] - assignments.append(TierAssignment( - tier=TIER_STANDARD, - model=standard_model, - quant=_DEFAULT_QUANT, - )) + assignments.append( + TierAssignment( + tier=TIER_STANDARD, + model=standard_model, + quant=_DEFAULT_QUANT, + ) + ) used_model_ids.add(standard_model.entry.id) # --- Fast tier: highest gen_tps, different from standard --- @@ -520,11 +519,13 @@ def assign_tiers( ) if fast_candidates: fast_model = fast_candidates[0] - assignments.append(TierAssignment( - tier=TIER_FAST, - model=fast_model, - quant=_DEFAULT_QUANT, - )) + assignments.append( + TierAssignment( + tier=TIER_FAST, + model=fast_model, + quant=_DEFAULT_QUANT, + ) + ) used_model_ids.add(fast_model.entry.id) # --- Longctx tier: architecturally diverse, only for larger budgets --- @@ -541,11 +542,13 @@ def assign_tiers( ) if longctx_candidates: longctx_model = longctx_candidates[0] - assignments.append(TierAssignment( - tier=TIER_LONGCTX, - model=longctx_model, - quant=_DEFAULT_QUANT, - )) + assignments.append( + TierAssignment( + tier=TIER_LONGCTX, + model=longctx_model, + quant=_DEFAULT_QUANT, + ) + ) used_model_ids.add(longctx_model.entry.id) return assignments @@ -604,7 +607,12 @@ def recommend( # Score and filter scored = score_and_filter( - catalog, profile, intent, budget_gb, quant, saved_benchmarks, + catalog, + profile, + intent, + budget_gb, + quant, + saved_benchmarks, exclude_gated=exclude_gated, ) diff --git a/src/mlx_stack/core/stack_down.py b/src/mlx_stack/core/stack_down.py index 04a7333..69ea996 100644 --- a/src/mlx_stack/core/stack_down.py +++ b/src/mlx_stack/core/stack_down.py @@ -95,6 +95,7 @@ def _get_tier_names_from_stack(stack_name: str = "default") -> list[str]: catalog = None if catalog is not None: + def sort_key(tier: dict[str, Any]) -> tuple[float, str]: model_id = tier.get("model", "") entry = get_entry_by_id(catalog, model_id) @@ -226,10 +227,7 @@ def run_down( valid_tiers = _get_valid_tier_names(stack_name) if valid_tiers and tier_filter not in valid_tiers: valid_list = ", ".join(sorted(valid_tiers)) - msg = ( - f"Unknown tier '{tier_filter}'. " - f"Valid tiers: {valid_list}" - ) + msg = f"Unknown tier '{tier_filter}'. Valid tiers: {valid_list}" raise DownError(msg) # --- Check if anything is running --- @@ -277,11 +275,13 @@ def _run_shutdown( svc_result = _stop_single_service(tier_filter) result.services.append(svc_result) else: - result.services.append(ServiceStopResult( - name=tier_filter, - pid=None, - status="not-running", - )) + result.services.append( + ServiceStopResult( + name=tier_filter, + pid=None, + status="not-running", + ) + ) return result # --- Full shutdown: determine order --- diff --git a/src/mlx_stack/core/stack_init.py b/src/mlx_stack/core/stack_init.py index d16dcd5..a2d630f 100644 --- a/src/mlx_stack/core/stack_init.py +++ b/src/mlx_stack/core/stack_init.py @@ -8,7 +8,7 @@ from __future__ import annotations import socket -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -108,7 +108,7 @@ def allocate_ports( msg = ( f"Could not allocate {num_tiers} free ports starting " f"from {base_port}. All ports in range " - f"{base_port}–{max_port} are in use or reserved." + f"{base_port}-{max_port} are in use or reserved." ) raise InitError(msg) ports.append(port) @@ -223,7 +223,7 @@ def generate_stack_definition( raise InitError(msg) tiers: list[dict[str, Any]] = [] - for assignment, port in zip(recommendation.tiers, ports): + for assignment, port in zip(recommendation.tiers, ports, strict=False): tiers.append(_build_tier_entry(assignment, port, catalog)) stack: dict[str, Any] = { @@ -231,7 +231,7 @@ def generate_stack_definition( "name": stack_name, "hardware_profile": recommendation.hardware_profile.profile_id, "intent": recommendation.intent, - "created": datetime.now(timezone.utc).isoformat(), + "created": datetime.now(UTC).isoformat(), "tiers": tiers, } @@ -356,10 +356,7 @@ def run_init( litellm_path = get_data_home() / "litellm.yaml" if stack_path.exists() and not force: - msg = ( - f"Stack '{stack_name}' already exists at {stack_path}. " - f"Use --force to overwrite." - ) + msg = f"Stack '{stack_name}' already exists at {stack_path}. Use --force to overwrite." raise InitError(msg) # --- Load catalog --- @@ -427,7 +424,10 @@ def run_init( weights = INTENT_WEIGHTS.get(intent, INTENT_WEIGHTS["balanced"]) try: scored = score_model( - entry, profile, weights, recommendation.memory_budget_gb, + entry, + profile, + weights, + recommendation.memory_budget_gb, ) except ScoringError as exc: msg = f"Cannot add model '{model_id}': {exc}" @@ -450,11 +450,13 @@ def run_init( ) tier_name = f"added-{model_id}" - tiers.append(TierAssignment( - tier=tier_name, - model=scored, - quant="int4", - )) + tiers.append( + TierAssignment( + tier=tier_name, + model=scored, + quant="int4", + ) + ) if not tiers: msg = "No tiers remaining after customization. Cannot generate stack." @@ -481,8 +483,7 @@ def run_init( # --- Generate LiteLLM config --- tier_entries = [ - {"name": t["name"], "model": t["model"], "port": t["port"]} - for t in stack["tiers"] + {"name": t["name"], "model": t["model"], "port": t["port"]} for t in stack["tiers"] ] litellm_config = generate_litellm_config( tiers=tier_entries, diff --git a/src/mlx_stack/core/stack_status.py b/src/mlx_stack/core/stack_status.py index a2bd718..2bd42bb 100644 --- a/src/mlx_stack/core/stack_status.py +++ b/src/mlx_stack/core/stack_status.py @@ -178,16 +178,18 @@ def run_status(stack_name: str = "default") -> StatusResult: health_path=VLLM_HEALTH_PATH, ) - result.services.append(ServiceStatus( - tier=tier_name, - model=model, - port=port, - status=ServiceHealth(svc_status["status"]), - uptime=svc_status["uptime"], - uptime_display=format_uptime(svc_status["uptime"]), - response_time=svc_status["response_time"], - pid=svc_status["pid"], - )) + result.services.append( + ServiceStatus( + tier=tier_name, + model=model, + port=port, + status=ServiceHealth(svc_status["status"]), + uptime=svc_status["uptime"], + uptime_display=format_uptime(svc_status["uptime"]), + response_time=svc_status["response_time"], + pid=svc_status["pid"], + ) + ) # --- Check LiteLLM --- litellm_port = _get_litellm_port() @@ -197,16 +199,18 @@ def run_status(stack_name: str = "default") -> StatusResult: health_path=LITELLM_HEALTH_PATH, ) - result.services.append(ServiceStatus( - tier="litellm", - model="proxy", - port=litellm_port, - status=ServiceHealth(litellm_status["status"]), - uptime=litellm_status["uptime"], - uptime_display=format_uptime(litellm_status["uptime"]), - response_time=litellm_status["response_time"], - pid=litellm_status["pid"], - )) + result.services.append( + ServiceStatus( + tier="litellm", + model="proxy", + port=litellm_port, + status=ServiceHealth(litellm_status["status"]), + uptime=litellm_status["uptime"], + uptime_display=format_uptime(litellm_status["uptime"]), + response_time=litellm_status["response_time"], + pid=litellm_status["pid"], + ) + ) return result @@ -220,9 +224,8 @@ def status_to_dict(result: StatusResult) -> dict[str, Any]: Returns: A dict suitable for ``json.dumps``. """ - services_list: list[dict[str, Any]] = [] - for svc in result.services: - services_list.append({ + services_list: list[dict[str, Any]] = [ + { "tier": svc.tier, "model": svc.model, "port": svc.port, @@ -231,7 +234,9 @@ def status_to_dict(result: StatusResult) -> dict[str, Any]: "uptime_display": svc.uptime_display, "pid": svc.pid, "response_time": svc.response_time, - }) + } + for svc in result.services + ] return { "services": services_list, diff --git a/src/mlx_stack/core/stack_up.py b/src/mlx_stack/core/stack_up.py index c892ada..ab2ed11 100644 --- a/src/mlx_stack/core/stack_up.py +++ b/src/mlx_stack/core/stack_up.py @@ -192,7 +192,7 @@ def estimate_memory_usage( # Look for memory_gb in any benchmark entry memory_gb = 0.0 - for _hw_key, bench in entry.benchmarks.items(): + for bench in entry.benchmarks.values(): memory_gb = bench.memory_gb break # Take the first available benchmark's memory @@ -274,10 +274,7 @@ def check_local_model_exists(tier: dict[str, Any]) -> str | None: return None # Model not found — generate diagnostic message - return ( - f"Model '{model_id}' not found locally. " - f"Run 'mlx-stack pull {model_id}' to download it." - ) + return f"Model '{model_id}' not found locally. Run 'mlx-stack pull {model_id}' to download it." # --------------------------------------------------------------------------- # @@ -303,9 +300,12 @@ def build_vllm_command( cmd = [ vllm_binary, - "serve", model_source, - "--port", str(port), - "--host", "127.0.0.1", + "serve", + model_source, + "--port", + str(port), + "--host", + "127.0.0.1", ] # Add vllm_flags @@ -338,9 +338,12 @@ def build_litellm_command( """ return [ litellm_binary, - "--config", str(litellm_config_path), - "--port", str(litellm_port), - "--host", "127.0.0.1", + "--config", + str(litellm_config_path), + "--port", + str(litellm_port), + "--host", + "127.0.0.1", ] @@ -367,9 +370,7 @@ def format_dry_run_command( parts: list[str] = [] if env_vars: - for key in sorted(env_vars.keys()): - # Mask all env var values in dry-run - parts.append(f"{key}=***") + parts.extend(f"{key}=***" for key in sorted(env_vars.keys())) parts.extend(cmd) return " ".join(parts) @@ -466,10 +467,7 @@ def run_up( if tier_filter is not None: if tier_filter not in valid_tier_names: valid_list = ", ".join(sorted(valid_tier_names)) - msg = ( - f"Unknown tier '{tier_filter}'. " - f"Valid tiers: {valid_list}" - ) + msg = f"Unknown tier '{tier_filter}'. Valid tiers: {valid_list}" raise UpError(msg) tiers = [t for t in tiers if t["name"] == tier_filter] @@ -539,18 +537,22 @@ def _run_dry_run( cmd = build_vllm_command(tier, vllm_binary) cmd_str = format_dry_run_command(cmd) - result.dry_run_commands.append({ - "service": tier["name"], - "command": cmd_str, - "type": "vllm-mlx", - }) + result.dry_run_commands.append( + { + "service": tier["name"], + "command": cmd_str, + "type": "vllm-mlx", + } + ) - result.tiers.append(TierStatus( - name=tier["name"], - model=tier.get("model", ""), - port=tier["port"], - status="dry-run", - )) + result.tiers.append( + TierStatus( + name=tier["name"], + model=tier.get("model", ""), + port=tier["port"], + status="dry-run", + ) + ) # LiteLLM command litellm_cmd = build_litellm_command(litellm_binary, litellm_port, litellm_config_path) @@ -561,11 +563,13 @@ def _run_dry_run( litellm_cmd_str = format_dry_run_command(litellm_cmd, env_display) - result.dry_run_commands.append({ - "service": LITELLM_SERVICE_NAME, - "command": litellm_cmd_str, - "type": "litellm", - }) + result.dry_run_commands.append( + { + "service": LITELLM_SERVICE_NAME, + "command": litellm_cmd_str, + "type": "litellm", + } + ) result.litellm = TierStatus( name=LITELLM_SERVICE_NAME, @@ -638,17 +642,18 @@ def _run_startup( if pid is not None: if is_process_alive(pid): # Already running - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=tier["port"], - status="already-running", - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=tier["port"], + status="already-running", + ) + ) continue - else: - # Stale PID — clean up - cleanup_stale_pid(tier_name) - any_stale = True + # Stale PID — clean up + cleanup_stale_pid(tier_name) + any_stale = True else: pass # Tier needs to be started @@ -670,9 +675,7 @@ def _run_startup( any_stale = True # If all tiers + LiteLLM are already running, report and return - tiers_already_running = [ - t for t in result.tiers if t.status == "already-running" - ] + tiers_already_running = [t for t in result.tiers if t.status == "already-running"] if len(tiers_already_running) == len(tiers) and litellm_already_running: result.already_running = True result.litellm = TierStatus( @@ -695,10 +698,7 @@ def _run_startup( # --- Start vllm-mlx instances sequentially --- healthy_count = 0 - tiers_needing_start = [ - t for t in tiers - if t["name"] not in {ts.name for ts in result.tiers} - ] + tiers_needing_start = [t for t in tiers if t["name"] not in {ts.name for ts in result.tiers}] for tier in tiers_needing_start: tier_name = tier["name"] @@ -707,29 +707,30 @@ def _run_startup( # Preflight: check local model exists on disk missing_msg = check_local_model_exists(tier) if missing_msg is not None: - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=port, - status="skipped", - error=missing_msg, - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=port, + status="skipped", + error=missing_msg, + ) + ) continue # Check port conflict conflict = check_port_conflict(port) if conflict is not None: conflict_pid, conflict_name = conflict - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=port, - status="skipped", - error=( - f"Port {port} already in use by " - f"PID {conflict_pid} ({conflict_name})" - ), - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=port, + status="skipped", + error=(f"Port {port} already in use by PID {conflict_pid} ({conflict_name})"), + ) + ) continue # Start the vllm-mlx subprocess @@ -742,38 +743,42 @@ def _run_startup( port=port, ) except Exception as exc: - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=port, - status="failed", - error=str(exc), - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=port, + status="failed", + error=str(exc), + ) + ) continue # Health check with exponential backoff try: wait_for_healthy(port=port, path=VLLM_HEALTH_PATH) - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=port, - status="healthy", - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=port, + status="healthy", + ) + ) healthy_count += 1 except HealthCheckError as exc: - result.tiers.append(TierStatus( - name=tier_name, - model=tier.get("model", ""), - port=port, - status="failed", - error=str(exc), - )) + result.tiers.append( + TierStatus( + name=tier_name, + model=tier.get("model", ""), + port=port, + status="failed", + error=str(exc), + ) + ) # --- Count total healthy (including already-running) --- - total_healthy = sum( - 1 for t in result.tiers if t.status in ("healthy", "already-running") - ) + total_healthy = sum(1 for t in result.tiers if t.status in ("healthy", "already-running")) # --- Start LiteLLM if any healthy tiers and not already running --- if litellm_already_running: @@ -802,13 +807,14 @@ def _run_startup( port=litellm_port, status="skipped", error=( - f"Port {litellm_port} already in use by " - f"PID {conflict_pid} ({conflict_name})" + f"Port {litellm_port} already in use by PID {conflict_pid} ({conflict_name})" ), ) else: litellm_cmd = build_litellm_command( - litellm_binary, litellm_port, litellm_config_path, + litellm_binary, + litellm_port, + litellm_config_path, ) # Build env with OpenRouter key if configured diff --git a/src/mlx_stack/core/watchdog.py b/src/mlx_stack/core/watchdog.py index 3d0cc36..d23191d 100644 --- a/src/mlx_stack/core/watchdog.py +++ b/src/mlx_stack/core/watchdog.py @@ -187,9 +187,7 @@ def check_flapping( cutoff = now - window_seconds # Prune old timestamps outside the window - tracker.restart_timestamps = [ - ts for ts in tracker.restart_timestamps if ts > cutoff - ] + tracker.restart_timestamps = [ts for ts in tracker.restart_timestamps if ts > cutoff] if len(tracker.restart_timestamps) >= max_restarts: tracker.is_flapping = True @@ -303,8 +301,7 @@ def restart_service( if service_name == LITELLM_SERVICE_NAME: return _restart_litellm(service_name, stack, litellm_binary) - else: - return _restart_tier(service_name, stack, vllm_binary) + return _restart_tier(service_name, stack, vllm_binary) except LockError: logger.warning( "Could not acquire lock to restart '%s' — another operation is in progress.", @@ -518,6 +515,7 @@ def setup_signal_handlers(state: WatchdogState) -> None: Args: state: The watchdog state — sets shutdown_requested flag. """ + def handler(signum: int, frame: Any) -> None: state.shutdown_requested = True @@ -570,9 +568,7 @@ def poll_cycle( # Check each service for crashed state for svc in status_result.services: service_name = svc.tier - tracker = state.service_trackers.setdefault( - service_name, ServiceTracker() - ) + tracker = state.service_trackers.setdefault(service_name, ServiceTracker()) # Try to reset flap state if service has been stable reset_flap_state(tracker) @@ -593,8 +589,7 @@ def poll_cycle( if check_flapping(tracker, max_restarts): result.flapping_services.append(service_name) logger.warning( - "Service '%s' marked as flapping after %d restarts. " - "Stopping auto-restart.", + "Service '%s' marked as flapping after %d restarts. Stopping auto-restart.", service_name, max_restarts, ) @@ -728,7 +723,7 @@ def run_watchdog( status_callback(result, state) if restart_callback is not None and result.restarts_attempted > 0: - for record in state.restart_log[-result.restarts_attempted:]: + for record in state.restart_log[-result.restarts_attempted :]: restart_callback(record) # Sleep in small increments so we can check shutdown flag diff --git a/tests/conftest.py b/tests/conftest.py index 640211b..9879ede 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ import pytest -@pytest.fixture() +@pytest.fixture def mlx_stack_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: """Provide an isolated MLX_STACK_HOME directory for testing. @@ -27,7 +27,7 @@ def mlx_stack_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: return home -@pytest.fixture() +@pytest.fixture def clean_mlx_stack_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: """Provide a clean (non-existent) MLX_STACK_HOME for testing auto-creation. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 10defad..f9a5781 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib import os import shutil import signal @@ -19,7 +20,7 @@ import sys import time from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -65,7 +66,7 @@ def skip_insufficient_memory(min_gb: float) -> pytest.MarkDecorator: available < min_gb, reason=f"Requires {min_gb:.1f}GB free memory, have {available:.1f}GB", ) - except Exception: # noqa: BLE001 + except Exception: return pytest.mark.skip(reason="Could not determine available memory") @@ -75,7 +76,7 @@ def check_memory_or_skip(min_gb: float) -> None: available = psutil.virtual_memory().available / (1024**3) if available < min_gb: pytest.skip(f"Requires {min_gb:.1f}GB free memory, have {available:.1f}GB") - except Exception: # noqa: BLE001 + except Exception: pytest.skip("Could not determine available memory") @@ -126,10 +127,8 @@ def kill_processes_on_port(port: int) -> None: for pid_str in result.stdout.strip().split("\n"): pid_str = pid_str.strip() if pid_str.isdigit(): - try: + with contextlib.suppress(OSError): os.kill(int(pid_str), signal.SIGKILL) - except OSError: - pass except (subprocess.TimeoutExpired, FileNotFoundError, OSError): pass @@ -198,7 +197,7 @@ def catalog() -> list[CatalogEntry]: # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def integration_home( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, @@ -281,9 +280,13 @@ def start_vllm( """ flags = vllm_flags or {} cmd: list[str] = [ - "vllm-mlx", "serve", model_source, - "--port", str(port), - "--host", "127.0.0.1", + "vllm-mlx", + "serve", + model_source, + "--port", + str(port), + "--host", + "127.0.0.1", ] if flags.get("continuous_batching"): cmd.append("--continuous-batching") @@ -347,9 +350,12 @@ def start_litellm( """ cmd: list[str] = [ "litellm", - "--config", str(config_path), - "--port", str(port), - "--host", "127.0.0.1", + "--config", + str(config_path), + "--port", + str(port), + "--host", + "127.0.0.1", ] logs_dir = self._mlx_home / "logs" @@ -387,10 +393,8 @@ def stop_all(self) -> None: """Stop all managed services. Always safe to call multiple times.""" for svc in reversed(self._services): if svc.pid is not None: - try: + with contextlib.suppress(OSError): os.kill(svc.pid, signal.SIGTERM) - except OSError: - pass # Give services time to shut down gracefully time.sleep(2) @@ -398,10 +402,8 @@ def stop_all(self) -> None: # Force-kill anything still alive for svc in self._services: if svc.pid is not None: - try: + with contextlib.suppress(OSError): os.kill(svc.pid, signal.SIGKILL) - except OSError: - pass kill_processes_on_port(svc.port) # Wait for ports to be freed @@ -449,7 +451,7 @@ def build_test_stack_yaml( "name": "default", "hardware_profile": hardware_profile, "intent": intent, - "created": datetime.now(timezone.utc).isoformat(), + "created": datetime.now(UTC).isoformat(), "tiers": tiers, } @@ -467,16 +469,17 @@ def build_test_litellm_yaml( Returns: LiteLLM config dict ready for yaml.dump(). """ - model_list = [] - for tier in tiers: - model_list.append({ + model_list = [ + { "model_name": tier["name"], "litellm_params": { "model": f"openai/{tier['model']}", "api_base": f"http://localhost:{tier['port']}/v1", "api_key": "dummy", }, - }) + } + for tier in tiers + ] return { "model_list": model_list, diff --git a/tests/integration/report.py b/tests/integration/report.py index 00c7fe4..06489a7 100644 --- a/tests/integration/report.py +++ b/tests/integration/report.py @@ -17,7 +17,7 @@ import json import platform from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path # --------------------------------------------------------------------------- # @@ -56,7 +56,7 @@ class CompatibilityMatrix: def __post_init__(self) -> None: if not self.timestamp: - self.timestamp = datetime.now(timezone.utc).isoformat() + self.timestamp = datetime.now(UTC).isoformat() if not self.platform: self.platform = f"{platform.system()} {platform.release()}" @@ -73,7 +73,7 @@ def write(self, output_dir: Path | None = None) -> Path: out_dir = output_dir or REPORT_DIR out_dir.mkdir(parents=True, exist_ok=True) - ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") path = out_dir / f"compatibility-matrix-{ts}.json" data = { diff --git a/tests/integration/test_catalog_validation.py b/tests/integration/test_catalog_validation.py index 1240b8f..985b114 100644 --- a/tests/integration/test_catalog_validation.py +++ b/tests/integration/test_catalog_validation.py @@ -38,7 +38,7 @@ @pytest.fixture(scope="module") -def hf_client() -> Generator[httpx.Client, None, None]: +def hf_client() -> Generator[httpx.Client]: """Shared HTTP client for HuggingFace API calls.""" client = httpx.Client( timeout=30.0, @@ -88,9 +88,7 @@ def test_source_fields(self, entry: CatalogEntry) -> None: assert "/" in source.hf_repo, ( f"{entry.id}/{quant}: hf_repo '{source.hf_repo}' must be org/repo format" ) - assert source.disk_size_gb > 0, ( - f"{entry.id}/{quant}: disk_size_gb must be > 0" - ) + assert source.disk_size_gb > 0, f"{entry.id}/{quant}: disk_size_gb must be > 0" @pytest.mark.parametrize("entry", _CATALOG, ids=_CATALOG_IDS) def test_capability_consistency(self, entry: CatalogEntry) -> None: @@ -129,9 +127,7 @@ def test_quality_scores_in_range(self, entry: CatalogEntry) -> None: ("reasoning", q.reasoning), ("instruction_following", q.instruction_following), ]: - assert 0 <= score <= 100, ( - f"{entry.id}: quality.{name}={score} is outside 0-100 range" - ) + assert 0 <= score <= 100, f"{entry.id}: quality.{name}={score} is outside 0-100 range" # --------------------------------------------------------------------------- # @@ -180,7 +176,9 @@ def test_hf_repos_exist(self, entry: CatalogEntry, hf_client: httpx.Client) -> N ids=[e.id for e in _CATALOG if not e.gated], ) def test_non_gated_repos_have_safetensors( - self, entry: CatalogEntry, hf_client: httpx.Client, + self, + entry: CatalogEntry, + hf_client: httpx.Client, ) -> None: """Non-gated, non-convert_from repos contain safetensors weight files. @@ -202,7 +200,7 @@ def test_non_gated_repos_have_safetensors( assert has_safetensors, ( f"{entry.id}/{quant}: repo '{source.hf_repo}' has no *.safetensors files. " - f"Files found: {[f for f in filenames[:10]]}... " + f"Files found: {list(filenames[:10])}... " f"This repo may not contain MLX-format weights." ) @@ -224,6 +222,5 @@ def test_gated_models_flagged(self, entry: CatalogEntry) -> None: """Models flagged as gated should have at least one non-mlx-community source.""" assert entry.gated, f"{entry.id}: expected gated=True" assert any( - not source.hf_repo.startswith("mlx-community/") - for source in entry.sources.values() + not source.hf_repo.startswith("mlx-community/") for source in entry.sources.values() ), f"{entry.id}: gated model should have at least one non-mlx-community source" diff --git a/tests/integration/test_harness_compatibility.py b/tests/integration/test_harness_compatibility.py index 726edf6..14d7ae6 100644 --- a/tests/integration/test_harness_compatibility.py +++ b/tests/integration/test_harness_compatibility.py @@ -99,7 +99,10 @@ def _start_stack( self._svc.__enter__() self._svc.start_vllm( - "fast", MODEL_SOURCE, self.vllm_port, timeout=120.0, + "fast", + MODEL_SOURCE, + self.vllm_port, + timeout=120.0, ) self._svc.start_litellm( self.litellm_port, @@ -125,9 +128,7 @@ def test_chat_completion(self) -> None: messages=[{"role": "user", "content": "Say hello"}], max_tokens=10, ) - assert response.choices[0].message.content, ( - "OpenAI client got empty response content" - ) + assert response.choices[0].message.content, "OpenAI client got empty response content" assert len(response.choices[0].message.content.strip()) > 0 def test_streaming(self) -> None: @@ -139,14 +140,11 @@ def test_streaming(self) -> None: stream=True, ) chunks = list(stream) - assert len(chunks) > 1, ( - f"Expected multiple SSE chunks, got {len(chunks)}" - ) + assert len(chunks) > 1, f"Expected multiple SSE chunks, got {len(chunks)}" # At least one chunk should have content has_content = any( - c.choices and c.choices[0].delta and c.choices[0].delta.content - for c in chunks + c.choices and c.choices[0].delta and c.choices[0].delta.content for c in chunks ) assert has_content, "No chunk contained content" @@ -154,9 +152,7 @@ def test_model_listing(self) -> None: """GET /v1/models returns expected tier names via OpenAI client.""" models = self.client.models.list() model_ids = [m.id for m in models.data] - assert "fast" in model_ids, ( - f"Tier 'fast' not found in model list: {model_ids}" - ) + assert "fast" in model_ids, f"Tier 'fast' not found in model list: {model_ids}" def test_tool_calling_via_client(self) -> None: """Function calling via OpenAI client returns structured tool_calls.""" diff --git a/tests/integration/test_inference_e2e.py b/tests/integration/test_inference_e2e.py index d082ae0..c26d998 100644 --- a/tests/integration/test_inference_e2e.py +++ b/tests/integration/test_inference_e2e.py @@ -35,13 +35,14 @@ from __future__ import annotations +import contextlib import os import signal import socket import subprocess import sys import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -136,10 +137,8 @@ def _kill_processes_on_port(port: int) -> None: for pid_str in pids: pid_str = pid_str.strip() if pid_str.isdigit(): - try: + with contextlib.suppress(OSError): os.kill(int(pid_str), signal.SIGKILL) - except OSError: - pass except (subprocess.TimeoutExpired, FileNotFoundError, OSError): pass @@ -165,7 +164,7 @@ def _build_minimal_stack_yaml() -> dict[str, Any]: "name": "default", "hardware_profile": "test", "intent": "balanced", - "created": datetime.now(timezone.utc).isoformat(), + "created": datetime.now(UTC).isoformat(), "tiers": [ { "name": TIER_NAME, @@ -286,17 +285,14 @@ def test_full_inference_lifecycle( model_id=MODEL_ID, quant=MODEL_QUANT, ) - assert pull_result.local_path.exists(), ( - f"Model not found at {pull_result.local_path}" - ) + assert pull_result.local_path.exists(), f"Model not found at {pull_result.local_path}" # ---- Step 3: Up — start real services ---- up_result = run_up() # Verify at least one tier is healthy healthy_tiers = [ - t for t in up_result.tiers - if t.status in ("healthy", "already-running") + t for t in up_result.tiers if t.status in ("healthy", "already-running") ] assert len(healthy_tiers) > 0, ( f"No healthy tiers after up. " @@ -306,8 +302,7 @@ def test_full_inference_lifecycle( # Verify LiteLLM is healthy assert up_result.litellm is not None, "LiteLLM result missing" assert up_result.litellm.status in ("healthy", "already-running"), ( - f"LiteLLM not healthy: {up_result.litellm.status} " - f"({up_result.litellm.error})" + f"LiteLLM not healthy: {up_result.litellm.status} ({up_result.litellm.error})" ) # ---- Step 4: Inference via LiteLLM proxy ---- @@ -381,10 +376,8 @@ def test_full_inference_lifecycle( finally: # ---- Cleanup: belt-and-suspenders ---- # Always attempt run_down() to clean up services - try: + with contextlib.suppress(Exception): run_down() - except Exception: # noqa: BLE001 — cleanup must not raise - pass # Kill any remaining processes on ports 8000 and 4000 _kill_processes_on_port(8000) diff --git a/tests/integration/test_launchd_e2e.py b/tests/integration/test_launchd_e2e.py index 5106e5b..f62cd9c 100644 --- a/tests/integration/test_launchd_e2e.py +++ b/tests/integration/test_launchd_e2e.py @@ -25,6 +25,7 @@ from __future__ import annotations +import contextlib import plistlib import shutil import subprocess @@ -110,16 +111,14 @@ def _force_cleanup() -> None: plist_path = get_plist_path() # Try to unload via launchctl bootout - try: + with contextlib.suppress(Exception): unload_agent(plist_path) - except Exception: # noqa: BLE001 — cleanup must not raise - pass # Remove plist file if it exists try: if plist_path.exists(): plist_path.unlink() - except Exception: # noqa: BLE001 + except Exception: pass @@ -144,9 +143,7 @@ class TestLaunchdE2ELifecycle: pytest -m integration tests/integration/test_launchd_e2e.py -v """ - def test_full_lifecycle( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_full_lifecycle(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test install → status → uninstall lifecycle with real launchctl. Steps: @@ -185,14 +182,12 @@ def test_full_lifecycle( try: # ---- Step 1: Install the agent ---- - plist_path, was_reinstall = install_agent(binary_path) + _plist_path, was_reinstall = install_agent(binary_path) assert not was_reinstall, "Expected fresh install, not reinstall" # ---- Step 2: Verify plist file exists ---- canonical_path = get_plist_path() - assert canonical_path.exists(), ( - f"Plist file not found at {canonical_path}" - ) + assert canonical_path.exists(), f"Plist file not found at {canonical_path}" # ---- Step 3: Verify plist content ---- with open(canonical_path, "rb") as f: @@ -201,32 +196,21 @@ def test_full_lifecycle( assert plist_data["Label"] == LAUNCHD_LABEL, ( f"Expected Label={LAUNCHD_LABEL!r}, got {plist_data.get('Label')!r}" ) - assert plist_data["RunAtLoad"] is True, ( - "Expected RunAtLoad=True" - ) - assert plist_data["KeepAlive"] is True, ( - "Expected KeepAlive=True" - ) + assert plist_data["RunAtLoad"] is True, "Expected RunAtLoad=True" + assert plist_data["KeepAlive"] is True, "Expected KeepAlive=True" prog_args = plist_data["ProgramArguments"] - assert isinstance(prog_args, list), ( - "ProgramArguments should be a list" - ) - assert len(prog_args) >= 2, ( - "ProgramArguments should have at least 2 elements" - ) + assert isinstance(prog_args, list), "ProgramArguments should be a list" + assert len(prog_args) >= 2, "ProgramArguments should have at least 2 elements" assert prog_args[0] == binary_path, ( f"Expected binary path {binary_path!r}, got {prog_args[0]!r}" ) - assert prog_args[1] == "watch", ( - f"Expected 'watch' subcommand, got {prog_args[1]!r}" - ) + assert prog_args[1] == "watch", f"Expected 'watch' subcommand, got {prog_args[1]!r}" # ---- Step 4: Verify file permissions ---- mode = canonical_path.stat().st_mode & 0o777 assert mode == PLIST_PERMISSIONS, ( - f"Expected permissions {oct(PLIST_PERMISSIONS)}, " - f"got {oct(mode)}" + f"Expected permissions {oct(PLIST_PERMISSIONS)}, got {oct(mode)}" ) # ---- Step 5: Verify agent status reports installed ---- diff --git a/tests/integration/test_model_smoke.py b/tests/integration/test_model_smoke.py index c549908..51c0f21 100644 --- a/tests/integration/test_model_smoke.py +++ b/tests/integration/test_model_smoke.py @@ -156,9 +156,7 @@ def test_basic_inference( try: # Pull model (cached across runs) pull_result = pull_model(model_id=entry.id, quant=SMOKE_QUANT) - assert pull_result.local_path.exists(), ( - f"Model not found at {pull_result.local_path}" - ) + assert pull_result.local_path.exists(), f"Model not found at {pull_result.local_path}" # Start vllm-mlx start_time = time.monotonic() @@ -205,8 +203,7 @@ def test_basic_inference( result.inference_time_s = time.monotonic() - inference_start assert response.status_code == 200, ( - f"{entry.id}: inference returned {response.status_code}: " - f"{response.text[:300]}" + f"{entry.id}: inference returned {response.status_code}: {response.text[:300]}" ) data = response.json() diff --git a/tests/integration/test_stack_integration.py b/tests/integration/test_stack_integration.py index 9230807..eb15604 100644 --- a/tests/integration/test_stack_integration.py +++ b/tests/integration/test_stack_integration.py @@ -116,7 +116,8 @@ def test_full_lifecycle( ) data = response.json() content = data["choices"][0]["message"]["content"] - assert content and len(content.strip()) > 0 + assert content + assert len(content.strip()) > 0 # Inference directly to vllm-mlx response = httpx.post( @@ -142,9 +143,7 @@ def test_full_lifecycle( # Verify no PID files remain pids_dir = integration_home / "pids" remaining = list(pids_dir.glob("*.pid")) - assert len(remaining) == 0, ( - f"PID files remain: {[p.name for p in remaining]}" - ) + assert len(remaining) == 0, f"PID files remain: {[p.name for p in remaining]}" def test_litellm_routing( self, @@ -249,9 +248,7 @@ def test_model_listing_via_litellm( assert response.status_code == 200 data = response.json() model_ids = [m["id"] for m in data.get("data", [])] - assert "fast" in model_ids, ( - f"Tier 'fast' not in model list: {model_ids}" - ) + assert "fast" in model_ids, f"Tier 'fast' not in model list: {model_ids}" def test_concurrent_requests( self, @@ -299,9 +296,7 @@ async def run_concurrent() -> list[int]: return await asyncio.gather(*tasks) results = asyncio.run(run_concurrent()) - assert all(s == 200 for s in results), ( - f"Some concurrent requests failed: {results}" - ) + assert all(s == 200 for s in results), f"Some concurrent requests failed: {results}" def test_clean_shutdown_no_orphans( self, @@ -336,11 +331,7 @@ def test_clean_shutdown_no_orphans( # Verify: no PID files pids_dir = integration_home / "pids" remaining = list(pids_dir.glob("*.pid")) - assert len(remaining) == 0, ( - f"PID files remain: {[p.name for p in remaining]}" - ) + assert len(remaining) == 0, f"PID files remain: {[p.name for p in remaining]}" # Verify: socket bind succeeds (port truly free) - assert not is_port_in_use(vllm_port), ( - f"Port {vllm_port} still in use after cleanup" - ) + assert not is_port_in_use(vllm_port), f"Port {vllm_port} still in use after cleanup" diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 6bf87b7..4262dd6 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -57,7 +57,7 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def sample_entry() -> CatalogEntry: """A sample catalog entry for testing.""" return CatalogEntry( @@ -80,18 +80,14 @@ def sample_entry() -> CatalogEntry: ), quality=QualityScores(overall=68, coding=65, reasoning=62, instruction_following=72), benchmarks={ - "m5-max-128": CatalogBenchmarkResult( - prompt_tps=155.0, gen_tps=85.0, memory_gb=5.5 - ), - "m4-pro-48": CatalogBenchmarkResult( - prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5 - ), + "m5-max-128": CatalogBenchmarkResult(prompt_tps=155.0, gen_tps=85.0, memory_gb=5.5), + "m4-pro-48": CatalogBenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), }, tags=["balanced", "agent-ready"], ) -@pytest.fixture() +@pytest.fixture def sample_entry_no_tool_calling() -> CatalogEntry: """A sample catalog entry without tool-calling.""" return CatalogEntry( @@ -113,15 +109,13 @@ def sample_entry_no_tool_calling() -> CatalogEntry: ), quality=QualityScores(overall=65, coding=60, reasoning=70, instruction_following=60), benchmarks={ - "m5-max-128": CatalogBenchmarkResult( - prompt_tps=150.0, gen_tps=80.0, memory_gb=5.0 - ), + "m5-max-128": CatalogBenchmarkResult(prompt_tps=150.0, gen_tps=80.0, memory_gb=5.0), }, tags=["reasoning"], ) -@pytest.fixture() +@pytest.fixture def sample_profile() -> HardwareProfile: """A sample hardware profile for testing.""" return HardwareProfile( @@ -287,9 +281,7 @@ def test_fail_when_far_below( for cls in classifications: assert cls.classification == CLASSIFICATION_FAIL - def test_no_matching_profile_returns_empty( - self, sample_entry: CatalogEntry - ) -> None: + def test_no_matching_profile_returns_empty(self, sample_entry: CatalogEntry) -> None: unknown_profile = HardwareProfile( chip="Apple M99", gpu_cores=100, @@ -580,8 +572,7 @@ def test_timeout_raises(self, mock_stream: MagicMock) -> None: def test_no_content_returns_zero_tps(self, mock_stream: MagicMock) -> None: """When no content chunks arrive, TPS should be zero.""" sse_lines = [ - 'data: {"choices":[{"delta":{}}],' - '"usage":{"prompt_tokens":0,"completion_tokens":0}}', + 'data: {"choices":[{"delta":{}}],"usage":{"prompt_tokens":0,"completion_tokens":0}}', "data: [DONE]", ] mock_response = MagicMock() @@ -787,9 +778,7 @@ def test_no_tool_calls_in_response(self, mock_post: MagicMock) -> None: mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json.return_value = { - "choices": [{"message": {"content": "It's sunny"}}] - } + mock_response.json.return_value = {"choices": [{"message": {"content": "It's sunny"}}]} mock_post.return_value = mock_response result = _run_tool_call_benchmark(port=8000, model_name="test") @@ -991,9 +980,7 @@ def test_no_reasoning_parser_for_non_thinking_models( architecture="transformer", min_mlx_lm_version="0.22.0", sources={ - "int4": QuantSource( - hf_repo="mlx-community/test-4bit", disk_size_gb=4.5 - ), + "int4": QuantSource(hf_repo="mlx-community/test-4bit", disk_size_gb=4.5), }, capabilities=Capabilities( tool_calling=False, @@ -1002,9 +989,7 @@ def test_no_reasoning_parser_for_non_thinking_models( reasoning_parser=None, vision=False, ), - quality=QualityScores( - overall=50, coding=50, reasoning=50, instruction_following=50 - ), + quality=QualityScores(overall=50, coding=50, reasoning=50, instruction_following=50), benchmarks={}, tags=[], ) @@ -1037,9 +1022,7 @@ def test_thinking_without_reasoning_parser_no_flag( architecture="transformer", min_mlx_lm_version="0.22.0", sources={ - "int4": QuantSource( - hf_repo="mlx-community/test-4bit", disk_size_gb=4.5 - ), + "int4": QuantSource(hf_repo="mlx-community/test-4bit", disk_size_gb=4.5), }, capabilities=Capabilities( tool_calling=False, @@ -1048,9 +1031,7 @@ def test_thinking_without_reasoning_parser_no_flag( reasoning_parser=None, # thinking=True but no parser vision=False, ), - quality=QualityScores( - overall=50, coding=50, reasoning=50, instruction_following=50 - ), + quality=QualityScores(overall=50, coding=50, reasoning=50, instruction_following=50), benchmarks={}, tags=[], ) @@ -1195,16 +1176,25 @@ def test_successful_benchmark_running_tier( ) mock_iterations.return_value = [ IterationResult( - prompt_tps=150.0, gen_tps=80.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=150.0, + gen_tps=80.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), IterationResult( - prompt_tps=160.0, gen_tps=85.0, - prompt_tokens=1000, completion_tokens=100, total_time=9.5, + prompt_tps=160.0, + gen_tps=85.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=9.5, ), IterationResult( - prompt_tps=155.0, gen_tps=82.0, - prompt_tokens=1000, completion_tokens=100, total_time=9.8, + prompt_tps=155.0, + gen_tps=82.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=9.8, ), ] mock_profile.return_value = sample_profile @@ -1244,8 +1234,11 @@ def test_benchmark_with_no_profile_detects_hardware( min_mlx_lm_version="0.22.0", sources={"int4": QuantSource(hf_repo="test/test", disk_size_gb=4.0)}, capabilities=Capabilities( - tool_calling=False, tool_call_parser=None, - thinking=False, reasoning_parser=None, vision=False, + tool_calling=False, + tool_call_parser=None, + thinking=False, + reasoning_parser=None, + vision=False, ), quality=QualityScores(overall=50, coding=50, reasoning=50, instruction_following=50), benchmarks={}, @@ -1261,8 +1254,11 @@ def test_benchmark_with_no_profile_detects_hardware( ) mock_iterations.return_value = [ IterationResult( - prompt_tps=100.0, gen_tps=50.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=100.0, + gen_tps=50.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ] mock_detect.return_value = sample_profile @@ -1329,8 +1325,11 @@ def test_save_flag_persists_results( min_mlx_lm_version="0.22.0", sources={"int4": QuantSource(hf_repo="test/test", disk_size_gb=4.0)}, capabilities=Capabilities( - tool_calling=False, tool_call_parser=None, - thinking=False, reasoning_parser=None, vision=False, + tool_calling=False, + tool_call_parser=None, + thinking=False, + reasoning_parser=None, + vision=False, ), quality=QualityScores(overall=50, coding=50, reasoning=50, instruction_following=50), benchmarks={}, @@ -1346,8 +1345,11 @@ def test_save_flag_persists_results( ) mock_iterations.return_value = [ IterationResult( - prompt_tps=100.0, gen_tps=50.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=100.0, + gen_tps=50.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ] mock_profile.return_value = sample_profile diff --git a/tests/unit/test_catalog.py b/tests/unit/test_catalog.py index aa173d4..df8c682 100644 --- a/tests/unit/test_catalog.py +++ b/tests/unit/test_catalog.py @@ -33,13 +33,13 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def catalog() -> list[CatalogEntry]: """Load the full shipped catalog.""" return load_catalog() -@pytest.fixture() +@pytest.fixture def sample_yaml_dir(tmp_path: Path) -> Path: """Create a temporary directory with valid catalog YAML files for testing.""" catalog_dir = tmp_path / "catalog" @@ -241,7 +241,12 @@ def test_six_distinct_family_values_in_yaml(self, catalog: list[CatalogEntry]) - """Catalog YAML files use 6 distinct family values.""" families = {e.family for e in catalog} expected_families = { - "Qwen 3.5", "Nemotron", "Gemma 3", "DeepSeek R1", "Qwen 3", "Llama 3.3", + "Qwen 3.5", + "Nemotron", + "Gemma 3", + "DeepSeek R1", + "Qwen 3", + "Llama 3.3", } assert families == expected_families @@ -306,7 +311,7 @@ def test_all_entries_have_benchmarks(self, catalog: list[CatalogEntry]) -> None: """Every entry has at least one benchmark entry.""" for entry in catalog: assert len(entry.benchmarks) > 0, f"{entry.id}: no benchmarks" - for hw_key, bench in entry.benchmarks.items(): + for bench in entry.benchmarks.values(): assert bench.prompt_tps > 0 assert bench.gen_tps > 0 assert bench.memory_gb > 0 @@ -439,9 +444,7 @@ def test_missing_required_field(self, tmp_path: Path) -> None: """Missing required field raises CatalogError.""" catalog_dir = tmp_path / "catalog" catalog_dir.mkdir() - (catalog_dir / "incomplete.yaml").write_text( - yaml.dump({"id": "test", "name": "Test"}) - ) + (catalog_dir / "incomplete.yaml").write_text(yaml.dump({"id": "test", "name": "Test"})) with pytest.raises(CatalogError, match="missing required field"): load_catalog_from_directory(str(catalog_dir)) @@ -512,7 +515,7 @@ def test_missing_benchmark_field(self, tmp_path: Path) -> None: data = _make_valid_entry() data["benchmarks"]["m4-pro-48"] = {"prompt_tps": 100.0} # missing gen_tps, memory_gb (catalog_dir / "no_gen_tps.yaml").write_text(yaml.dump(data)) - with pytest.raises(CatalogError, match="benchmark.*missing required field"): + with pytest.raises(CatalogError, match=r"benchmark.*missing required field"): load_catalog_from_directory(str(catalog_dir)) def test_non_string_tag(self, tmp_path: Path) -> None: @@ -530,7 +533,7 @@ def test_error_identifies_filename(self, tmp_path: Path) -> None: catalog_dir = tmp_path / "catalog" catalog_dir.mkdir() (catalog_dir / "specific-file.yaml").write_text(yaml.dump({"id": "x"})) - with pytest.raises(CatalogError, match="specific-file.yaml"): + with pytest.raises(CatalogError, match=r"specific-file\.yaml"): load_catalog_from_directory(str(catalog_dir)) def test_non_numeric_disk_size_gb(self, tmp_path: Path) -> None: @@ -540,7 +543,7 @@ def test_non_numeric_disk_size_gb(self, tmp_path: Path) -> None: data = _make_valid_entry() data["sources"]["int4"]["disk_size_gb"] = "abc" (catalog_dir / "bad_disk_size.yaml").write_text(yaml.dump(data)) - with pytest.raises(CatalogError, match="disk_size_gb.*must be numeric"): + with pytest.raises(CatalogError, match=r"disk_size_gb.*must be numeric"): load_catalog_from_directory(str(catalog_dir)) def test_non_numeric_quality_score(self, tmp_path: Path) -> None: @@ -550,7 +553,7 @@ def test_non_numeric_quality_score(self, tmp_path: Path) -> None: data = _make_valid_entry() data["quality"]["overall"] = "high" (catalog_dir / "bad_quality.yaml").write_text(yaml.dump(data)) - with pytest.raises(CatalogError, match="quality.*overall.*must be numeric"): + with pytest.raises(CatalogError, match=r"quality.*overall.*must be numeric"): load_catalog_from_directory(str(catalog_dir)) def test_non_numeric_benchmark_value(self, tmp_path: Path) -> None: @@ -560,7 +563,7 @@ def test_non_numeric_benchmark_value(self, tmp_path: Path) -> None: data = _make_valid_entry() data["benchmarks"]["m4-pro-48"]["gen_tps"] = "fast" (catalog_dir / "bad_bench.yaml").write_text(yaml.dump(data)) - with pytest.raises(CatalogError, match="benchmark.*gen_tps.*must be numeric"): + with pytest.raises(CatalogError, match=r"benchmark.*gen_tps.*must be numeric"): load_catalog_from_directory(str(catalog_dir)) def test_corrupted_disk_size_no_raw_valueerror(self, tmp_path: Path) -> None: @@ -830,9 +833,18 @@ def test_shipped_catalog_gated_models(self) -> None: gated = [e for e in catalog if e.gated] gated_ids = {e.id for e in gated} assert gated_ids == { - "deepseek-r1-32b", "gemma3-4b", "gemma3-12b", "gemma3-27b", - "llama3.3-8b", "nemotron-49b", "nemotron-8b", - "qwen3.5-3b", "qwen3.5-8b", "qwen3.5-14b", "qwen3.5-32b", "qwen3.5-72b", + "deepseek-r1-32b", + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama3.3-8b", + "nemotron-49b", + "nemotron-8b", + "qwen3.5-3b", + "qwen3.5-8b", + "qwen3.5-14b", + "qwen3.5-32b", + "qwen3.5-72b", } def test_shipped_catalog_non_gated_models(self) -> None: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index fd240e3..de0274b 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -114,7 +114,8 @@ class TestDataHomeAutoCreation: """Tests for data-home gating: only state-writing commands create it.""" def test_bare_command_does_not_create_data_home( - self, clean_mlx_stack_home: Path, + self, + clean_mlx_stack_home: Path, ) -> None: """Running bare mlx-stack (help) does NOT create the data directory.""" assert not clean_mlx_stack_home.exists() @@ -124,7 +125,8 @@ def test_bare_command_does_not_create_data_home( assert not clean_mlx_stack_home.exists() def test_readonly_subcommand_does_not_create_data_home( - self, clean_mlx_stack_home: Path, + self, + clean_mlx_stack_home: Path, ) -> None: """Read-only subcommands do NOT create the data directory.""" assert not clean_mlx_stack_home.exists() @@ -133,7 +135,8 @@ def test_readonly_subcommand_does_not_create_data_home( assert not clean_mlx_stack_home.exists() def test_state_writing_subcommand_creates_data_home( - self, clean_mlx_stack_home: Path, + self, + clean_mlx_stack_home: Path, ) -> None: """State-writing subcommands (config set) auto-create the data directory.""" assert not clean_mlx_stack_home.exists() @@ -143,7 +146,8 @@ def test_state_writing_subcommand_creates_data_home( assert clean_mlx_stack_home.is_dir() def test_existing_data_home_not_affected( - self, mlx_stack_home: Path, + self, + mlx_stack_home: Path, ) -> None: """Running CLI when data home already exists doesn't cause errors.""" assert mlx_stack_home.exists() diff --git a/tests/unit/test_cli_bench.py b/tests/unit/test_cli_bench.py index 4a157cc..978256f 100644 --- a/tests/unit/test_cli_bench.py +++ b/tests/unit/test_cli_bench.py @@ -34,13 +34,13 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def runner() -> CliRunner: """Create a Click test runner.""" return CliRunner() -@pytest.fixture() +@pytest.fixture def sample_result() -> BenchmarkResult_: """A sample successful benchmark result.""" return BenchmarkResult_( @@ -48,16 +48,25 @@ def sample_result() -> BenchmarkResult_: quant="int4", iterations=[ IterationResult( - prompt_tps=150.0, gen_tps=80.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=150.0, + gen_tps=80.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), IterationResult( - prompt_tps=160.0, gen_tps=85.0, - prompt_tokens=1000, completion_tokens=100, total_time=9.5, + prompt_tps=160.0, + gen_tps=85.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=9.5, ), IterationResult( - prompt_tps=155.0, gen_tps=82.0, - prompt_tokens=1000, completion_tokens=100, total_time=9.8, + prompt_tps=155.0, + gen_tps=82.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=9.8, ), ], prompt_tps_mean=155.0, @@ -86,7 +95,7 @@ def sample_result() -> BenchmarkResult_: ) -@pytest.fixture() +@pytest.fixture def sample_result_no_catalog() -> BenchmarkResult_: """A sample result with no catalog data for comparison.""" return BenchmarkResult_( @@ -94,8 +103,11 @@ def sample_result_no_catalog() -> BenchmarkResult_: quant="int4", iterations=[ IterationResult( - prompt_tps=150.0, gen_tps=80.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=150.0, + gen_tps=80.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=150.0, @@ -402,8 +414,11 @@ def test_warn_classification_displayed( quant="int4", iterations=[ IterationResult( - prompt_tps=80.0, gen_tps=60.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=80.0, + gen_tps=60.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=80.0, @@ -438,8 +453,11 @@ def test_fail_classification_displayed( quant="int4", iterations=[ IterationResult( - prompt_tps=50.0, gen_tps=30.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=50.0, + gen_tps=30.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=50.0, @@ -490,8 +508,11 @@ def test_below_catalog_shows_positive_delta( quant="int4", iterations=[ IterationResult( - prompt_tps=80.0, gen_tps=60.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=80.0, + gen_tps=60.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=80.0, @@ -527,8 +548,11 @@ def test_above_catalog_shows_negative_delta( quant="int4", iterations=[ IterationResult( - prompt_tps=100.0, gen_tps=90.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=100.0, + gen_tps=90.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=100.0, @@ -564,8 +588,11 @@ def test_exact_match_shows_zero_delta( quant="int4", iterations=[ IterationResult( - prompt_tps=155.0, gen_tps=85.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=155.0, + gen_tps=85.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=155.0, @@ -609,8 +636,11 @@ def test_failed_tool_call_displayed( quant="int4", iterations=[ IterationResult( - prompt_tps=100.0, gen_tps=50.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=100.0, + gen_tps=50.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=100.0, @@ -641,8 +671,11 @@ def test_no_tool_calling_skip_message( quant="int4", iterations=[ IterationResult( - prompt_tps=100.0, gen_tps=50.0, - prompt_tokens=1000, completion_tokens=100, total_time=10.0, + prompt_tps=100.0, + gen_tps=50.0, + prompt_tokens=1000, + completion_tokens=100, + total_time=10.0, ), ], prompt_tps_mean=100.0, diff --git a/tests/unit/test_cli_down.py b/tests/unit/test_cli_down.py index 0bd30cb..e98cc02 100644 --- a/tests/unit/test_cli_down.py +++ b/tests/unit/test_cli_down.py @@ -295,7 +295,10 @@ def test_lockfile_acquired_during_shutdown(self, mlx_stack_home: Path) -> None: patch( "mlx_stack.core.stack_down._stop_single_service", return_value=ServiceStopResult( - name="litellm", pid=1001, status="stopped", graceful=True, + name="litellm", + pid=1001, + status="stopped", + graceful=True, ), ), ): @@ -310,12 +313,14 @@ def test_lockfile_error_propagated(self, mlx_stack_home: Path) -> None: """VAL-DOWN-003: Lock conflict raises LockError.""" _create_pid_file(mlx_stack_home, "fast", 1001) - with patch( - "mlx_stack.core.stack_down.acquire_lock", - side_effect=LockError("Another operation is running"), + with ( + patch( + "mlx_stack.core.stack_down.acquire_lock", + side_effect=LockError("Another operation is running"), + ), + pytest.raises(LockError, match="Another operation"), ): - with pytest.raises(LockError, match="Another operation"): - run_down() + run_down() def test_tier_filter_stops_only_specified_tier(self, mlx_stack_home: Path) -> None: """VAL-DOWN-004: --tier stops only the specified tier.""" @@ -443,14 +448,21 @@ def test_mixed_stale_and_running_services(self, mlx_stack_home: Path) -> None: results_map = { "litellm": ServiceStopResult( - name="litellm", pid=1003, status="stopped", graceful=True, + name="litellm", + pid=1003, + status="stopped", + graceful=True, ), "standard": ServiceStopResult( - name="standard", pid=1001, status="stale", + name="standard", + pid=1001, + status="stale", error="Process 1001 already dead; cleaned up stale PID file.", ), "fast": ServiceStopResult( - name="fast", pid=None, status="corrupt", + name="fast", + pid=None, + status="corrupt", error="PID file contained non-numeric content; removed.", ), } @@ -536,7 +548,10 @@ def test_no_stack_definition_falls_back(self, mlx_stack_home: Path) -> None: patch( "mlx_stack.core.stack_down._stop_single_service", return_value=ServiceStopResult( - name="fast", pid=12345, status="stopped", graceful=True, + name="fast", + pid=12345, + status="stopped", + graceful=True, ), ), ): @@ -579,13 +594,22 @@ def test_shutdown_displays_summary_table(self, mlx_stack_home: Path) -> None: mock_result = DownResult( services=[ ServiceStopResult( - name="litellm", pid=1003, status="stopped", graceful=True, + name="litellm", + pid=1003, + status="stopped", + graceful=True, ), ServiceStopResult( - name="fast", pid=1002, status="stopped", graceful=True, + name="fast", + pid=1002, + status="stopped", + graceful=True, ), ServiceStopResult( - name="standard", pid=1001, status="stopped", graceful=False, + name="standard", + pid=1001, + status="stopped", + graceful=False, ), ], ) @@ -605,10 +629,16 @@ def test_shutdown_shows_graceful_method(self, mlx_stack_home: Path) -> None: mock_result = DownResult( services=[ ServiceStopResult( - name="fast", pid=1002, status="stopped", graceful=True, + name="fast", + pid=1002, + status="stopped", + graceful=True, ), ServiceStopResult( - name="standard", pid=1001, status="stopped", graceful=False, + name="standard", + pid=1001, + status="stopped", + graceful=False, ), ], ) @@ -627,7 +657,10 @@ def test_forced_sigkill_explicit_in_output(self, mlx_stack_home: Path) -> None: mock_result = DownResult( services=[ ServiceStopResult( - name="standard", pid=1001, status="stopped", graceful=False, + name="standard", + pid=1001, + status="stopped", + graceful=False, ), ], ) @@ -647,7 +680,10 @@ def test_graceful_sigterm_explicit_in_output(self, mlx_stack_home: Path) -> None mock_result = DownResult( services=[ ServiceStopResult( - name="fast", pid=1002, status="stopped", graceful=True, + name="fast", + pid=1002, + status="stopped", + graceful=True, ), ], ) @@ -664,7 +700,10 @@ def test_tier_filter_option(self, mlx_stack_home: Path) -> None: mock_result = DownResult( services=[ ServiceStopResult( - name="fast", pid=1002, status="stopped", graceful=True, + name="fast", + pid=1002, + status="stopped", + graceful=True, ), ], ) @@ -801,11 +840,11 @@ def test_full_lifecycle_cleanup(self, mlx_stack_home: Path) -> None: proc_mock = MagicMock() proc_mock.status.side_effect = [ "running", # alive check for litellm - "zombie", # dead after SIGTERM for litellm + "zombie", # dead after SIGTERM for litellm "running", # alive check for tier - "zombie", # dead after SIGTERM for tier + "zombie", # dead after SIGTERM for tier "running", # alive check for tier - "zombie", # dead after SIGTERM for tier + "zombie", # dead after SIGTERM for tier ] mock_psutil.Process.return_value = proc_mock mock_psutil.STATUS_ZOMBIE = "zombie" @@ -901,7 +940,10 @@ def test_litellm_stopped_before_model_servers(self, mlx_stack_home: Path) -> Non def mock_stop(name: str) -> ServiceStopResult: order.append(name) return ServiceStopResult( - name=name, pid=1000, status="stopped", graceful=True, + name=name, + pid=1000, + status="stopped", + graceful=True, ) with ( @@ -956,7 +998,10 @@ def test_model_servers_reversed(self, mlx_stack_home: Path) -> None: def mock_stop(name: str) -> ServiceStopResult: order.append(name) return ServiceStopResult( - name=name, pid=1000, status="stopped", graceful=True, + name=name, + pid=1000, + status="stopped", + graceful=True, ) with ( diff --git a/tests/unit/test_cli_init.py b/tests/unit/test_cli_init.py index d442bbf..ce9b69f 100644 --- a/tests/unit/test_cli_init.py +++ b/tests/unit/test_cli_init.py @@ -324,8 +324,10 @@ def test_schema_version_is_1(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert result["stack"]["schema_version"] == 1 @@ -336,8 +338,10 @@ def test_hardware_profile_matches(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert result["stack"]["hardware_profile"] == profile.profile_id @@ -348,8 +352,10 @@ def test_intent_matches(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="agent-fleet", force=True) assert result["stack"]["intent"] == "agent-fleet" @@ -360,8 +366,10 @@ def test_name_is_default(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert result["stack"]["name"] == "default" @@ -372,8 +380,10 @@ def test_created_timestamp_is_iso8601(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) created = result["stack"]["created"] @@ -387,8 +397,10 @@ def test_tiers_have_required_fields(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) for tier in result["stack"]["tiers"]: @@ -405,8 +417,10 @@ def test_tier_ports_are_unique(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) ports = [t["port"] for t in result["stack"]["tiers"]] @@ -418,8 +432,10 @@ def test_tier_ports_dont_conflict_with_litellm(self, mlx_stack_home: Path) -> No catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) ports = {t["port"] for t in result["stack"]["tiers"]} @@ -440,8 +456,10 @@ def test_stack_yaml_written(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) stack_path = Path(result["stack_path"]) @@ -457,8 +475,10 @@ def test_litellm_yaml_written(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) litellm_path = Path(result["litellm_path"]) @@ -472,8 +492,10 @@ def test_directory_auto_created(self, clean_mlx_stack_home: Path) -> None: profile = _make_profile() catalog = _make_test_catalog() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert Path(result["stack_path"]).exists() @@ -494,8 +516,10 @@ def test_model_list_has_correct_count(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) num_tiers = len(result["stack"]["tiers"]) @@ -508,20 +532,21 @@ def test_api_base_matches_tier_port(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) tiers = result["stack"]["tiers"] model_list = result["litellm_config"]["model_list"] for tier in tiers: - matching = [ - m for m in model_list - if m["model_name"] == tier["name"] - ] + matching = [m for m in model_list if m["model_name"] == tier["name"]] assert len(matching) == 1 - assert matching[0]["litellm_params"]["api_base"] == f"http://localhost:{tier['port']}/v1" + assert ( + matching[0]["litellm_params"]["api_base"] == f"http://localhost:{tier['port']}/v1" + ) def test_model_uses_openai_prefix(self, mlx_stack_home: Path) -> None: """Model identifiers use openai/ prefix.""" @@ -529,8 +554,10 @@ def test_model_uses_openai_prefix(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) for entry in result["litellm_config"]["model_list"]: @@ -542,8 +569,10 @@ def test_api_key_is_dummy(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) for entry in result["litellm_config"]["model_list"]: @@ -555,8 +584,10 @@ def test_router_settings_present(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) rs = result["litellm_config"]["router_settings"] @@ -570,8 +601,10 @@ def test_fallback_chain_present(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) litellm = result["litellm_config"] @@ -598,9 +631,12 @@ def test_cloud_fallback_with_key(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile), \ - patch("mlx_stack.core.stack_init.get_value") as mock_get: + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + patch("mlx_stack.core.stack_init.get_value") as mock_get, + ): + def config_side_effect(key: str): if key == "openrouter-key": return "sk-or-test123" @@ -611,6 +647,7 @@ def config_side_effect(key: str): if key == "model-dir": return str(mlx_stack_home / "models") return "" + mock_get.side_effect = config_side_effect result = run_init(intent="balanced", force=True) @@ -620,8 +657,7 @@ def config_side_effect(key: str): # LiteLLM config should have premium entries premium = [ - e for e in result["litellm_config"]["model_list"] - if e["model_name"] == "premium" + e for e in result["litellm_config"]["model_list"] if e["model_name"] == "premium" ] assert len(premium) > 0 @@ -631,15 +667,16 @@ def test_no_cloud_without_key(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert "cloud_fallback" not in result["stack"] premium = [ - e for e in result["litellm_config"]["model_list"] - if e["model_name"] == "premium" + e for e in result["litellm_config"]["model_list"] if e["model_name"] == "premium" ] assert len(premium) == 0 @@ -659,13 +696,17 @@ def test_overwrite_blocked_without_force(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) # Create initial stack - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): run_init(intent="balanced", force=True) # Try again without force - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): with pytest.raises(InitError, match="already exists"): run_init(intent="balanced", force=False) @@ -675,13 +716,17 @@ def test_overwrite_allowed_with_force(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): run_init(intent="balanced", force=True) # Overwrite with force - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert result["stack"]["schema_version"] == 1 @@ -701,8 +746,10 @@ def test_remove_tier(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init( intent="balanced", remove_tiers=["fast"], @@ -718,8 +765,10 @@ def test_remove_invalid_tier_errors(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): with pytest.raises(InitError, match="Cannot remove tier"): run_init( intent="balanced", @@ -733,8 +782,10 @@ def test_add_model(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init( intent="balanced", add_models=["medium-model"], @@ -750,8 +801,10 @@ def test_add_unknown_model_errors(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): with pytest.raises(InitError, match="Unknown model"): run_init( intent="balanced", @@ -765,8 +818,10 @@ def test_invalid_intent_errors(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): with pytest.raises(InitError, match="Invalid intent"): run_init(intent="invalid", force=True) @@ -846,8 +901,10 @@ def test_accept_defaults_completes(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -859,8 +916,10 @@ def test_accept_defaults_with_intent(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults", "--intent", "agent-fleet"]) assert result.exit_code == 0 @@ -872,8 +931,10 @@ def test_overwrite_without_force_exits_error(self, mlx_stack_home: Path) -> None _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): # First init result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -890,8 +951,10 @@ def test_force_allows_overwrite(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -905,8 +968,10 @@ def test_output_shows_file_paths(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert "default.yaml" in result.output @@ -919,8 +984,10 @@ def test_output_shows_tier_assignments(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert "standard" in result.output or "fast" in result.output @@ -932,8 +999,10 @@ def test_output_shows_next_steps(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) # Should mention next steps @@ -946,8 +1015,10 @@ def test_missing_models_shows_pull_suggestion(self, mlx_stack_home: Path) -> Non _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) # Models are not downloaded, so should suggest pulling @@ -960,8 +1031,10 @@ def test_generated_stack_yaml_is_valid(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -982,8 +1055,10 @@ def test_generated_litellm_yaml_is_valid(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -1001,11 +1076,11 @@ def test_add_option_works(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): - result = runner.invoke( - cli, ["init", "--accept-defaults", "--add", "medium-model"] - ) + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): + result = runner.invoke(cli, ["init", "--accept-defaults", "--add", "medium-model"]) assert result.exit_code == 0 @@ -1016,11 +1091,11 @@ def test_remove_option_works(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): - result = runner.invoke( - cli, ["init", "--accept-defaults", "--remove", "fast"] - ) + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): + result = runner.invoke(cli, ["init", "--accept-defaults", "--remove", "fast"]) assert result.exit_code == 0 @@ -1032,8 +1107,10 @@ def test_different_intents_produce_different_stacks(self, mlx_stack_home: Path) results = {} for intent_name in ["balanced", "agent-fleet"]: - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent=intent_name, force=True) results[intent_name] = result["stack"] @@ -1047,8 +1124,10 @@ def test_vllm_flags_in_generated_stack(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) for tier in result["stack"]["tiers"]: @@ -1067,6 +1146,7 @@ class TestPortInUseDetection: def test_skips_port_in_use(self) -> None: """Ports detected as in-use are skipped, next available selected.""" + # Mock _is_port_available: 8000 is in use, 8001 is free def mock_available(port: int) -> bool: return port != 8000 @@ -1128,9 +1208,11 @@ def test_port_detection_in_full_init(self, mlx_stack_home: Path) -> None: def mock_available(port: int) -> bool: return port != 8000 - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile), \ - patch("mlx_stack.core.stack_init._is_port_available", side_effect=mock_available): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + patch("mlx_stack.core.stack_init._is_port_available", side_effect=mock_available), + ): result = run_init(intent="balanced", force=True) tier_ports = [t["port"] for t in result["stack"]["tiers"]] @@ -1153,8 +1235,10 @@ def test_total_memory_in_result(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) assert "total_memory_gb" in result @@ -1169,8 +1253,10 @@ def test_total_memory_displayed_in_summary(self, mlx_stack_home: Path) -> None: _write_profile(mlx_stack_home, profile) runner = CliRunner() - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = runner.invoke(cli, ["init", "--accept-defaults"]) assert result.exit_code == 0 @@ -1184,8 +1270,10 @@ def test_total_memory_sum_is_correct(self, mlx_stack_home: Path) -> None: catalog = _make_test_catalog() _write_profile(mlx_stack_home, profile) - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) # The total should be positive. Note: individual models fit within budget, @@ -1225,8 +1313,10 @@ def test_init_excludes_gated_models(self, mlx_stack_home: Path) -> None: ), ] - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init(intent="balanced", force=True) tier_model_ids = {t["model"] for t in result["stack"]["tiers"]} @@ -1243,8 +1333,10 @@ def test_add_gated_model_warns(self, mlx_stack_home: Path) -> None: _make_entry(model_id="gated-model", name="Gated Model", gated=True), ] - with patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), \ - patch("mlx_stack.core.stack_init.load_profile", return_value=profile): + with ( + patch("mlx_stack.core.stack_init.load_catalog", return_value=catalog), + patch("mlx_stack.core.stack_init.load_profile", return_value=profile), + ): result = run_init( intent="balanced", add_models=["gated-model"], diff --git a/tests/unit/test_cli_install.py b/tests/unit/test_cli_install.py index c3be7d9..5ff33b6 100644 --- a/tests/unit/test_cli_install.py +++ b/tests/unit/test_cli_install.py @@ -27,13 +27,13 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def runner() -> CliRunner: """Create a Click CliRunner.""" return CliRunner() -@pytest.fixture() +@pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict: """Create a test stack definition.""" stacks_dir = mlx_stack_home / "stacks" @@ -115,9 +115,7 @@ def test_install_in_lifecycle_category(self, runner: CliRunner) -> None: class TestInstallStatus: """Tests for install --status flag.""" - def test_status_not_installed( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_status_not_installed(self, runner: CliRunner, mlx_stack_home: Path) -> None: status = AgentStatus(installed=False, running=False, pid=None) with patch("mlx_stack.cli.install.get_agent_status", return_value=status): result = runner.invoke(cli, ["install", "--status"]) @@ -125,9 +123,7 @@ def test_status_not_installed( assert result.exit_code == 0 assert "not installed" in result.output - def test_status_installed_and_running( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_status_installed_and_running(self, runner: CliRunner, mlx_stack_home: Path) -> None: status = AgentStatus(installed=True, running=True, pid=12345) with patch("mlx_stack.cli.install.get_agent_status", return_value=status): result = runner.invoke(cli, ["install", "--status"]) @@ -146,9 +142,7 @@ def test_status_installed_but_not_running( assert result.exit_code == 0 assert "installed but not running" in result.output - def test_status_platform_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_status_platform_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.check_platform", side_effect=PlatformError("only available on macOS"), @@ -225,9 +219,7 @@ def test_mentions_auto_start( class TestInstallErrors: """Tests for install command error handling.""" - def test_platform_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_platform_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.install_agent", side_effect=PlatformError("only available on macOS"), @@ -238,9 +230,7 @@ def test_platform_error( assert "Error" in result.output assert "macOS" in result.output - def test_prerequisite_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_prerequisite_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.install_agent", side_effect=PrerequisiteError( @@ -253,9 +243,7 @@ def test_prerequisite_error( assert "Error" in result.output assert "init" in result.output.lower() - def test_launchd_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_launchd_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.install_agent", side_effect=LaunchdError("launchctl bootstrap failed"), @@ -296,27 +284,21 @@ def test_no_python_traceback_on_prerequisite_error( class TestUninstallCommand: """Tests for the uninstall command.""" - def test_successful_uninstall( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_successful_uninstall(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch("mlx_stack.cli.install.uninstall_agent", return_value=True): result = runner.invoke(cli, ["uninstall"]) assert result.exit_code == 0 assert "uninstalled" in result.output.lower() - def test_not_installed_message( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_not_installed_message(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch("mlx_stack.cli.install.uninstall_agent", return_value=False): result = runner.invoke(cli, ["uninstall"]) assert result.exit_code == 0 assert "not installed" in result.output.lower() - def test_services_unaffected_message( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_services_unaffected_message(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch("mlx_stack.cli.install.uninstall_agent", return_value=True): result = runner.invoke(cli, ["uninstall"]) @@ -332,9 +314,7 @@ def test_services_unaffected_message( class TestUninstallErrors: """Tests for uninstall command error handling.""" - def test_platform_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_platform_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.uninstall_agent", side_effect=PlatformError("only available on macOS"), @@ -344,9 +324,7 @@ def test_platform_error( assert result.exit_code != 0 assert "Error" in result.output - def test_launchd_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_launchd_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.uninstall_agent", side_effect=LaunchdError("bootout failed"), @@ -356,9 +334,7 @@ def test_launchd_error( assert result.exit_code != 0 assert "Error" in result.output - def test_no_python_traceback( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_no_python_traceback(self, runner: CliRunner, mlx_stack_home: Path) -> None: with patch( "mlx_stack.cli.install.uninstall_agent", side_effect=LaunchdError("bootout failed"), diff --git a/tests/unit/test_cli_logs.py b/tests/unit/test_cli_logs.py index bf1cb41..c24bcb0 100644 --- a/tests/unit/test_cli_logs.py +++ b/tests/unit/test_cli_logs.py @@ -27,9 +27,7 @@ def _create_log(logs_dir: Path, service: str, content: str = "") -> Path: return log_path -def _create_archive( - logs_dir: Path, service: str, number: int, content: str -) -> Path: +def _create_archive(logs_dir: Path, service: str, number: int, content: str) -> Path: """Create a gzip archive for a service.""" logs_dir.mkdir(parents=True, exist_ok=True) archive_path = logs_dir / f"{service}.log.{number}.gz" @@ -377,9 +375,7 @@ def test_logs_in_main_help(self) -> None: class TestEdgeCases: """Edge case tests.""" - def test_service_argument_takes_precedence_over_flag( - self, mlx_stack_home: Path - ) -> None: + def test_service_argument_takes_precedence_over_flag(self, mlx_stack_home: Path) -> None: """Positional argument takes precedence over --service flag.""" logs_dir = mlx_stack_home / "logs" _create_log(logs_dir, "fast", "fast content\n") diff --git a/tests/unit/test_cli_models.py b/tests/unit/test_cli_models.py index a1bd142..42894c8 100644 --- a/tests/unit/test_cli_models.py +++ b/tests/unit/test_cli_models.py @@ -894,15 +894,21 @@ def test_filter_by_family(self, mlx_stack_home: Path) -> None: """--family filters catalog to matching family only.""" catalog = [ _make_entry( - model_id="qwen-a", name="Qwen A", family="Qwen 3.5", + model_id="qwen-a", + name="Qwen A", + family="Qwen 3.5", tags=["balanced"], ), _make_entry( - model_id="qwen-b", name="Qwen B", family="Qwen 3.5", + model_id="qwen-b", + name="Qwen B", + family="Qwen 3.5", tags=["balanced"], ), _make_entry( - model_id="gemma-a", name="Gemma A", family="Gemma 3", + model_id="gemma-a", + name="Gemma A", + family="Gemma 3", tags=["balanced"], ), ] @@ -939,11 +945,13 @@ def test_filter_by_tag(self, mlx_stack_home: Path) -> None: """--tag filters catalog to models with the specified tag.""" catalog = [ _make_entry( - model_id="agent-model", name="Agent Model", + model_id="agent-model", + name="Agent Model", tags=["agent-ready", "balanced"], ), _make_entry( - model_id="basic-model", name="Basic Model", + model_id="basic-model", + name="Basic Model", tags=["balanced"], ), ] @@ -963,11 +971,13 @@ def test_filter_by_tool_calling(self, mlx_stack_home: Path) -> None: """--tool-calling filters to tool-calling-capable models only.""" catalog = [ _make_entry( - model_id="with-tools", name="With Tools", + model_id="with-tools", + name="With Tools", tool_calling=True, ), _make_entry( - model_id="no-tools", name="No Tools", + model_id="no-tools", + name="No Tools", tool_calling=False, ), ] @@ -987,18 +997,24 @@ def test_combined_filters(self, mlx_stack_home: Path) -> None: """Multiple filters are applied together (AND logic).""" catalog = [ _make_entry( - model_id="match", name="Match Both", - family="Qwen 3.5", tool_calling=True, + model_id="match", + name="Match Both", + family="Qwen 3.5", + tool_calling=True, tags=["agent-ready"], ), _make_entry( - model_id="family-only", name="Family Only", - family="Qwen 3.5", tool_calling=False, + model_id="family-only", + name="Family Only", + family="Qwen 3.5", + tool_calling=False, tags=[], ), _make_entry( - model_id="tools-only", name="Tools Only", - family="Gemma 3", tool_calling=True, + model_id="tools-only", + name="Tools Only", + family="Gemma 3", + tool_calling=True, tags=["agent-ready"], ), ] @@ -1029,9 +1045,7 @@ def test_no_matches_message(self, mlx_stack_home: Path) -> None: patch("mlx_stack.cli.models.load_catalog", return_value=catalog), patch("mlx_stack.cli.models.load_profile", return_value=None), ): - result = runner.invoke( - cli, ["models", "--catalog", "--family", "nonexistent"] - ) + result = runner.invoke(cli, ["models", "--catalog", "--family", "nonexistent"]) assert result.exit_code == 0 assert "No models match" in result.output @@ -1057,9 +1071,7 @@ def test_real_catalog_family_filter(self, mlx_stack_home: Path) -> None: """VAL-CATALOG-002: Filter real catalog by family.""" runner = CliRunner() with patch("mlx_stack.cli.models.load_profile", return_value=None): - result = runner.invoke( - cli, ["models", "--catalog", "--family", "qwen 3.5"] - ) + result = runner.invoke(cli, ["models", "--catalog", "--family", "qwen 3.5"]) assert result.exit_code == 0 assert "Qwen 3.5" in result.output @@ -1072,9 +1084,7 @@ def test_real_catalog_tool_calling_filter(self, mlx_stack_home: Path) -> None: """VAL-CATALOG-002: Filter real catalog by tool-calling capability.""" runner = CliRunner() with patch("mlx_stack.cli.models.load_profile", return_value=None): - result = runner.invoke( - cli, ["models", "--catalog", "--tool-calling"] - ) + result = runner.invoke(cli, ["models", "--catalog", "--tool-calling"]) assert result.exit_code == 0 # Should have some models but not all 15 @@ -1332,12 +1342,8 @@ def test_wrong_quant_local_model_is_remote(self, mlx_stack_home: Path) -> None: # Local has the int8 variant, not int4 _create_model_dir(models_dir, "qwen3.5-8b-8bit", size_bytes=1000) - local_models = scan_local_models( - models_dir=models_dir, catalog=catalog, stack=stack - ) - remote = get_remote_stack_models( - local_models=local_models, stack=stack, catalog=catalog - ) + local_models = scan_local_models(models_dir=models_dir, catalog=catalog, stack=stack) + remote = get_remote_stack_models(local_models=local_models, stack=stack, catalog=catalog) # The int4 model should show as remote since only int8 is local assert len(remote) == 1 @@ -1363,12 +1369,8 @@ def test_correct_quant_local_model_not_remote(self, mlx_stack_home: Path) -> Non _create_model_dir(models_dir, "qwen3.5-8b-4bit", size_bytes=1000) - local_models = scan_local_models( - models_dir=models_dir, catalog=catalog, stack=stack - ) - remote = get_remote_stack_models( - local_models=local_models, stack=stack, catalog=catalog - ) + local_models = scan_local_models(models_dir=models_dir, catalog=catalog, stack=stack) + remote = get_remote_stack_models(local_models=local_models, stack=stack, catalog=catalog) assert len(remote) == 0 @@ -1399,12 +1401,8 @@ def test_multi_tier_mixed_availability(self, mlx_stack_home: Path) -> None: # Only qwen is downloaded locally _create_model_dir(models_dir, "qwen3.5-8b-4bit", size_bytes=1000) - local_models = scan_local_models( - models_dir=models_dir, catalog=catalog, stack=stack - ) - remote = get_remote_stack_models( - local_models=local_models, stack=stack, catalog=catalog - ) + local_models = scan_local_models(models_dir=models_dir, catalog=catalog, stack=stack) + remote = get_remote_stack_models(local_models=local_models, stack=stack, catalog=catalog) assert len(remote) == 1 assert remote[0]["model_id"] == "nemotron-8b" diff --git a/tests/unit/test_cli_profile.py b/tests/unit/test_cli_profile.py index 6ab3985..ad30492 100644 --- a/tests/unit/test_cli_profile.py +++ b/tests/unit/test_cli_profile.py @@ -43,63 +43,49 @@ class TestProfileKnownChip: """VAL-PROFILE-001: Known Apple Silicon chip detection and display.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_exits_zero( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_exits_zero(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert result.exit_code == 0 @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_chip_name( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_chip_name(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "Apple M4 Pro" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_gpu_cores( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_gpu_cores(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "20" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_memory( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_memory(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "64 GB" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_bandwidth( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_bandwidth(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "273.0 GB/s" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_profile_id( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_profile_id(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "m4-pro-64" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_no_warning_for_known_chip( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_no_warning_for_known_chip(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) @@ -112,36 +98,28 @@ class TestProfileUnknownChip: """VAL-PROFILE-002: Unknown chip estimation with bench suggestion.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_exits_zero( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_exits_zero(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_unknown_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert result.exit_code == 0 @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_estimate_label( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_estimate_label(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_unknown_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "estimate" in result.output.lower() @patch("mlx_stack.cli.profile.detect_hardware") - def test_shows_bench_suggestion( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_shows_bench_suggestion(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_unknown_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "bench --save" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_profile_still_written( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_profile_still_written(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_unknown_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) @@ -157,9 +135,7 @@ class TestProfileNonAppleSilicon: """VAL-PROFILE-003: Non-Apple-Silicon rejection.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_nonzero_exit( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_nonzero_exit(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "mlx-stack requires Apple Silicon (M1 or later)" ) @@ -168,9 +144,7 @@ def test_nonzero_exit( assert result.exit_code != 0 @patch("mlx_stack.cli.profile.detect_hardware") - def test_error_message( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_error_message(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "mlx-stack requires Apple Silicon (M1 or later)" ) @@ -179,9 +153,7 @@ def test_error_message( assert "requires Apple Silicon" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_no_traceback( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_no_traceback(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "mlx-stack requires Apple Silicon (M1 or later)" ) @@ -190,9 +162,7 @@ def test_no_traceback( assert "Traceback" not in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_no_profile_written( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_no_profile_written(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "mlx-stack requires Apple Silicon (M1 or later)" ) @@ -206,9 +176,7 @@ class TestProfileJsonFormat: """VAL-PROFILE-004: Profile JSON is valid, complete, and correctly located.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_valid_json( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_valid_json(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() runner.invoke(cli, ["profile"]) @@ -218,9 +186,7 @@ def test_valid_json( assert isinstance(data, dict) @patch("mlx_stack.cli.profile.detect_hardware") - def test_all_required_fields( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_all_required_fields(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() runner.invoke(cli, ["profile"]) @@ -235,9 +201,7 @@ def test_all_required_fields( assert "profile_id" in data @patch("mlx_stack.cli.profile.detect_hardware") - def test_field_types( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_field_types(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() runner.invoke(cli, ["profile"]) @@ -252,9 +216,7 @@ def test_field_types( assert isinstance(data["profile_id"], str) @patch("mlx_stack.cli.profile.detect_hardware") - def test_profile_id_pattern( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_profile_id_pattern(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() runner.invoke(cli, ["profile"]) @@ -265,9 +227,7 @@ def test_profile_id_pattern( assert data["profile_id"] == "m4-pro-64" @patch("mlx_stack.cli.profile.detect_hardware") - def test_all_values_non_null( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_all_values_non_null(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() runner.invoke(cli, ["profile"]) @@ -283,18 +243,14 @@ class TestProfileRichTable: """VAL-PROFILE-005: Output is a Rich-formatted table.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_table_header_present( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_table_header_present(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) assert "Hardware Profile" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_table_has_property_labels( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_table_has_property_labels(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() result = runner.invoke(cli, ["profile"]) @@ -305,9 +261,7 @@ def test_table_has_property_labels( assert "Profile ID" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_table_has_borders( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_table_has_borders(self, mock_detect: object, mlx_stack_home: Path) -> None: """Rich tables include box-drawing characters or similar formatting.""" mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() @@ -322,9 +276,7 @@ class TestProfileOverwrite: """VAL-PROFILE-006: Re-running profile overwrites existing data.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_overwrite( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_overwrite(self, mock_detect: object, mlx_stack_home: Path) -> None: # First run with M1 hw1 = HardwareProfile("Apple M1", 8, 16, 68.25, False) mock_detect.return_value = hw1 # type: ignore[attr-defined] @@ -348,9 +300,7 @@ class TestProfileErrorHandling: """VAL-PROFILE-007: System command failures handled gracefully.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_sysctl_error_no_traceback( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_sysctl_error_no_traceback(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "sysctl failed for key 'machdep.cpu.brand_string': Operation not permitted" ) @@ -361,9 +311,7 @@ def test_sysctl_error_no_traceback( assert "Error" in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_profiler_error_no_traceback( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_profiler_error_no_traceback(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "system_profiler command not found — are you running on macOS?" ) @@ -373,9 +321,7 @@ def test_profiler_error_no_traceback( assert "Traceback" not in result.output @patch("mlx_stack.cli.profile.detect_hardware") - def test_descriptive_error_message( - self, mock_detect: object, mlx_stack_home: Path - ) -> None: + def test_descriptive_error_message(self, mock_detect: object, mlx_stack_home: Path) -> None: mock_detect.side_effect = HardwareError( # type: ignore[attr-defined] "sysctl timed out reading key 'hw.memsize'" ) @@ -388,9 +334,7 @@ class TestProfileAutoCreatesDirectory: """VAL-SETUP-004: Profile auto-creates ~/.mlx-stack/ on first use.""" @patch("mlx_stack.cli.profile.detect_hardware") - def test_creates_data_dir( - self, mock_detect: object, clean_mlx_stack_home: Path - ) -> None: + def test_creates_data_dir(self, mock_detect: object, clean_mlx_stack_home: Path) -> None: assert not clean_mlx_stack_home.exists() mock_detect.return_value = _mock_known_hardware() # type: ignore[attr-defined] runner = CliRunner() diff --git a/tests/unit/test_cli_pull.py b/tests/unit/test_cli_pull.py index caa4e43..cb2c1cf 100644 --- a/tests/unit/test_cli_pull.py +++ b/tests/unit/test_cli_pull.py @@ -758,9 +758,7 @@ def test_calls_snapshot_download_with_correct_args( console = Console(file=StringIO()) _run_download("test/repo", local_dir, console) - mock_snapshot.assert_called_once_with( - repo_id="test/repo", local_dir=str(local_dir) - ) + mock_snapshot.assert_called_once_with(repo_id="test/repo", local_dir=str(local_dir)) @patch("mlx_stack.core.pull.snapshot_download", side_effect=Exception("Repo not found")) def test_wraps_exception_in_download_error( @@ -1341,9 +1339,7 @@ def test_gated_repo_error_caught( mock_response.status_code = 403 mock_response.headers = {} mock_response.url = "https://huggingface.co/test/repo" - mock_snapshot.side_effect = HfGatedRepoError( - "gated repo", response=mock_response - ) + mock_snapshot.side_effect = HfGatedRepoError("gated repo", response=mock_response) local_dir = tmp_path / "model" local_dir.mkdir() diff --git a/tests/unit/test_cli_recommend.py b/tests/unit/test_cli_recommend.py index 5c9b94e..2f4687a 100644 --- a/tests/unit/test_cli_recommend.py +++ b/tests/unit/test_cli_recommend.py @@ -443,12 +443,8 @@ def test_balanced_vs_agent_fleet_different_tiers( balanced_lines = result_balanced.output.split("\n") agent_lines = result_agent.output.split("\n") - balanced_standard = [ - line for line in balanced_lines if "standard" in line.lower() - ] - agent_standard = [ - line for line in agent_lines if "standard" in line.lower() - ] + balanced_standard = [line for line in balanced_lines if "standard" in line.lower()] + agent_standard = [line for line in agent_lines if "standard" in line.lower()] # Both should have a standard tier assert len(balanced_standard) > 0 @@ -506,7 +502,8 @@ def test_fast_is_highest_tps( assert result.exit_code == 0 output_lines = result.output.split("\n") fast_line = [ - line for line in output_lines + line + for line in output_lines if "fast" in line.lower() and "standard" not in line.lower() ] assert len(fast_line) > 0 @@ -578,9 +575,7 @@ def test_small_budget_fewer_tiers( name="Tiny 1B", quality_overall=30, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=200.0, gen_tps=150.0, memory_gb=1.0 - ), + "m4-max-128": BenchmarkResult(prompt_tps=200.0, gen_tps=150.0, memory_gb=1.0), }, ), ] @@ -1048,9 +1043,7 @@ def test_saved_benchmarks_used( "memory_gb": 5.5, } } - (benchmarks_dir / f"{profile.profile_id}.json").write_text( - json.dumps(saved_data) - ) + (benchmarks_dir / f"{profile.profile_id}.json").write_text(json.dumps(saved_data)) runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) @@ -1221,9 +1214,7 @@ def test_malformed_benchmark_json_warning( "memory_gb": 5.5, } } - (benchmarks_dir / f"{profile.profile_id}.json").write_text( - json.dumps(saved_data) - ) + (benchmarks_dir / f"{profile.profile_id}.json").write_text(json.dumps(saved_data)) runner = CliRunner() result = runner.invoke(cli, ["recommend", "--show-all"]) diff --git a/tests/unit/test_cli_setup.py b/tests/unit/test_cli_setup.py index 61f75c7..be4e5fd 100644 --- a/tests/unit/test_cli_setup.py +++ b/tests/unit/test_cli_setup.py @@ -34,17 +34,26 @@ MOCK_UP_RESULT = SimpleNamespace( tiers=[ SimpleNamespace( - name="standard", model="Qwen3.5-9B", - port=8000, status="healthy", error=None, + name="standard", + model="Qwen3.5-9B", + port=8000, + status="healthy", + error=None, ), SimpleNamespace( - name="fast", model="SmallFast-4B", - port=8001, status="healthy", error=None, + name="fast", + model="SmallFast-4B", + port=8001, + status="healthy", + error=None, ), ], litellm=SimpleNamespace( - name="litellm", model="proxy", - port=4000, status="healthy", error=None, + name="litellm", + model="proxy", + port=4000, + status="healthy", + error=None, ), dry_run=False, warnings=[], @@ -54,19 +63,29 @@ MOCK_BENCHMARK_DATA = { "models": { "mlx-community/Qwen3.5-9B-4bit": { - "params_b": 9.0, "thinking": True, "tool_calling": True, - "benchmarks": {"m4-pro-64": { - "generation_tps": 62.0, "prompt_tps": 337.0, - "peak_memory_gib": 5.2, - }}, + "params_b": 9.0, + "thinking": True, + "tool_calling": True, + "benchmarks": { + "m4-pro-64": { + "generation_tps": 62.0, + "prompt_tps": 337.0, + "peak_memory_gib": 5.2, + } + }, "quality": {"overall_pass_rate": 0.98}, }, "mlx-community/SmallFast-4B-4bit": { - "params_b": 4.0, "thinking": False, "tool_calling": False, - "benchmarks": {"m4-pro-64": { - "generation_tps": 95.0, "prompt_tps": 500.0, - "peak_memory_gib": 2.4, - }}, + "params_b": 4.0, + "thinking": False, + "tool_calling": False, + "benchmarks": { + "m4-pro-64": { + "generation_tps": 95.0, + "prompt_tps": 500.0, + "peak_memory_gib": 2.4, + } + }, "quality": {"overall_pass_rate": 0.91}, }, }, @@ -81,17 +100,18 @@ def _run_setup(args: list[str], mlx_stack_home: Path) -> Any: patch("mlx_stack.core.onboarding.detect_hardware", return_value=MOCK_PROFILE), patch("mlx_stack.core.onboarding.save_profile"), patch("mlx_stack.core.discovery.query_hf_models", return_value=[]), - patch("mlx_stack.core.discovery.load_benchmark_data", - return_value=MOCK_BENCHMARK_DATA), - patch("mlx_stack.cli.setup.generate_config", - return_value=(mlx_stack_home / "stacks" / "default.yaml", - mlx_stack_home / "litellm.yaml")), + patch("mlx_stack.core.discovery.load_benchmark_data", return_value=MOCK_BENCHMARK_DATA), + patch( + "mlx_stack.cli.setup.generate_config", + return_value=( + mlx_stack_home / "stacks" / "default.yaml", + mlx_stack_home / "litellm.yaml", + ), + ), patch("mlx_stack.cli.setup.pull_setup_models", return_value=[]), patch("mlx_stack.cli.setup.start_stack", return_value=MOCK_UP_RESULT), ): - result = runner.invoke(setup, args) - - return result + return runner.invoke(setup, args) # --------------------------------------------------------------------------- # @@ -105,9 +125,7 @@ class TestSetupAcceptDefaults: def test_completes_successfully(self, mlx_stack_home: Path) -> None: """Setup with --accept-defaults exits 0.""" result = _run_setup(["--accept-defaults"], mlx_stack_home) - assert result.exit_code == 0, ( - f"Exit {result.exit_code}:\n{result.output}" - ) + assert result.exit_code == 0, f"Exit {result.exit_code}:\n{result.output}" def test_shows_hardware_info(self, mlx_stack_home: Path) -> None: """Output includes detected hardware details.""" @@ -144,7 +162,8 @@ class TestSetupIntentFlag: def test_intent_flag_accepted(self, mlx_stack_home: Path) -> None: """Providing --intent with --accept-defaults works.""" result = _run_setup( - ["--accept-defaults", "--intent", "agent-fleet"], mlx_stack_home, + ["--accept-defaults", "--intent", "agent-fleet"], + mlx_stack_home, ) assert result.exit_code == 0 @@ -157,8 +176,7 @@ def test_hardware_detection_failure(self, mlx_stack_home: Path) -> None: runner = CliRunner() with ( - patch("mlx_stack.core.onboarding.detect_hardware", - side_effect=RuntimeError("no chip")), + patch("mlx_stack.core.onboarding.detect_hardware", side_effect=RuntimeError("no chip")), patch("mlx_stack.core.onboarding.save_profile"), ): result = runner.invoke(setup, ["--accept-defaults"]) @@ -171,13 +189,10 @@ def test_no_models_found(self, mlx_stack_home: Path) -> None: runner = CliRunner() with ( - patch("mlx_stack.core.onboarding.detect_hardware", - return_value=MOCK_PROFILE), + patch("mlx_stack.core.onboarding.detect_hardware", return_value=MOCK_PROFILE), patch("mlx_stack.core.onboarding.save_profile"), - patch("mlx_stack.core.discovery.query_hf_models", - return_value=[]), - patch("mlx_stack.core.discovery.load_benchmark_data", - return_value={"models": {}}), + patch("mlx_stack.core.discovery.query_hf_models", return_value=[]), + patch("mlx_stack.core.discovery.load_benchmark_data", return_value={"models": {}}), ): result = runner.invoke(setup, ["--accept-defaults"]) diff --git a/tests/unit/test_cli_status.py b/tests/unit/test_cli_status.py index 9c4cd05..5ec82a6 100644 --- a/tests/unit/test_cli_status.py +++ b/tests/unit/test_cli_status.py @@ -302,12 +302,15 @@ def test_five_distinct_states( } def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str, Any]: - return status_map.get(service_name, { - "status": "stopped", - "pid": None, - "uptime": None, - "response_time": None, - }) + return status_map.get( + service_name, + { + "status": "stopped", + "pid": None, + "uptime": None, + "response_time": None, + }, + ) mock_status.side_effect = side_effect @@ -1244,8 +1247,14 @@ def test_json_output_all_five_states( ) -> None: """VAL-STATUS-001/004: All five states present in JSON.""" tiers = [ - {"name": f"t{i}", "model": f"m{i}", "quant": "int4", - "source": f"s{i}", "port": 8000 + i, "vllm_flags": {}} + { + "name": f"t{i}", + "model": f"m{i}", + "quant": "int4", + "source": f"s{i}", + "port": 8000 + i, + "vllm_flags": {}, + } for i in range(4) ] stack = _make_stack_yaml(tiers=tiers) @@ -1258,11 +1267,9 @@ def side_effect(service_name: str, port: int, health_path: str = "") -> dict[str st = status_list[idx] return { "status": st, - "pid": 1000 + idx if st not in ("stopped",) else None, + "pid": 1000 + idx if st != "stopped" else None, "uptime": 100.0 if st in ("healthy", "degraded", "down") else None, - "response_time": 0.1 if st == "healthy" else ( - 3.0 if st == "degraded" else None - ), + "response_time": 0.1 if st == "healthy" else (3.0 if st == "degraded" else None), } mock_status.side_effect = side_effect diff --git a/tests/unit/test_cli_up.py b/tests/unit/test_cli_up.py index db51029..71e3992 100644 --- a/tests/unit/test_cli_up.py +++ b/tests/unit/test_cli_up.py @@ -703,11 +703,16 @@ def test_successful_startup( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="test", pid=12345, port=8000, log_path=Path("/tmp/test.log"), + name="test", + pid=12345, + port=8000, + log_path=Path("/tmp/test.log"), pid_path=Path("/tmp/test.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up() @@ -752,11 +757,16 @@ def test_tier_filter_starts_only_one( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="fast", pid=12345, port=8001, log_path=Path("/tmp/fast.log"), + name="fast", + pid=12345, + port=8001, + log_path=Path("/tmp/fast.log"), pid_path=Path("/tmp/fast.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up(tier_filter="fast") @@ -809,11 +819,16 @@ def port_conflict_side_effect(port: int) -> tuple[int, str] | None: mock_port_conflict.side_effect = port_conflict_side_effect mock_start_service.return_value = ServiceInfo( - name="fast", pid=12345, port=8001, log_path=Path("/tmp/fast.log"), + name="fast", + pid=12345, + port=8001, + log_path=Path("/tmp/fast.log"), pid_path=Path("/tmp/fast.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up() @@ -870,11 +885,16 @@ def port_conflict_side_effect(port: int) -> tuple[int, str] | None: mock_port_conflict.side_effect = port_conflict_side_effect mock_start_service.return_value = ServiceInfo( - name="fast", pid=12345, port=8001, log_path=Path("/tmp/fast.log"), + name="fast", + pid=12345, + port=8001, + log_path=Path("/tmp/fast.log"), pid_path=Path("/tmp/fast.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up() @@ -925,15 +945,18 @@ def test_port_conflict_unknown_owner( mock_lock.return_value.__exit__ = MagicMock(return_value=False) # Port occupied but owner unknown (e.g., macOS permission issue) - mock_port_conflict.side_effect = lambda port: ( - (0, "") if port == 8000 else None - ) + mock_port_conflict.side_effect = lambda port: (0, "") if port == 8000 else None mock_start_service.return_value = ServiceInfo( - name="fast", pid=12345, port=8001, log_path=Path("/tmp/fast.log"), + name="fast", + pid=12345, + port=8001, + log_path=Path("/tmp/fast.log"), pid_path=Path("/tmp/fast.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up() @@ -978,7 +1001,10 @@ def test_health_check_timeout_continues( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="test", pid=12345, port=8000, log_path=Path("/tmp/test.log"), + name="test", + pid=12345, + port=8000, + log_path=Path("/tmp/test.log"), pid_path=Path("/tmp/test.pid"), ) @@ -1115,11 +1141,16 @@ def test_stale_pid_cleanup_and_restart( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="test", pid=99999, port=8000, log_path=Path("/tmp/test.log"), + name="test", + pid=99999, + port=8000, + log_path=Path("/tmp/test.log"), pid_path=Path("/tmp/test.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) result = run_up() @@ -1213,18 +1244,24 @@ def test_api_key_passed_via_env( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="test", pid=12345, port=8000, log_path=Path("/tmp/test.log"), + name="test", + pid=12345, + port=8000, + log_path=Path("/tmp/test.log"), pid_path=Path("/tmp/test.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) run_up() # Check that start_service was called for litellm with env dict litellm_calls = [ - c for c in mock_start_service.call_args_list + c + for c in mock_start_service.call_args_list if c.kwargs.get("service_name") == LITELLM_SERVICE_NAME or (c.args and c.args[0] == LITELLM_SERVICE_NAME) ] @@ -1274,11 +1311,16 @@ def test_memory_warning_displayed( mock_lock.return_value.__enter__ = MagicMock(return_value=None) mock_lock.return_value.__exit__ = MagicMock(return_value=False) mock_start_service.return_value = ServiceInfo( - name="test", pid=12345, port=8000, log_path=Path("/tmp/test.log"), + name="test", + pid=12345, + port=8000, + log_path=Path("/tmp/test.log"), pid_path=Path("/tmp/test.pid"), ) mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.5, status_code=200, + healthy=True, + response_time=0.5, + status_code=200, ) with patch("mlx_stack.core.stack_up.psutil.virtual_memory") as mock_vmem: @@ -1310,7 +1352,10 @@ def test_summary_table_displayed( TierStatus(name="fast", model="fast-model", port=8001, status="healthy"), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, status="healthy", + name="litellm", + model="proxy", + port=4000, + status="healthy", ), ) @@ -1333,12 +1378,17 @@ def test_already_running_message( mock_run_up.return_value = UpResult( tiers=[ TierStatus( - name="standard", model="big-model", port=8000, + name="standard", + model="big-model", + port=8000, status="already-running", ), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, status="already-running", + name="litellm", + model="proxy", + port=4000, + status="already-running", ), already_running=True, ) @@ -1358,16 +1408,24 @@ def test_partial_failure_summary( mock_run_up.return_value = UpResult( tiers=[ TierStatus( - name="standard", model="big-model", port=8000, - status="failed", error="Health check timeout", + name="standard", + model="big-model", + port=8000, + status="failed", + error="Health check timeout", ), TierStatus( - name="fast", model="fast-model", port=8001, + name="fast", + model="fast-model", + port=8001, status="healthy", ), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, status="healthy", + name="litellm", + model="proxy", + port=4000, + status="healthy", ), ) @@ -1387,13 +1445,19 @@ def test_all_failed_exit_code( mock_run_up.return_value = UpResult( tiers=[ TierStatus( - name="standard", model="big-model", port=8000, - status="failed", error="Port conflict", + name="standard", + model="big-model", + port=8000, + status="failed", + error="Port conflict", ), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, - status="skipped", error="All model servers failed", + name="litellm", + model="proxy", + port=4000, + status="skipped", + error="All model servers failed", ), ) @@ -1427,7 +1491,10 @@ def test_warning_displayed( TierStatus(name="standard", model="big-model", port=8000, status="healthy"), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, status="healthy", + name="litellm", + model="proxy", + port=4000, + status="healthy", ), warnings=["Estimated memory usage (50.0 GB) exceeds available (10.0 GB)"], ) @@ -1447,17 +1514,24 @@ def test_port_conflict_in_summary( mock_run_up.return_value = UpResult( tiers=[ TierStatus( - name="standard", model="big-model", port=8000, + name="standard", + model="big-model", + port=8000, status="skipped", error="Port 8000 already in use by PID 54321 (node)", ), TierStatus( - name="fast", model="fast-model", port=8001, + name="fast", + model="fast-model", + port=8001, status="healthy", ), ], litellm=TierStatus( - name="litellm", model="proxy", port=4000, status="healthy", + name="litellm", + model="proxy", + port=4000, + status="healthy", ), ) @@ -1503,9 +1577,7 @@ def test_custom_litellm_port( assert result.litellm.port == 5001 # Dry-run commands should reference port 5001 - litellm_cmds = [ - c for c in result.dry_run_commands if c["service"] == "litellm" - ] + litellm_cmds = [c for c in result.dry_run_commands if c["service"] == "litellm"] assert len(litellm_cmds) == 1 assert "5001" in litellm_cmds[0]["command"] diff --git a/tests/unit/test_cli_watch.py b/tests/unit/test_cli_watch.py index 90777d1..3716d56 100644 --- a/tests/unit/test_cli_watch.py +++ b/tests/unit/test_cli_watch.py @@ -21,13 +21,13 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def runner() -> CliRunner: """Create a Click CliRunner.""" return CliRunner() -@pytest.fixture() +@pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict: """Create a test stack definition.""" stacks_dir = mlx_stack_home / "stacks" @@ -93,34 +93,24 @@ def test_watch_help_shows_defaults(self, runner: CliRunner) -> None: class TestWatchParameterValidation: """Tests for watch command parameter validation.""" - def test_invalid_interval_zero( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_invalid_interval_zero(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch", "--interval", "0"]) assert result.exit_code != 0 assert "positive integer" in result.output.lower() or "Invalid" in result.output - def test_invalid_interval_negative( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_invalid_interval_negative(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch", "--interval", "-5"]) assert result.exit_code != 0 - def test_invalid_max_restarts_zero( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_invalid_max_restarts_zero(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch", "--max-restarts", "0"]) assert result.exit_code != 0 - def test_invalid_restart_delay_negative( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_invalid_restart_delay_negative(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch", "--restart-delay", "-1"]) assert result.exit_code != 0 - def test_invalid_interval_non_integer( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_invalid_interval_non_integer(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch", "--interval", "abc"]) assert result.exit_code != 0 @@ -133,9 +123,7 @@ def test_invalid_interval_non_integer( class TestWatchNoStack: """Tests for watch command when no stack is configured.""" - def test_no_stack_exits_with_error( - self, runner: CliRunner, mlx_stack_home: Path - ) -> None: + def test_no_stack_exits_with_error(self, runner: CliRunner, mlx_stack_home: Path) -> None: result = runner.invoke(cli, ["watch"]) assert result.exit_code != 0 assert "init" in result.output.lower() or "stack" in result.output.lower() @@ -269,9 +257,12 @@ def test_all_options_combined( cli, [ "watch", - "--interval", "45", - "--max-restarts", "10", - "--restart-delay", "15", + "--interval", + "45", + "--max-restarts", + "10", + "--restart-delay", + "15", "--daemon", ], ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 60c7d7a..28f20d4 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -6,6 +6,7 @@ from __future__ import annotations +import contextlib from pathlib import Path import pytest @@ -335,20 +336,16 @@ def test_set_invalid_value_rejected(self, mlx_stack_home: Path) -> None: def test_invalid_key_not_written_to_file(self, mlx_stack_home: Path) -> None: config_path = get_config_path() - try: + with contextlib.suppress(ConfigError): set_value("bad-key", "value") - except ConfigError: - pass if config_path.exists(): content = config_path.read_text() assert "bad-key" not in content def test_invalid_value_not_written_to_file(self, mlx_stack_home: Path) -> None: config_path = get_config_path() - try: + with contextlib.suppress(ConfigValidationError): set_value("default-quant", "int6") - except ConfigValidationError: - pass if config_path.exists(): content = config_path.read_text() assert "int6" not in content @@ -461,7 +458,7 @@ def test_api_key_masked_in_entries(self, mlx_stack_home: Path) -> None: set_value("openrouter-key", "sk-secret-key-12345") entries = get_all_config() key_entry = next(e for e in entries if e["name"] == "openrouter-key") - assert "sk-secret-key-12345" != key_entry["masked_value"] + assert key_entry["masked_value"] != "sk-secret-key-12345" assert "****" in key_entry["masked_value"] def test_corrupt_file_raises(self, mlx_stack_home: Path) -> None: diff --git a/tests/unit/test_cross_area.py b/tests/unit/test_cross_area.py index 0ebd894..40eaf2e 100644 --- a/tests/unit/test_cross_area.py +++ b/tests/unit/test_cross_area.py @@ -87,10 +87,12 @@ def _make_entry( min_mlx_lm_version="0.22.0", sources={ "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=disk_size_gb, + hf_repo=f"mlx-community/{model_id}-4bit", + disk_size_gb=disk_size_gb, ), "int8": QuantSource( - hf_repo=f"mlx-community/{model_id}-8bit", disk_size_gb=disk_size_gb * 2, + hf_repo=f"mlx-community/{model_id}-8bit", + disk_size_gb=disk_size_gb * 2, ), }, capabilities=Capabilities( @@ -211,9 +213,7 @@ def _write_saved_benchmarks( """Write saved benchmark data for the given profile.""" benchmarks_dir = home / "benchmarks" benchmarks_dir.mkdir(parents=True, exist_ok=True) - (benchmarks_dir / f"{profile_id}.json").write_text( - json.dumps(benchmarks, indent=2) - ) + (benchmarks_dir / f"{profile_id}.json").write_text(json.dumps(benchmarks, indent=2)) def _write_inventory(home: Path, entries: list[dict[str, Any]]) -> None: @@ -413,13 +413,18 @@ def fake_start_service( logs_dir.mkdir(parents=True, exist_ok=True) log_file = logs_dir / f"{service_name}.log" return ServiceInfo( - name=service_name, pid=pid, port=port, - log_path=log_file, pid_path=pid_file, + name=service_name, + pid=pid, + port=port, + log_path=log_file, + pid_path=pid_file, ) mock_start_service.side_effect = fake_start_service mock_wait_healthy.return_value = HealthCheckResult( - healthy=True, response_time=0.1, status_code=200, + healthy=True, + response_time=0.1, + status_code=200, ) runner = CliRunner() @@ -443,16 +448,18 @@ def fake_start_service( local_path.mkdir(parents=True, exist_ok=True) (local_path / "config.json").write_text("{}") - inventory_entries.append({ - "model_id": model_id, - "name": model_id, - "quant": quant, - "source_type": "mlx_community", - "hf_repo": source, - "local_path": str(local_path), - "disk_size_gb": 4.5, - "downloaded_at": "2026-03-24T00:00:00+00:00", - }) + inventory_entries.append( + { + "model_id": model_id, + "name": model_id, + "quant": quant, + "source_type": "mlx_community", + "hf_repo": source, + "local_path": str(local_path), + "disk_size_gb": 4.5, + "downloaded_at": "2026-03-24T00:00:00+00:00", + } + ) _write_inventory(mlx_stack_home, inventory_entries) @@ -475,9 +482,7 @@ def fake_start_service( tier_names = {t["name"] for t in stack["tiers"]} pid_file_names = {p.stem for p in pid_files} for tier_name in tier_names: - assert tier_name in pid_file_names, ( - f"No PID file for tier '{tier_name}'" - ) + assert tier_name in pid_file_names, f"No PID file for tier '{tier_name}'" assert "litellm" in pid_file_names, "No PID file for litellm" # ---- Step 4: Mock GET /v1/models returning 200 ---- @@ -485,9 +490,7 @@ def fake_start_service( expected_models = [t["model"] for t in stack["tiers"]] mock_response = { "object": "list", - "data": [ - {"id": f"openai/{m}", "object": "model"} for m in expected_models - ], + "data": [{"id": f"openai/{m}", "object": "model"} for m in expected_models], } import httpx @@ -516,6 +519,7 @@ def fake_start_service( patch("mlx_stack.core.stack_down.read_pid_file") as mock_read_pid, patch("mlx_stack.core.stack_down.remove_pid_file") as mock_remove_pid, ): + def read_pid_side_effect(name: str) -> int | None: pid_file = pids_dir / f"{name}.pid" if pid_file.exists(): @@ -585,17 +589,14 @@ def test_litellm_port_5000_in_generated_litellm_yaml( # But the tier ports in the stack should NOT be 5000 either stack = _read_stack_yaml(mlx_stack_home) tier_ports = {t["port"] for t in stack["tiers"]} - assert 5000 not in tier_ports, ( - "LiteLLM port 5000 should not be used as a vllm tier port" - ) + assert 5000 not in tier_ports, "LiteLLM port 5000 should not be used as a vllm tier port" # Verify the port 5000 is reflected in the stack or litellm config # (the actual litellm.yaml doesn't store the port since it's a # CLI flag, but the init output should mention it, and the # dry-run should use it) with ( - patch("mlx_stack.core.stack_up.load_catalog", - return_value=_make_test_catalog()), + patch("mlx_stack.core.stack_up.load_catalog", return_value=_make_test_catalog()), patch("mlx_stack.core.stack_up.get_value") as mock_get_val, ): mock_get_val.side_effect = lambda key: { @@ -729,9 +730,7 @@ def config_side_effect(key: str) -> Any: assert mock_download.called, "download_model was never called" call_args = mock_download.call_args hf_repo = call_args[0][0] if call_args[0] else call_args[1].get("hf_repo", "") - assert "8bit" in hf_repo, ( - f"Expected int8 HF repo (containing '8bit'), got: {hf_repo}" - ) + assert "8bit" in hf_repo, f"Expected int8 HF repo (containing '8bit'), got: {hf_repo}" # Verify add_to_inventory was called with quant=int8 assert mock_add_inv.called, "add_to_inventory was never called" @@ -780,9 +779,7 @@ def test_litellm_port_propagates_to_up_dry_run( result = runner.invoke(cli, ["up", "--dry-run"]) assert result.exit_code == 0 # The litellm command should use port 5001 - assert "5001" in result.output, ( - f"Port 5001 not found in dry-run output:\n{result.output}" - ) + assert "5001" in result.output, f"Port 5001 not found in dry-run output:\n{result.output}" @patch("mlx_stack.core.stack_init.load_catalog") @patch("mlx_stack.core.stack_init.detect_hardware") @@ -934,9 +931,7 @@ def test_saved_benchmarks_remove_estimated_label( second_output = result.output # The saved gen_tps value (85.0) should appear - assert "85.0" in second_output, ( - f"Expected saved gen_tps '85.0' in output:\n{second_output}" - ) + assert "85.0" in second_output, f"Expected saved gen_tps '85.0' in output:\n{second_output}" # Parse the output line-by-line to find the medium-8b row # and verify it does NOT have 'est.' marker @@ -946,9 +941,7 @@ def test_saved_benchmarks_remove_estimated_label( f"'Medium 8B' not found in recommend output:\n{second_output}" ) for line in medium_8b_lines: - assert "(est.)" not in line, ( - f"Medium 8B still shows 'est.' after bench --save: {line}" - ) + assert "(est.)" not in line, f"Medium 8B still shows 'est.' after bench --save: {line}" @patch("mlx_stack.cli.recommend.load_catalog") @patch("mlx_stack.cli.recommend.load_profile") @@ -995,8 +988,7 @@ def test_saved_benchmarks_affect_scoring_order( # Verify that the output text is actually different, proving # the saved benchmarks affected scoring. assert result_before.output != result_after.output, ( - "Recommend output should differ after bench --save with " - "dramatically different gen_tps" + "Recommend output should differ after bench --save with dramatically different gen_tps" ) @@ -1221,9 +1213,7 @@ def test_vllm_flags_in_dry_run_output( ) # Verify we actually tested something - assert has_tool_calling_tier, ( - "No tool-calling tiers found in stack — test is vacuous" - ) + assert has_tool_calling_tier, "No tool-calling tiers found in stack — test is vacuous" @patch("mlx_stack.core.stack_init.load_catalog") @patch("mlx_stack.core.stack_init.detect_hardware") diff --git a/tests/unit/test_deps.py b/tests/unit/test_deps.py index 0dfcf57..ec483c7 100644 --- a/tests/unit/test_deps.py +++ b/tests/unit/test_deps.py @@ -40,21 +40,21 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def mock_which(): """Patch shutil.which in the deps module.""" with patch("mlx_stack.core.deps.shutil.which") as m: yield m -@pytest.fixture() +@pytest.fixture def mock_subprocess_run(): """Patch subprocess.run in the deps module.""" with patch("mlx_stack.core.deps.subprocess.run") as m: yield m -@pytest.fixture() +@pytest.fixture def mock_console(): """Patch the Rich console in the deps module.""" with patch("mlx_stack.core.deps._console") as m: @@ -209,12 +209,7 @@ def test_handles_realistic_uv_tool_list_output( # Real uv tool list output includes indented binary entries below each tool mock_subprocess_run.return_value = MagicMock( returncode=0, - stdout=( - "vllm-mlx v0.2.6\n" - "- vllm-mlx\n" - "litellm v1.83.0\n" - "- litellm\n" - ), + stdout=("vllm-mlx v0.2.6\n- vllm-mlx\nlitellm v1.83.0\n- litellm\n"), ) assert _get_installed_version("vllm-mlx") == "0.2.6" assert _get_installed_version("litellm") == "1.83.0" @@ -430,9 +425,7 @@ def test_unknown_tool_raises_value_error(self) -> None: with pytest.raises(ValueError, match="Unknown dependency"): check_dependency("nonexistent-tool") - def test_check_litellm( - self, mock_which: MagicMock, mock_subprocess_run: MagicMock - ) -> None: + def test_check_litellm(self, mock_which: MagicMock, mock_subprocess_run: MagicMock) -> None: mock_which.side_effect = ["/usr/local/bin/litellm", "/usr/local/bin/uv"] mock_subprocess_run.return_value = MagicMock( returncode=0, @@ -483,11 +476,11 @@ def test_missing_tool_triggers_auto_install( # 4. check_dependency (re-check) -> _find_binary: found # 5. check_dependency (re-check) -> _get_installed_version -> uv found mock_which.side_effect = [ - None, # _find_binary: tool not found - "/usr/local/bin/uv", # _install_tool: uv found + None, # _find_binary: tool not found + "/usr/local/bin/uv", # _install_tool: uv found "/usr/local/bin/vllm-mlx", # _verify_post_install: tool found "/usr/local/bin/vllm-mlx", # re-check _find_binary - "/usr/local/bin/uv", # re-check _get_installed_version + "/usr/local/bin/uv", # re-check _get_installed_version ] # First run = install, second run = uv tool list for version check mock_subprocess_run.side_effect = [ @@ -507,7 +500,7 @@ def test_install_failure_raises_error( ) -> None: """When install fails, DependencyInstallError is raised.""" mock_which.side_effect = [ - None, # _find_binary: not found + None, # _find_binary: not found "/usr/local/bin/uv", # _install_tool: uv found ] mock_subprocess_run.return_value = MagicMock( @@ -525,13 +518,11 @@ def test_post_install_not_found_raises_error( ) -> None: """When tool not found after install, error with PATH instructions.""" mock_which.side_effect = [ - None, # _find_binary: not found + None, # _find_binary: not found "/usr/local/bin/uv", # _install_tool: uv found - None, # _verify_post_install: still not found + None, # _verify_post_install: still not found ] - mock_subprocess_run.return_value = MagicMock( - returncode=0, stdout="", stderr="" - ) + mock_subprocess_run.return_value = MagicMock(returncode=0, stdout="", stderr="") with pytest.raises(DependencyInstallError, match="not found on PATH"): ensure_dependency("vllm-mlx") @@ -672,8 +663,8 @@ def test_install_failure_no_traceback( ) -> None: """Error messages should not contain Python tracebacks.""" mock_which.side_effect = [ - None, # _find_binary - "/usr/local/bin/uv", # _install_tool + None, # _find_binary + "/usr/local/bin/uv", # _install_tool ] mock_subprocess_run.return_value = MagicMock( returncode=1, @@ -745,12 +736,7 @@ def test_uv_tool_list_with_multiple_tools_finds_vllm( mock_which.return_value = "/usr/local/bin/uv" mock_subprocess_run.return_value = MagicMock( returncode=0, - stdout=( - "ruff v0.8.0\n" - "vllm-mlx v0.2.6\n" - "litellm v1.83.0\n" - "mypy v1.13.0\n" - ), + stdout=("ruff v0.8.0\nvllm-mlx v0.2.6\nlitellm v1.83.0\nmypy v1.13.0\n"), ) assert _get_installed_version("vllm-mlx") == "0.2.6" @@ -761,12 +747,7 @@ def test_uv_tool_list_with_multiple_tools_finds_litellm( mock_which.return_value = "/usr/local/bin/uv" mock_subprocess_run.return_value = MagicMock( returncode=0, - stdout=( - "ruff v0.8.0\n" - "vllm-mlx v0.2.6\n" - "litellm v1.83.0\n" - "mypy v1.13.0\n" - ), + stdout=("ruff v0.8.0\nvllm-mlx v0.2.6\nlitellm v1.83.0\nmypy v1.13.0\n"), ) assert _get_installed_version("litellm") == "1.83.0" diff --git a/tests/unit/test_discovery.py b/tests/unit/test_discovery.py index a666946..f0d09a7 100644 --- a/tests/unit/test_discovery.py +++ b/tests/unit/test_discovery.py @@ -218,8 +218,10 @@ def test_benchmarked_models_have_performance_data(self, mock_hf, mock_bench) -> qwen9b = next(m for m in models if "9B" in m.display_name and "Qwen3.5" in m.display_name) assert qwen9b.has_benchmark is True - assert qwen9b.gen_tps is not None and qwen9b.gen_tps > 0 - assert qwen9b.memory_gb is not None and qwen9b.memory_gb > 0 + assert qwen9b.gen_tps is not None + assert qwen9b.gen_tps > 0 + assert qwen9b.memory_gb is not None + assert qwen9b.memory_gb > 0 assert qwen9b.quality_overall is not None @patch("mlx_stack.core.discovery.load_benchmark_data") diff --git a/tests/unit/test_hardware.py b/tests/unit/test_hardware.py index 3a173e9..f1c7b3d 100644 --- a/tests/unit/test_hardware.py +++ b/tests/unit/test_hardware.py @@ -32,6 +32,7 @@ # Helper: mock subprocess.run # --------------------------------------------------------------------------- # + def _make_completed( stdout: str = "", returncode: int = 0, stderr: str = "" ) -> subprocess.CompletedProcess[str]: @@ -175,7 +176,7 @@ def test_8gb(self, mock_sysctl: object) -> None: @patch("mlx_stack.core.hardware._run_sysctl") def test_invalid_value(self, mock_sysctl: object) -> None: mock_sysctl.return_value = "not-a-number" # type: ignore[attr-defined] - with pytest.raises(HardwareError, match="Unexpected hw.memsize"): + with pytest.raises(HardwareError, match=r"Unexpected hw\.memsize"): detect_memory_gb() @@ -229,7 +230,7 @@ def test_all_17_known_chips_in_table(self) -> None: assert len(CHIP_SPECS) == 17 @pytest.mark.parametrize( - "chip,expected_bw", + ("chip", "expected_bw"), [ ("Apple M1", 68.25), ("Apple M1 Pro", 200.0), @@ -292,9 +293,7 @@ class TestDetectHardware: @patch("mlx_stack.core.hardware._run_system_profiler") @patch("mlx_stack.core.hardware._run_sysctl") - def test_known_chip_full_detection( - self, mock_sysctl: object, mock_profiler: object - ) -> None: + def test_known_chip_full_detection(self, mock_sysctl: object, mock_profiler: object) -> None: def sysctl_side_effect(key: str) -> str: if key == "machdep.cpu.brand_string": return "Apple M4 Pro" @@ -316,9 +315,7 @@ def sysctl_side_effect(key: str) -> str: @patch("mlx_stack.core.hardware._run_system_profiler") @patch("mlx_stack.core.hardware._run_sysctl") - def test_unknown_chip_uses_estimate( - self, mock_sysctl: object, mock_profiler: object - ) -> None: + def test_unknown_chip_uses_estimate(self, mock_sysctl: object, mock_profiler: object) -> None: def sysctl_side_effect(key: str) -> str: if key == "machdep.cpu.brand_string": return "Apple M6" diff --git a/tests/unit/test_launchd.py b/tests/unit/test_launchd.py index 2cb0b83..0000b6a 100644 --- a/tests/unit/test_launchd.py +++ b/tests/unit/test_launchd.py @@ -46,7 +46,7 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: """Create a minimal stack definition (prerequisite for install).""" stacks_dir = mlx_stack_home / "stacks" @@ -75,7 +75,7 @@ def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: return stack -@pytest.fixture() +@pytest.fixture def plist_dir(tmp_path: Path) -> Path: """Provide a temporary LaunchAgents directory.""" d = tmp_path / "Library" / "LaunchAgents" @@ -83,7 +83,7 @@ def plist_dir(tmp_path: Path) -> Path: return d -@pytest.fixture() +@pytest.fixture def plist_path(plist_dir: Path) -> Path: """Provide a plist file path.""" return plist_dir / PLIST_FILENAME @@ -237,10 +237,8 @@ def test_fallback_to_sys_executable_dir(self) -> None: # Chain: Path(sys.executable).parent mock_path_instance = MagicMock() mock_path_instance.parent = mock_parent - mock_path_cls.side_effect = ( - lambda x: mock_path_instance - if x == "/opt/venv/bin/python" - else MagicMock() + mock_path_cls.side_effect = lambda x: ( + mock_path_instance if x == "/opt/venv/bin/python" else MagicMock() ) result = _resolve_mlx_stack_binary() @@ -296,16 +294,12 @@ def test_no_duplicate_binary_dir_in_standard_paths(self) -> None: count = components.count("/usr/local/bin") assert count == 1 - def test_mlx_stack_home_not_set_by_default( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_mlx_stack_home_not_set_by_default(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("MLX_STACK_HOME", raising=False) env = _build_environment_variables("/usr/bin/mlx-stack") assert "MLX_STACK_HOME" not in env - def test_mlx_stack_home_included_when_set( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_mlx_stack_home_included_when_set(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("MLX_STACK_HOME", "/custom/home") env = _build_environment_variables("/usr/bin/mlx-stack") assert env["MLX_STACK_HOME"] == "/custom/home" @@ -487,16 +481,12 @@ def test_success(self, plist_path: Path) -> None: def test_already_unloaded_is_nonfatal(self, plist_path: Path) -> None: """bootout returning 'No such process' should not raise.""" with patch("subprocess.run") as mock_run: - mock_run.return_value = MagicMock( - returncode=3, stderr="3: No such process" - ) + mock_run.return_value = MagicMock(returncode=3, stderr="3: No such process") unload_agent(plist_path) # Should not raise def test_other_failure_raises(self, plist_path: Path) -> None: with patch("subprocess.run") as mock_run: - mock_run.return_value = MagicMock( - returncode=1, stderr="Permission denied" - ) + mock_run.return_value = MagicMock(returncode=1, stderr="Permission denied") with pytest.raises(LaunchdError, match="launchctl bootout failed"): unload_agent(plist_path) @@ -612,9 +602,7 @@ def test_installed_but_not_running(self) -> None: class TestInstallAgent: """Tests for install_agent.""" - def test_fresh_install( - self, mlx_stack_home: Path, stack_definition: dict[str, Any] - ) -> None: + def test_fresh_install(self, mlx_stack_home: Path, stack_definition: dict[str, Any]) -> None: with ( patch("mlx_stack.core.launchd.check_platform"), patch("mlx_stack.core.launchd.get_plist_path") as mock_get_path, @@ -637,9 +625,7 @@ def test_fresh_install( assert plist_data["ProgramArguments"][0] == "/usr/bin/mlx-stack" assert "watch" in plist_data["ProgramArguments"] - def test_reinstall( - self, mlx_stack_home: Path, stack_definition: dict[str, Any] - ) -> None: + def test_reinstall(self, mlx_stack_home: Path, stack_definition: dict[str, Any]) -> None: with ( patch("mlx_stack.core.launchd.check_platform"), patch("mlx_stack.core.launchd.get_plist_path") as mock_get_path, @@ -652,13 +638,11 @@ def test_reinstall( mock_get_path.return_value = mock_plist_path mock_write.return_value = mock_plist_path - path, was_reinstall = install_agent("/usr/bin/mlx-stack") + _path, was_reinstall = install_agent("/usr/bin/mlx-stack") assert was_reinstall is True - def test_creates_logs_dir( - self, mlx_stack_home: Path, stack_definition: dict[str, Any] - ) -> None: + def test_creates_logs_dir(self, mlx_stack_home: Path, stack_definition: dict[str, Any]) -> None: with ( patch("mlx_stack.core.launchd.check_platform"), patch("mlx_stack.core.launchd.get_plist_path") as mock_get_path, @@ -705,7 +689,7 @@ def test_reinstall_best_effort_unload( mock_write.return_value = mock_plist_path # Should not raise despite unload failure - path, was_reinstall = install_agent("/usr/bin/mlx-stack") + _path, was_reinstall = install_agent("/usr/bin/mlx-stack") assert was_reinstall is True diff --git a/tests/unit/test_lifecycle_fixes.py b/tests/unit/test_lifecycle_fixes.py index 54f2e7e..05e223a 100644 --- a/tests/unit/test_lifecycle_fixes.py +++ b/tests/unit/test_lifecycle_fixes.py @@ -125,9 +125,7 @@ def _make_entry( architecture="transformer", min_mlx_lm_version="0.22.0", sources={ - "int4": QuantSource( - hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=4.5 - ), + "int4": QuantSource(hf_repo=f"mlx-community/{model_id}-4bit", disk_size_gb=4.5), }, capabilities=Capabilities( tool_calling=True, @@ -136,13 +134,9 @@ def _make_entry( reasoning_parser=None, vision=False, ), - quality=QualityScores( - overall=70, coding=65, reasoning=60, instruction_following=72 - ), + quality=QualityScores(overall=70, coding=65, reasoning=60, instruction_following=72), benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=100.0, gen_tps=50.0, memory_gb=memory_gb - ), + "m4-max-128": BenchmarkResult(prompt_tps=100.0, gen_tps=50.0, memory_gb=memory_gb), }, tags=[], ) @@ -333,9 +327,7 @@ def test_partial_models_present( statuses = {t.name: t.status for t in result.tiers} # big-model (standard) should be skipped — not on disk assert statuses["standard"] == "skipped" - message = next( - t.error for t in result.tiers if t.name == "standard" - ) or "" + message = next(t.error for t in result.tiers if t.name == "standard") or "" assert "not found locally" in message # fast-model (fast) should be healthy — on disk assert statuses["fast"] == "healthy" diff --git a/tests/unit/test_litellm_gen.py b/tests/unit/test_litellm_gen.py index 20294ae..6d90e70 100644 --- a/tests/unit/test_litellm_gen.py +++ b/tests/unit/test_litellm_gen.py @@ -121,10 +121,7 @@ def test_cloud_entries_use_premium_name(self) -> None: tiers=_make_tiers(1), openrouter_key="sk-or-test123", ) - cloud_entries = [ - e for e in config["model_list"] - if e["model_name"] == "premium" - ] + cloud_entries = [e for e in config["model_list"] if e["model_name"] == "premium"] assert len(cloud_entries) == 2 def test_cloud_entries_use_openrouter_prefix(self) -> None: @@ -133,10 +130,7 @@ def test_cloud_entries_use_openrouter_prefix(self) -> None: tiers=_make_tiers(1), openrouter_key="sk-or-test123", ) - cloud_entries = [ - e for e in config["model_list"] - if e["model_name"] == "premium" - ] + cloud_entries = [e for e in config["model_list"] if e["model_name"] == "premium"] for entry in cloud_entries: assert entry["litellm_params"]["model"].startswith("openrouter/") @@ -146,10 +140,7 @@ def test_cloud_api_key_uses_env_reference(self) -> None: tiers=_make_tiers(1), openrouter_key="sk-or-test123", ) - cloud_entries = [ - e for e in config["model_list"] - if e["model_name"] == "premium" - ] + cloud_entries = [e for e in config["model_list"] if e["model_name"] == "premium"] for entry in cloud_entries: assert entry["litellm_params"]["api_key"] == "os.environ/OPENROUTER_API_KEY" diff --git a/tests/unit/test_log_rotation.py b/tests/unit/test_log_rotation.py index 312eb1c..cbeb1a7 100644 --- a/tests/unit/test_log_rotation.py +++ b/tests/unit/test_log_rotation.py @@ -12,7 +12,7 @@ import pytest -from mlx_stack.core.log_rotation import LogRotationError, rotate_log # noqa: I001 +from mlx_stack.core.log_rotation import LogRotationError, rotate_log # --------------------------------------------------------------------------- # # Helpers diff --git a/tests/unit/test_log_viewer.py b/tests/unit/test_log_viewer.py index 80f069c..df16189 100644 --- a/tests/unit/test_log_viewer.py +++ b/tests/unit/test_log_viewer.py @@ -6,9 +6,11 @@ from __future__ import annotations +import contextlib import gzip import threading import time +from datetime import UTC from pathlib import Path from mlx_stack.core.log_viewer import ( @@ -39,9 +41,7 @@ def _create_log(logs_dir: Path, service: str, content: str = "") -> Path: return log_path -def _create_archive( - logs_dir: Path, service: str, number: int, content: str -) -> Path: +def _create_archive(logs_dir: Path, service: str, number: int, content: str) -> Path: """Create a gzip archive for a service.""" logs_dir.mkdir(parents=True, exist_ok=True) archive_path = logs_dir / f"{service}.log.{number}.gz" @@ -60,57 +60,57 @@ class TestLogFileInfo: def test_size_display_bytes(self) -> None: """Sizes under 1KB show in bytes.""" - from datetime import datetime, timezone + from datetime import datetime info = LogFileInfo( name="test.log", service="test", size_bytes=500, - modified=datetime.now(tz=timezone.utc), + modified=datetime.now(tz=UTC), ) assert info.size_display == "500 B" def test_size_display_kilobytes(self) -> None: - """Sizes 1KB–1MB show in KB.""" - from datetime import datetime, timezone + """Sizes 1KB-1MB show in KB.""" + from datetime import datetime info = LogFileInfo( name="test.log", service="test", size_bytes=2048, - modified=datetime.now(tz=timezone.utc), + modified=datetime.now(tz=UTC), ) assert info.size_display == "2.0 KB" def test_size_display_megabytes(self) -> None: - """Sizes 1MB–1GB show in MB.""" - from datetime import datetime, timezone + """Sizes 1MB-1GB show in MB.""" + from datetime import datetime info = LogFileInfo( name="test.log", service="test", size_bytes=5 * 1024 * 1024, - modified=datetime.now(tz=timezone.utc), + modified=datetime.now(tz=UTC), ) assert info.size_display == "5.0 MB" def test_size_display_gigabytes(self) -> None: """Sizes >= 1GB show in GB.""" - from datetime import datetime, timezone + from datetime import datetime info = LogFileInfo( name="test.log", service="test", size_bytes=2 * 1024 * 1024 * 1024, - modified=datetime.now(tz=timezone.utc), + modified=datetime.now(tz=UTC), ) assert info.size_display == "2.0 GB" def test_modified_display(self) -> None: """Modified time is formatted as expected.""" - from datetime import datetime, timezone + from datetime import datetime - dt = datetime(2025, 3, 15, 14, 30, 45, tzinfo=timezone.utc) + dt = datetime(2025, 3, 15, 14, 30, 45, tzinfo=UTC) info = LogFileInfo( name="test.log", service="test", @@ -303,10 +303,8 @@ def writer(text: str) -> None: stop_event.set() def run_follow() -> None: - try: + with contextlib.suppress(KeyboardInterrupt): follow_log(log, num_lines=5, output_callback=writer) - except KeyboardInterrupt: - pass thread = threading.Thread(target=run_follow, daemon=True) thread.start() @@ -343,10 +341,8 @@ def writer(text: str) -> None: new_content_event.set() def run_follow() -> None: - try: + with contextlib.suppress(KeyboardInterrupt): follow_log(log, num_lines=1, output_callback=writer) - except KeyboardInterrupt: - pass thread = threading.Thread(target=run_follow, daemon=True) thread.start() @@ -385,10 +381,8 @@ def writer(text: str) -> None: post_truncation_event.set() def run_follow() -> None: - try: + with contextlib.suppress(KeyboardInterrupt): follow_log(log, num_lines=2, output_callback=writer) - except KeyboardInterrupt: - pass thread = threading.Thread(target=run_follow, daemon=True) thread.start() diff --git a/tests/unit/test_onboarding.py b/tests/unit/test_onboarding.py index 404eda1..b95ff5a 100644 --- a/tests/unit/test_onboarding.py +++ b/tests/unit/test_onboarding.py @@ -56,16 +56,49 @@ def _model( # A realistic set of models for testing scoring and selection -SMALL_FAST = _model("SmallFast-2B", "mlx-community/SmallFast-2B-4bit", params_b=2.0, - gen_tps=120.0, memory_gb=1.5, quality_overall=0.60, tool_calling=True) -MEDIUM_QUALITY = _model("MedQuality-9B", "mlx-community/MedQuality-9B-4bit", params_b=9.0, - gen_tps=50.0, memory_gb=5.5, quality_overall=0.95, tool_calling=True) -LARGE_SLOW = _model("LargeSlow-27B", "mlx-community/LargeSlow-27B-4bit", params_b=27.0, - gen_tps=20.0, memory_gb=16.0, quality_overall=0.98) -TINY_MODEL = _model("Tiny-0.5B", "mlx-community/Tiny-0.5B-4bit", params_b=0.5, - gen_tps=180.0, memory_gb=0.5, quality_overall=0.40) -NO_BENCHMARK = _model("Unknown-7B", "mlx-community/Unknown-7B-4bit", params_b=7.0, - gen_tps=None, memory_gb=4.2, quality_overall=None, has_benchmark=False) +SMALL_FAST = _model( + "SmallFast-2B", + "mlx-community/SmallFast-2B-4bit", + params_b=2.0, + gen_tps=120.0, + memory_gb=1.5, + quality_overall=0.60, + tool_calling=True, +) +MEDIUM_QUALITY = _model( + "MedQuality-9B", + "mlx-community/MedQuality-9B-4bit", + params_b=9.0, + gen_tps=50.0, + memory_gb=5.5, + quality_overall=0.95, + tool_calling=True, +) +LARGE_SLOW = _model( + "LargeSlow-27B", + "mlx-community/LargeSlow-27B-4bit", + params_b=27.0, + gen_tps=20.0, + memory_gb=16.0, + quality_overall=0.98, +) +TINY_MODEL = _model( + "Tiny-0.5B", + "mlx-community/Tiny-0.5B-4bit", + params_b=0.5, + gen_tps=180.0, + memory_gb=0.5, + quality_overall=0.40, +) +NO_BENCHMARK = _model( + "Unknown-7B", + "mlx-community/Unknown-7B-4bit", + params_b=7.0, + gen_tps=None, + memory_gb=4.2, + quality_overall=None, + has_benchmark=False, +) ALL_MODELS = [SMALL_FAST, MEDIUM_QUALITY, LARGE_SLOW, TINY_MODEL, NO_BENCHMARK] @@ -106,7 +139,9 @@ def test_results_sorted_by_composite_score_descending(self) -> None: def test_high_quality_model_ranks_higher_in_balanced(self) -> None: """In balanced intent, a high-quality model outranks a fast-but-low-quality one.""" scored = score_and_filter( - [SMALL_FAST, MEDIUM_QUALITY], "balanced", budget_gb=20.0, + [SMALL_FAST, MEDIUM_QUALITY], + "balanced", + budget_gb=20.0, ) # MEDIUM_QUALITY has 0.95 quality vs SMALL_FAST's 0.60 @@ -115,12 +150,22 @@ def test_high_quality_model_ranks_higher_in_balanced(self) -> None: def test_tool_calling_model_ranks_higher_in_agent_fleet(self) -> None: """In agent-fleet intent, tool calling capability boosts ranking.""" - no_tools = _model("NoTools-8B", "mlx-community/NoTools-8B-4bit", - gen_tps=50.0, memory_gb=5.0, quality_overall=0.90, - tool_calling=False) - with_tools = _model("WithTools-8B", "mlx-community/WithTools-8B-4bit", - gen_tps=50.0, memory_gb=5.0, quality_overall=0.90, - tool_calling=True) + no_tools = _model( + "NoTools-8B", + "mlx-community/NoTools-8B-4bit", + gen_tps=50.0, + memory_gb=5.0, + quality_overall=0.90, + tool_calling=False, + ) + with_tools = _model( + "WithTools-8B", + "mlx-community/WithTools-8B-4bit", + gen_tps=50.0, + memory_gb=5.0, + quality_overall=0.90, + tool_calling=True, + ) scored = score_and_filter([no_tools, with_tools], "agent-fleet", budget_gb=20.0) @@ -149,7 +194,9 @@ class TestSelectDefaults: def test_selects_up_to_two_models_within_budget(self) -> None: """Standard + fast are selected if both fit.""" scored = score_and_filter( - [SMALL_FAST, MEDIUM_QUALITY, TINY_MODEL], "balanced", budget_gb=10.0, + [SMALL_FAST, MEDIUM_QUALITY, TINY_MODEL], + "balanced", + budget_gb=10.0, ) result = select_defaults(scored, budget_gb=10.0) @@ -170,7 +217,9 @@ def test_total_memory_within_budget(self) -> None: def test_single_model_when_budget_tight(self) -> None: """Only one model selected if budget only fits one.""" scored = score_and_filter( - [MEDIUM_QUALITY, SMALL_FAST], "balanced", budget_gb=6.0, + [MEDIUM_QUALITY, SMALL_FAST], + "balanced", + budget_gb=6.0, ) result = select_defaults(scored, budget_gb=6.0) @@ -182,7 +231,9 @@ def test_single_model_when_budget_tight(self) -> None: def test_highest_quality_is_recommended(self) -> None: """The top composite_score model is always recommended.""" scored = score_and_filter( - [SMALL_FAST, MEDIUM_QUALITY], "balanced", budget_gb=20.0, + [SMALL_FAST, MEDIUM_QUALITY], + "balanced", + budget_gb=20.0, ) result = select_defaults(scored, budget_gb=20.0) @@ -193,7 +244,9 @@ def test_highest_quality_is_recommended(self) -> None: def test_fastest_model_also_recommended_when_different(self) -> None: """If highest quality != fastest, both are recommended.""" scored = score_and_filter( - [SMALL_FAST, MEDIUM_QUALITY], "balanced", budget_gb=20.0, + [SMALL_FAST, MEDIUM_QUALITY], + "balanced", + budget_gb=20.0, ) result = select_defaults(scored, budget_gb=20.0) @@ -237,10 +290,12 @@ def test_single_model_gets_standard_tier(self) -> None: def test_two_models_get_standard_and_fast(self) -> None: """Two models → 'standard' (highest composite) and 'fast' (highest speed).""" - tiers = assign_tiers([ - self._scored(MEDIUM_QUALITY, score=0.9), - self._scored(SMALL_FAST, score=0.7), - ]) + tiers = assign_tiers( + [ + self._scored(MEDIUM_QUALITY, score=0.9), + self._scored(SMALL_FAST, score=0.7), + ] + ) assert len(tiers) == 2 tier_names = {t.tier_name for t in tiers} @@ -251,11 +306,13 @@ def test_two_models_get_standard_and_fast(self) -> None: def test_three_models_third_gets_added_tier(self) -> None: """Third model gets 'added-N' tier name.""" - tiers = assign_tiers([ - self._scored(MEDIUM_QUALITY, score=0.9), - self._scored(SMALL_FAST, score=0.7), - self._scored(TINY_MODEL, score=0.5), - ]) + tiers = assign_tiers( + [ + self._scored(MEDIUM_QUALITY, score=0.9), + self._scored(SMALL_FAST, score=0.7), + self._scored(TINY_MODEL, score=0.5), + ] + ) assert len(tiers) == 3 tier_names = [t.tier_name for t in tiers] @@ -275,7 +332,7 @@ def test_empty_input_returns_empty(self) -> None: class TestGenerateConfig: """generate_config writes valid stack and LiteLLM YAML files.""" - @pytest.fixture() + @pytest.fixture def _mock_home(self, mlx_stack_home: Path, monkeypatch: pytest.MonkeyPatch) -> Path: """Use isolated mlx_stack_home and mock config reads.""" monkeypatch.setattr( @@ -302,8 +359,11 @@ def test_generates_valid_stack_yaml(self, _mock_home: Path) -> None: TierMapping(tier_name="fast", model=SMALL_FAST), ] - stack_path, litellm_path = generate_config( - self._profile(), "balanced", tiers, budget_gb=25.0, + stack_path, _litellm_path = generate_config( + self._profile(), + "balanced", + tiers, + budget_gb=25.0, ) assert stack_path.exists() @@ -325,7 +385,10 @@ def test_generates_valid_litellm_yaml(self, _mock_home: Path) -> None: ] _, litellm_path = generate_config( - self._profile(), "balanced", tiers, budget_gb=25.0, + self._profile(), + "balanced", + tiers, + budget_gb=25.0, ) assert litellm_path.exists() @@ -340,7 +403,10 @@ def test_tool_calling_model_gets_auto_tool_choice(self, _mock_home: Path) -> Non tiers = [TierMapping(tier_name="standard", model=SMALL_FAST)] # has tool_calling=True stack_path, _ = generate_config( - self._profile(), "balanced", tiers, budget_gb=25.0, + self._profile(), + "balanced", + tiers, + budget_gb=25.0, ) stack = yaml.safe_load(stack_path.read_text()) @@ -353,7 +419,10 @@ def test_non_tool_calling_model_no_auto_tool_choice(self, _mock_home: Path) -> N tiers = [TierMapping(tier_name="standard", model=no_tools)] stack_path, _ = generate_config( - self._profile(), "balanced", tiers, budget_gb=25.0, + self._profile(), + "balanced", + tiers, + budget_gb=25.0, ) stack = yaml.safe_load(stack_path.read_text()) @@ -366,7 +435,10 @@ def test_ports_skip_litellm_port(self, _mock_home: Path) -> None: tiers = [TierMapping(tier_name=f"tier-{i}", model=TINY_MODEL) for i in range(5)] stack_path, _ = generate_config( - self._profile(), "balanced", tiers, budget_gb=25.0, + self._profile(), + "balanced", + tiers, + budget_gb=25.0, ) stack = yaml.safe_load(stack_path.read_text()) diff --git a/tests/unit/test_ops_cross_area.py b/tests/unit/test_ops_cross_area.py index 4d53e2d..4a8c8a6 100644 --- a/tests/unit/test_ops_cross_area.py +++ b/tests/unit/test_ops_cross_area.py @@ -17,6 +17,7 @@ from __future__ import annotations +import contextlib import gzip import os import plistlib @@ -44,7 +45,7 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: """Create a test stack definition and return it.""" stacks_dir = mlx_stack_home / "stacks" @@ -88,7 +89,7 @@ def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: return stack -@pytest.fixture() +@pytest.fixture def pids_dir(mlx_stack_home: Path) -> Path: """Create and return the pids directory.""" d = mlx_stack_home / "pids" @@ -96,7 +97,7 @@ def pids_dir(mlx_stack_home: Path) -> Path: return d -@pytest.fixture() +@pytest.fixture def logs_dir(mlx_stack_home: Path) -> Path: """Create and return the logs directory.""" d = mlx_stack_home / "logs" @@ -104,7 +105,7 @@ def logs_dir(mlx_stack_home: Path) -> Path: return d -@pytest.fixture() +@pytest.fixture def config_file(mlx_stack_home: Path) -> Path: """Create a config file with rotation settings.""" config_path = mlx_stack_home / "config.yaml" @@ -347,14 +348,12 @@ def test_follow_detects_truncation_and_continues( captured: list[str] = [] def follow_thread() -> None: - try: + with contextlib.suppress(Exception): follow_log( log, num_lines=0, output_callback=lambda text: captured.append(text), ) - except Exception: - pass thread = threading.Thread(target=follow_thread, daemon=True) thread.start() @@ -388,14 +387,13 @@ def follow_thread() -> None: # Stop the follow thread by sending KeyboardInterrupt # (follow_log catches KeyboardInterrupt for clean exit) import ctypes + assert thread.ident is not None - try: + with contextlib.suppress(Exception): ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_ulong(thread.ident), ctypes.py_object(KeyboardInterrupt), ) - except Exception: - pass thread.join(timeout=3) # The follow output should contain the pre-rotation and @@ -414,14 +412,12 @@ def test_follow_handles_empty_file_after_truncation( captured: list[str] = [] def follow_thread() -> None: - try: + with contextlib.suppress(Exception): follow_log( log, num_lines=0, output_callback=lambda text: captured.append(text), ) - except Exception: - pass thread = threading.Thread(target=follow_thread, daemon=True) thread.start() @@ -439,14 +435,13 @@ def follow_thread() -> None: time.sleep(1.0) import ctypes + assert thread.ident is not None - try: + with contextlib.suppress(Exception): ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_ulong(thread.ident), ctypes.py_object(KeyboardInterrupt), ) - except Exception: - pass thread.join(timeout=3) combined = "\n".join(captured) @@ -468,14 +463,12 @@ def test_follow_continues_after_multiple_rotations( captured: list[str] = [] def follow_thread() -> None: - try: + with contextlib.suppress(Exception): follow_log( log, num_lines=0, output_callback=lambda text: captured.append(text), ) - except Exception: - pass def wait_for_content(marker: str, timeout: float = 5.0) -> bool: """Wait until marker appears in captured output.""" @@ -516,14 +509,13 @@ def wait_for_content(marker: str, timeout: float = 5.0) -> bool: assert wait_for_content("round-3-content") import ctypes + assert thread.ident is not None - try: + with contextlib.suppress(Exception): ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_ulong(thread.ident), ctypes.py_object(KeyboardInterrupt), ) - except Exception: - pass thread.join(timeout=3) combined = "\n".join(captured) @@ -1009,14 +1001,24 @@ def test_down_then_watchdog_then_up_then_watchdog( mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ] mock_status.return_value = mock_result @@ -1027,8 +1029,11 @@ def test_down_then_watchdog_then_up_then_watchdog( }.get(key, "") result = poll_cycle( - state=state, stack=stack_definition, - interval=30, max_restarts=5, restart_delay=10, + state=state, + stack=stack_definition, + interval=30, + max_restarts=5, + restart_delay=10, ) assert result.restarts_attempted == 0 @@ -1043,14 +1048,24 @@ def test_down_then_watchdog_then_up_then_watchdog( mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.HEALTHY, uptime=10.0, uptime_display="10s", - response_time=0.05, pid=20001, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.HEALTHY, + uptime=10.0, + uptime_display="10s", + response_time=0.05, + pid=20001, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.HEALTHY, uptime=8.0, uptime_display="8s", - response_time=0.08, pid=20002, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.HEALTHY, + uptime=8.0, + uptime_display="8s", + response_time=0.08, + pid=20002, ), ] mock_status.return_value = mock_result @@ -1061,8 +1076,11 @@ def test_down_then_watchdog_then_up_then_watchdog( }.get(key, "") result = poll_cycle( - state=state, stack=stack_definition, - interval=30, max_restarts=5, restart_delay=10, + state=state, + stack=stack_definition, + interval=30, + max_restarts=5, + restart_delay=10, ) assert result.restarts_attempted == 0 @@ -1091,9 +1109,14 @@ def test_watchdog_lock_only_during_restart( mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.HEALTHY, uptime=60.0, uptime_display="1m", - response_time=0.05, pid=10001, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.HEALTHY, + uptime=60.0, + uptime_display="1m", + response_time=0.05, + pid=10001, ), ] mock_status.return_value = mock_result @@ -1104,8 +1127,11 @@ def test_watchdog_lock_only_during_restart( }.get(key, "") poll_cycle( - state=state, stack=stack_definition, - interval=30, max_restarts=5, restart_delay=10, + state=state, + stack=stack_definition, + interval=30, + max_restarts=5, + restart_delay=10, ) # acquire_lock should NOT have been called during a healthy poll @@ -1189,14 +1215,24 @@ def test_full_lifecycle_sequence( mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.HEALTHY, uptime=300.0, uptime_display="5m", - response_time=0.05, pid=30001, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.HEALTHY, + uptime=300.0, + uptime_display="5m", + response_time=0.05, + pid=30001, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.HEALTHY, uptime=300.0, uptime_display="5m", - response_time=0.08, pid=30002, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.HEALTHY, + uptime=300.0, + uptime_display="5m", + response_time=0.08, + pid=30002, ), ] mock_status.return_value = mock_result @@ -1225,14 +1261,12 @@ def test_full_lifecycle_sequence( follow_captured: list[str] = [] def follow_thread_fn() -> None: - try: + with contextlib.suppress(Exception): follow_log( fast_log, num_lines=0, output_callback=lambda text: follow_captured.append(text), ) - except Exception: - pass follow_thread = threading.Thread(target=follow_thread_fn, daemon=True) follow_thread.start() @@ -1253,14 +1287,13 @@ def follow_thread_fn() -> None: # Stop the follow thread import ctypes + assert follow_thread.ident is not None - try: + with contextlib.suppress(Exception): ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_ulong(follow_thread.ident), ctypes.py_object(KeyboardInterrupt), ) - except Exception: - pass follow_thread.join(timeout=3) # Verify follow detected the new line @@ -1283,14 +1316,24 @@ def follow_thread_fn() -> None: mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.CRASHED, uptime=None, uptime_display="-", - response_time=None, pid=30001, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.CRASHED, + uptime=None, + uptime_display="-", + response_time=None, + pid=30001, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.HEALTHY, uptime=600.0, uptime_display="10m", - response_time=0.08, pid=30002, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.HEALTHY, + uptime=600.0, + uptime_display="10m", + response_time=0.08, + pid=30002, ), ] mock_status.return_value = mock_result @@ -1333,14 +1376,24 @@ def follow_thread_fn() -> None: mock_result.no_stack = False mock_result.services = [ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ] mock_status.return_value = mock_result diff --git a/tests/unit/test_paths.py b/tests/unit/test_paths.py index 5e1ac74..0ed88d2 100644 --- a/tests/unit/test_paths.py +++ b/tests/unit/test_paths.py @@ -27,9 +27,7 @@ def test_uses_env_var(self, mlx_stack_home: Path) -> None: result = get_data_home() assert result == mlx_stack_home - def test_default_is_home_dot_mlx_stack( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_default_is_home_dot_mlx_stack(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("MLX_STACK_HOME", raising=False) result = get_data_home() assert result == Path.home() / ".mlx-stack" diff --git a/tests/unit/test_process.py b/tests/unit/test_process.py index 4f4129e..046687c 100644 --- a/tests/unit/test_process.py +++ b/tests/unit/test_process.py @@ -251,10 +251,9 @@ def test_acquires_and_releases(self, mlx_stack_home: Path) -> None: pass def test_concurrent_lock_raises(self, mlx_stack_home: Path) -> None: - with acquire_lock(): - with pytest.raises(LockError, match="Another mlx-stack operation"): - with acquire_lock(): - pass # pragma: no cover + with acquire_lock(), pytest.raises(LockError, match="Another mlx-stack operation"): + with acquire_lock(): + pass # pragma: no cover def test_creates_data_home(self, clean_mlx_stack_home: Path) -> None: with acquire_lock(): @@ -332,9 +331,7 @@ def test_immediately_healthy(self, mock_check: MagicMock) -> None: @patch("mlx_stack.core.process.time.sleep") @patch("mlx_stack.core.process.http_health_check") - def test_healthy_after_retries( - self, mock_check: MagicMock, mock_sleep: MagicMock - ) -> None: + def test_healthy_after_retries(self, mock_check: MagicMock, mock_sleep: MagicMock) -> None: # First 2 calls fail, third succeeds mock_check.side_effect = [ HealthCheckResult(healthy=False, response_time=None, status_code=None), @@ -363,14 +360,14 @@ def test_timeout_raises( ) # Simulate time passing past the deadline mock_monotonic.side_effect = [ - 0.0, # deadline = 0 + 2 = 2 - 0.5, # first iteration check - 0.5, # per_request_timeout check - 1.0, # after first check - 1.5, # sleep calculation - 1.8, # second iteration check - 1.8, # per_request_timeout check - 2.5, # after second check - past deadline + 0.0, # deadline = 0 + 2 = 2 + 0.5, # first iteration check + 0.5, # per_request_timeout check + 1.0, # after first check + 1.5, # sleep calculation + 1.8, # second iteration check + 1.8, # per_request_timeout check + 2.5, # after second check - past deadline ] with pytest.raises(HealthCheckError, match="timed out"): @@ -455,7 +452,9 @@ class TestCheckPortConflict: @patch("mlx_stack.core.process._find_pid_on_port", return_value=(42, "python")) @patch("mlx_stack.core.process._socket_bind_check", return_value=True) def test_port_in_use_with_owner( - self, mock_bind: MagicMock, mock_find: MagicMock, + self, + mock_bind: MagicMock, + mock_find: MagicMock, ) -> None: """Port occupied, owner identified via psutil.""" result = check_port_conflict(8000) @@ -471,7 +470,9 @@ def test_port_available(self, mock_bind: MagicMock) -> None: @patch("mlx_stack.core.process._find_pid_on_port", return_value=None) @patch("mlx_stack.core.process._socket_bind_check", return_value=True) def test_port_in_use_unknown_owner( - self, mock_bind: MagicMock, mock_find: MagicMock, + self, + mock_bind: MagicMock, + mock_find: MagicMock, ) -> None: """Port occupied but owner can't be identified (e.g., macOS permission).""" result = check_port_conflict(8000) @@ -481,7 +482,9 @@ def test_port_in_use_unknown_owner( @patch("mlx_stack.core.process._find_pid_on_port", return_value=(42, "")) @patch("mlx_stack.core.process._socket_bind_check", return_value=True) def test_process_vanished( - self, mock_bind: MagicMock, mock_find: MagicMock, + self, + mock_bind: MagicMock, + mock_find: MagicMock, ) -> None: """Process vanished between detection and lookup.""" result = check_port_conflict(8000) @@ -704,8 +707,8 @@ def test_escalate_to_sigkill( # Process stays alive through grace period, then dies after SIGKILL mock_alive.side_effect = [True, True, True, False] mock_monotonic.side_effect = [ - 0.0, # deadline = 10.0 - 5.0, # still within grace + 0.0, # deadline = 10.0 + 5.0, # still within grace 11.0, # past grace → SIGKILL ] diff --git a/tests/unit/test_robustness_fixes.py b/tests/unit/test_robustness_fixes.py index aafb760..be6ddcb 100644 --- a/tests/unit/test_robustness_fixes.py +++ b/tests/unit/test_robustness_fixes.py @@ -310,8 +310,8 @@ def test_sigkill_confirmed_dead_returns_confirmed( # Process stays alive through grace, dies after SIGKILL mock_alive.side_effect = [True, True, True, False] mock_monotonic.side_effect = [ - 0.0, # deadline = 10.0 - 5.0, # still within grace + 0.0, # deadline = 10.0 + 5.0, # still within grace 11.0, # past grace → SIGKILL ] @@ -334,7 +334,7 @@ def test_sigkill_process_still_alive_returns_not_confirmed( # Process never dies (survives SIGKILL — e.g. zombie or kernel hold) mock_alive.return_value = True mock_monotonic.side_effect = [ - 0.0, # deadline = 10.0 + 0.0, # deadline = 10.0 11.0, # past grace → SIGKILL ] diff --git a/tests/unit/test_scoring.py b/tests/unit/test_scoring.py index 2c20439..55cc359 100644 --- a/tests/unit/test_scoring.py +++ b/tests/unit/test_scoring.py @@ -125,13 +125,13 @@ def _make_profile( ) -@pytest.fixture() +@pytest.fixture def m4_max_128_profile() -> HardwareProfile: """M4 Max 128 GB profile — matches catalog benchmark key 'm4-max-128'.""" return _make_profile() -@pytest.fixture() +@pytest.fixture def m4_pro_48_profile() -> HardwareProfile: """M4 Pro 48 GB profile — matches catalog benchmark key 'm4-pro-48'.""" return _make_profile( @@ -142,7 +142,7 @@ def m4_pro_48_profile() -> HardwareProfile: ) -@pytest.fixture() +@pytest.fixture def unknown_profile() -> HardwareProfile: """Unknown hardware profile — no catalog benchmark match.""" return _make_profile( @@ -154,7 +154,7 @@ def unknown_profile() -> HardwareProfile: ) -@pytest.fixture() +@pytest.fixture def small_memory_profile() -> HardwareProfile: """Small memory profile (32 GB) for tier count tests.""" return _make_profile( @@ -165,13 +165,13 @@ def small_memory_profile() -> HardwareProfile: ) -@pytest.fixture() +@pytest.fixture def basic_entry() -> CatalogEntry: """A basic catalog entry with standard benchmarks.""" return _make_entry() -@pytest.fixture() +@pytest.fixture def sample_catalog() -> list[CatalogEntry]: """A representative catalog for testing scoring and tier assignment.""" return [ @@ -433,14 +433,14 @@ def test_different_intents_have_different_weights(self) -> None: assert balanced.tool_calling != fleet.tool_calling def test_valid_intents(self) -> None: - assert VALID_INTENTS == {"balanced", "agent-fleet"} + assert {"balanced", "agent-fleet"} == VALID_INTENTS def test_all_valid_intents_have_weights(self) -> None: for intent in VALID_INTENTS: assert intent in INTENT_WEIGHTS def test_weights_must_sum_to_one(self) -> None: - with pytest.raises(ValueError, match="must sum to 1.0"): + with pytest.raises(ValueError, match=r"must sum to 1\.0"): IntentWeights(speed=0.5, quality=0.5, tool_calling=0.5, memory_efficiency=0.5) @@ -455,9 +455,7 @@ class TestBenchmarkResolution: def test_direct_match( self, basic_entry: CatalogEntry, m4_max_128_profile: HardwareProfile ) -> None: - gen_tps, memory_gb, is_estimated = _resolve_benchmark( - basic_entry, m4_max_128_profile - ) + gen_tps, memory_gb, is_estimated = _resolve_benchmark(basic_entry, m4_max_128_profile) assert gen_tps == 77.0 assert memory_gb == 5.5 assert is_estimated is False @@ -465,9 +463,7 @@ def test_direct_match( def test_direct_match_different_profile( self, basic_entry: CatalogEntry, m4_pro_48_profile: HardwareProfile ) -> None: - gen_tps, memory_gb, is_estimated = _resolve_benchmark( - basic_entry, m4_pro_48_profile - ) + gen_tps, memory_gb, is_estimated = _resolve_benchmark(basic_entry, m4_pro_48_profile) assert gen_tps == 52.0 assert memory_gb == 5.5 assert is_estimated is False @@ -488,9 +484,7 @@ def test_saved_benchmarks_override( def test_bandwidth_ratio_estimation( self, basic_entry: CatalogEntry, unknown_profile: HardwareProfile ) -> None: - gen_tps, memory_gb, is_estimated = _resolve_benchmark( - basic_entry, unknown_profile - ) + gen_tps, memory_gb, is_estimated = _resolve_benchmark(basic_entry, unknown_profile) # Unknown M6 with 1000 GB/s, reference m4-pro-48 has 273 GB/s, gen_tps=52.0 # Expected: 52.0 * (1000 / 273) ≈ 190.5 assert is_estimated is True @@ -612,9 +606,7 @@ def test_high_bandwidth_hardware_speed_score_clamped(self) -> None: entry = _make_entry( model_id="fast-model", benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=480.0, gen_tps=185.0, memory_gb=0.8 - ), + "m4-max-128": BenchmarkResult(prompt_tps=480.0, gen_tps=185.0, memory_gb=0.8), }, ) # Very high bandwidth hardware — 4x the reference m4-max-128 (546 GB/s) @@ -669,9 +661,7 @@ def test_filters_by_memory_budget( ) -> None: """Models exceeding budget should be excluded.""" budget_gb = 25.0 # Should exclude premium-72b (42GB) and huge-100b (60GB) - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", budget_gb - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", budget_gb) model_ids = {m.entry.id for m in scored} assert "premium-72b" not in model_ids assert "huge-100b" not in model_ids @@ -685,9 +675,7 @@ def test_sorted_by_composite_score_descending( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) for i in range(len(scored) - 1): assert scored[i].composite_score >= scored[i + 1].composite_score @@ -708,9 +696,7 @@ def test_all_models_exceed_budget( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 0.1 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 0.1) assert scored == [] def test_different_intents_produce_different_scores( @@ -718,12 +704,8 @@ def test_different_intents_produce_different_scores( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - balanced = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) - fleet = score_and_filter( - sample_catalog, m4_max_128_profile, "agent-fleet", 51.2 - ) + balanced = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) + fleet = score_and_filter(sample_catalog, m4_max_128_profile, "agent-fleet", 51.2) # Get scores for the same model under different intents balanced_scores = {m.entry.id: m.composite_score for m in balanced} @@ -731,10 +713,7 @@ def test_different_intents_produce_different_scores( # At least one model should have a different score common_ids = set(balanced_scores) & set(fleet_scores) - assert any( - not math.isclose(balanced_scores[mid], fleet_scores[mid]) - for mid in common_ids - ) + assert any(not math.isclose(balanced_scores[mid], fleet_scores[mid]) for mid in common_ids) def test_deterministic_scoring( self, @@ -742,14 +721,10 @@ def test_deterministic_scoring( m4_max_128_profile: HardwareProfile, ) -> None: """Same inputs must always produce same outputs.""" - result1 = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) - result2 = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + result1 = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) + result2 = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) assert len(result1) == len(result2) - for m1, m2 in zip(result1, result2): + for m1, m2 in zip(result1, result2, strict=False): assert m1.entry.id == m2.entry.id assert m1.composite_score == m2.composite_score @@ -760,7 +735,10 @@ def test_saved_benchmarks_used( entry = _make_entry() saved = {"test-model": {"gen_tps": 200.0, "memory_gb": 5.5}} scored = score_and_filter( - [entry], m4_max_128_profile, "balanced", 51.2, + [entry], + m4_max_128_profile, + "balanced", + 51.2, saved_benchmarks=saved, ) assert len(scored) == 1 @@ -786,9 +764,7 @@ def test_standard_is_highest_composite_score( The composite score is intent-weighted, so standard tier reflects the intent-specific best model rather than the raw quality leader. """ - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) standard = next(t for t in tiers if t.tier == TIER_STANDARD) # Standard tier should be the model with the highest composite score @@ -800,9 +776,7 @@ def test_fast_is_highest_gen_tps( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) fast = next(t for t in tiers if t.tier == TIER_FAST) # fast-0.8b should be the fast tier (185 gen_tps on m4-max-128) @@ -813,9 +787,7 @@ def test_longctx_is_architecturally_diverse( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) longctx_tiers = [t for t in tiers if t.tier == TIER_LONGCTX] assert len(longctx_tiers) == 1 @@ -826,16 +798,12 @@ def test_tiers_use_distinct_models( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) model_ids = [t.model.entry.id for t in tiers] assert len(model_ids) == len(set(model_ids)) - def test_small_memory_fewer_tiers( - self, small_memory_profile: HardwareProfile - ) -> None: + def test_small_memory_fewer_tiers(self, small_memory_profile: HardwareProfile) -> None: """Small memory systems (budget < 16 GB) should get 1-2 tiers.""" # 32 GB * 40% = 12.8 GB budget — below 16 GB threshold catalog = [ @@ -867,9 +835,7 @@ def test_small_memory_fewer_tiers( ), ] budget_gb = compute_memory_budget(32) # 12.8 GB - scored = score_and_filter( - catalog, small_memory_profile, "balanced", budget_gb - ) + scored = score_and_filter(catalog, small_memory_profile, "balanced", budget_gb) tiers = assign_tiers(scored, budget_gb) # Should have at most 2 tiers (no longctx for budget < 16) tier_names = [t.tier for t in tiers] @@ -883,9 +849,7 @@ def test_large_memory_up_to_3_tiers( ) -> None: """Large memory systems should get up to 3 tiers.""" budget_gb = compute_memory_budget(128) # 51.2 GB - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", budget_gb - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", budget_gb) tiers = assign_tiers(scored, budget_gb) assert len(tiers) == 3 tier_names = [t.tier for t in tiers] @@ -917,15 +881,11 @@ def test_no_longctx_architecture_available( model_id="fast-model", quality_overall=42, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=480.0, gen_tps=185.0, memory_gb=0.8 - ), + "m4-max-128": BenchmarkResult(prompt_tps=480.0, gen_tps=185.0, memory_gb=0.8), }, ), ] - scored = score_and_filter( - catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) tier_names = [t.tier for t in tiers] assert TIER_LONGCTX not in tier_names @@ -937,9 +897,7 @@ def test_tier_order( m4_max_128_profile: HardwareProfile, ) -> None: """Tiers should be ordered: standard, fast, longctx.""" - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) tier_names = [t.tier for t in tiers] assert tier_names[0] == TIER_STANDARD @@ -953,9 +911,7 @@ def test_quant_is_int4( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - scored = score_and_filter( - sample_catalog, m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter(sample_catalog, m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) for tier in tiers: assert tier.quant == "int4" @@ -995,9 +951,7 @@ def test_budget_gb_override( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - result = recommend( - sample_catalog, m4_max_128_profile, budget_gb_override=30.0 - ) + result = recommend(sample_catalog, m4_max_128_profile, budget_gb_override=30.0) assert result.memory_budget_gb == 30.0 for scored in result.all_scored: assert scored.memory_gb <= 30.0 @@ -1007,9 +961,7 @@ def test_budget_pct( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - result = recommend( - sample_catalog, m4_max_128_profile, budget_pct=60 - ) + result = recommend(sample_catalog, m4_max_128_profile, budget_pct=60) assert result.memory_budget_gb == pytest.approx(76.8) def test_agent_fleet_intent( @@ -1017,9 +969,7 @@ def test_agent_fleet_intent( sample_catalog: list[CatalogEntry], m4_max_128_profile: HardwareProfile, ) -> None: - result = recommend( - sample_catalog, m4_max_128_profile, intent="agent-fleet" - ) + result = recommend(sample_catalog, m4_max_128_profile, intent="agent-fleet") assert result.intent == "agent-fleet" def test_different_intents_produce_different_results( @@ -1036,10 +986,7 @@ def test_different_intents_produce_different_results( # At least one model should have different scores common_ids = set(balanced_scores) & set(fleet_scores) - assert any( - not math.isclose(balanced_scores[mid], fleet_scores[mid]) - for mid in common_ids - ) + assert any(not math.isclose(balanced_scores[mid], fleet_scores[mid]) for mid in common_ids) def test_invalid_intent_raises( self, @@ -1073,9 +1020,7 @@ def test_saved_benchmarks_override_catalog( ) -> None: entry = _make_entry() saved = {"test-model": {"gen_tps": 200.0, "memory_gb": 5.5}} - result = recommend( - [entry], m4_max_128_profile, saved_benchmarks=saved - ) + result = recommend([entry], m4_max_128_profile, saved_benchmarks=saved) assert len(result.all_scored) == 1 assert result.all_scored[0].gen_tps == 200.0 assert result.all_scored[0].is_estimated is False @@ -1090,37 +1035,31 @@ def test_deterministic( r2 = recommend(sample_catalog, m4_max_128_profile) assert len(r1.tiers) == len(r2.tiers) - for t1, t2 in zip(r1.tiers, r2.tiers): + for t1, t2 in zip(r1.tiers, r2.tiers, strict=False): assert t1.tier == t2.tier assert t1.model.entry.id == t2.model.entry.id assert t1.model.composite_score == t2.model.composite_score assert len(r1.all_scored) == len(r2.all_scored) - for m1, m2 in zip(r1.all_scored, r2.all_scored): + for m1, m2 in zip(r1.all_scored, r2.all_scored, strict=False): assert m1.entry.id == m2.entry.id assert m1.composite_score == m2.composite_score - def test_small_budget_gives_fewer_tiers( - self, small_memory_profile: HardwareProfile - ) -> None: + def test_small_budget_gives_fewer_tiers(self, small_memory_profile: HardwareProfile) -> None: """On small memory, recommendation produces fewer tiers.""" catalog = [ _make_entry( model_id="small-model", quality_overall=65, benchmarks={ - "m4-pro-48": BenchmarkResult( - prompt_tps=180.0, gen_tps=88.0, memory_gb=2.5 - ), + "m4-pro-48": BenchmarkResult(prompt_tps=180.0, gen_tps=88.0, memory_gb=2.5), }, ), _make_entry( model_id="mid-model", quality_overall=70, benchmarks={ - "m4-pro-48": BenchmarkResult( - prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5 - ), + "m4-pro-48": BenchmarkResult(prompt_tps=95.0, gen_tps=52.0, memory_gb=5.5), }, ), ] @@ -1141,9 +1080,7 @@ def test_zero_models_fit( m4_max_128_profile: HardwareProfile, ) -> None: """When no models fit the budget, result has empty tiers.""" - result = recommend( - sample_catalog, m4_max_128_profile, budget_gb_override=0.1 - ) + result = recommend(sample_catalog, m4_max_128_profile, budget_gb_override=0.1) assert result.tiers == [] assert result.all_scored == [] @@ -1189,9 +1126,7 @@ def test_real_catalog_balanced_vs_agent_fleet(self) -> None: common = set(balanced_scores) & set(fleet_scores) differences = [ - mid - for mid in common - if not math.isclose(balanced_scores[mid], fleet_scores[mid]) + mid for mid in common if not math.isclose(balanced_scores[mid], fleet_scores[mid]) ] assert len(differences) > 0, "Balanced and agent-fleet should differ" @@ -1246,7 +1181,7 @@ def test_real_catalog_deterministic(self) -> None: r2 = recommend(catalog, profile) assert len(r1.tiers) == len(r2.tiers) - for t1, t2 in zip(r1.tiers, r2.tiers): + for t1, t2 in zip(r1.tiers, r2.tiers, strict=False): assert t1.tier == t2.tier assert t1.model.entry.id == t2.model.entry.id @@ -1273,8 +1208,7 @@ def test_real_catalog_intents_produce_different_tier_assignments(self) -> None: # The two intents must produce at least one different tier assignment common_tiers = set(balanced_tier_models) & set(fleet_tier_models) has_difference = any( - balanced_tier_models[tier] != fleet_tier_models[tier] - for tier in common_tiers + balanced_tier_models[tier] != fleet_tier_models[tier] for tier in common_tiers ) assert has_difference, ( f"balanced and agent-fleet should differ in at least one tier, " @@ -1323,9 +1257,7 @@ def test_different_intents_different_standard_tier( tool_calling=False, tool_call_parser=None, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=40.0, gen_tps=25.0, memory_gb=20.0 - ), + "m4-max-128": BenchmarkResult(prompt_tps=40.0, gen_tps=25.0, memory_gb=20.0), }, ) model_b = _make_entry( @@ -1337,20 +1269,14 @@ def test_different_intents_different_standard_tier( quality_instruction=52, tool_calling=True, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=42.0, gen_tps=26.0, memory_gb=20.0 - ), + "m4-max-128": BenchmarkResult(prompt_tps=42.0, gen_tps=26.0, memory_gb=20.0), }, ) catalog = [model_a, model_b] budget_gb = 51.2 - balanced_scored = score_and_filter( - catalog, m4_max_128_profile, "balanced", budget_gb - ) - fleet_scored = score_and_filter( - catalog, m4_max_128_profile, "agent-fleet", budget_gb - ) + balanced_scored = score_and_filter(catalog, m4_max_128_profile, "balanced", budget_gb) + fleet_scored = score_and_filter(catalog, m4_max_128_profile, "agent-fleet", budget_gb) balanced_tiers = assign_tiers(balanced_scored, budget_gb) fleet_tiers = assign_tiers(fleet_scored, budget_gb) @@ -1378,9 +1304,7 @@ def test_standard_tier_uses_composite_not_raw_quality( model_id="slow-quality", quality_overall=95, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=5.0, gen_tps=3.0, memory_gb=48.0 - ), + "m4-max-128": BenchmarkResult(prompt_tps=5.0, gen_tps=3.0, memory_gb=48.0), }, ) # Model with moderate quality but good speed and efficiency @@ -1388,15 +1312,11 @@ def test_standard_tier_uses_composite_not_raw_quality( model_id="balanced-model", quality_overall=70, benchmarks={ - "m4-max-128": BenchmarkResult( - prompt_tps=120.0, gen_tps=70.0, memory_gb=5.0 - ), + "m4-max-128": BenchmarkResult(prompt_tps=120.0, gen_tps=70.0, memory_gb=5.0), }, ) - scored = score_and_filter( - [high_q, balanced_model], m4_max_128_profile, "balanced", 51.2 - ) + scored = score_and_filter([high_q, balanced_model], m4_max_128_profile, "balanced", 51.2) tiers = assign_tiers(scored, 51.2) standard = next(t for t in tiers if t.tier == TIER_STANDARD) @@ -1419,7 +1339,10 @@ def test_exclude_gated_filters_gated_models(self, m4_max_128_profile: HardwarePr gated_model = _make_entry(model_id="gated-model", name="Gated Model", gated=True) scored = score_and_filter( - [open_model, gated_model], m4_max_128_profile, "balanced", 51.2, + [open_model, gated_model], + m4_max_128_profile, + "balanced", + 51.2, exclude_gated=True, ) scored_ids = {m.entry.id for m in scored} @@ -1432,7 +1355,10 @@ def test_exclude_gated_false_includes_all(self, m4_max_128_profile: HardwareProf gated_model = _make_entry(model_id="gated-model", name="Gated Model", gated=True) scored = score_and_filter( - [open_model, gated_model], m4_max_128_profile, "balanced", 51.2, + [open_model, gated_model], + m4_max_128_profile, + "balanced", + 51.2, exclude_gated=False, ) scored_ids = {m.entry.id for m in scored} @@ -1443,12 +1369,15 @@ def test_recommend_exclude_gated(self, m4_max_128_profile: HardwareProfile) -> N """Gated models excluded from tier assignments via recommend().""" open_model = _make_entry(model_id="open-model", name="Open Model") gated_model = _make_entry( - model_id="gated-model", name="Gated Model", - quality_overall=99, gated=True, + model_id="gated-model", + name="Gated Model", + quality_overall=99, + gated=True, ) result = recommend( - [open_model, gated_model], m4_max_128_profile, + [open_model, gated_model], + m4_max_128_profile, exclude_gated=True, ) tier_ids = {t.model.entry.id for t in result.tiers} diff --git a/tests/unit/test_watchdog.py b/tests/unit/test_watchdog.py index f022691..7b551c6 100644 --- a/tests/unit/test_watchdog.py +++ b/tests/unit/test_watchdog.py @@ -43,7 +43,7 @@ # --------------------------------------------------------------------------- # -@pytest.fixture() +@pytest.fixture def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: """Create a test stack definition and return it.""" stacks_dir = mlx_stack_home / "stacks" @@ -87,7 +87,7 @@ def stack_definition(mlx_stack_home: Path) -> dict[str, Any]: return stack -@pytest.fixture() +@pytest.fixture def pids_dir(mlx_stack_home: Path) -> Path: """Create and return the pids directory.""" d = mlx_stack_home / "pids" @@ -95,7 +95,7 @@ def pids_dir(mlx_stack_home: Path) -> Path: return d -@pytest.fixture() +@pytest.fixture def logs_dir(mlx_stack_home: Path) -> Path: """Create and return the logs directory.""" d = mlx_stack_home / "logs" @@ -382,9 +382,7 @@ def test_no_logs_directory(self, mlx_stack_home: Path) -> None: count = rotate_service_logs() assert count == 0 - def test_rotates_eligible_files( - self, mlx_stack_home: Path, logs_dir: Path - ) -> None: + def test_rotates_eligible_files(self, mlx_stack_home: Path, logs_dir: Path) -> None: # Create a log file exceeding threshold log_file = logs_dir / "fast.log" log_file.write_bytes(b"x" * (1 * 1024 * 1024)) # 1 MB @@ -402,9 +400,7 @@ def test_rotates_eligible_files( # Since max_size_mb=0 means threshold=0 bytes, any non-empty file rotates assert count >= 0 # The actual rotation depends on implementation - def test_skips_non_log_files( - self, mlx_stack_home: Path, logs_dir: Path - ) -> None: + def test_skips_non_log_files(self, mlx_stack_home: Path, logs_dir: Path) -> None: # Create a non-.log file other_file = logs_dir / "fast.log.1.gz" other_file.write_bytes(b"compressed data") @@ -438,14 +434,24 @@ def test_basic_poll_no_crashes( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.HEALTHY, uptime=100.0, uptime_display="1m 40s", - response_time=0.1, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.HEALTHY, + uptime=100.0, + uptime_display="1m 40s", + response_time=0.1, + pid=1234, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ] ) @@ -476,14 +482,24 @@ def test_poll_with_crashed_service_triggers_restart( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.CRASHED, uptime=None, uptime_display="-", - response_time=None, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.CRASHED, + uptime=None, + uptime_display="-", + response_time=None, + pid=1234, ), ServiceStatus( - tier="standard", model="qwen3.5-8b", port=8001, - status=ServiceHealth.HEALTHY, uptime=200.0, uptime_display="3m 20s", - response_time=0.05, pid=5678, + tier="standard", + model="qwen3.5-8b", + port=8001, + status=ServiceHealth.HEALTHY, + uptime=200.0, + uptime_display="3m 20s", + response_time=0.05, + pid=5678, ), ] ) @@ -518,9 +534,14 @@ def test_poll_does_not_restart_stopped_service( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.STOPPED, uptime=None, uptime_display="-", - response_time=None, pid=None, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.STOPPED, + uptime=None, + uptime_display="-", + response_time=None, + pid=None, ), ] ) @@ -556,9 +577,14 @@ def test_poll_flapping_service_not_restarted( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.CRASHED, uptime=None, uptime_display="-", - response_time=None, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.CRASHED, + uptime=None, + uptime_display="-", + response_time=None, + pid=1234, ), ] ) @@ -595,9 +621,14 @@ def test_poll_respects_restart_delay( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.CRASHED, uptime=None, uptime_display="-", - response_time=None, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.CRASHED, + uptime=None, + uptime_display="-", + response_time=None, + pid=1234, ), ] ) @@ -628,9 +659,14 @@ def test_poll_with_failed_restart( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.CRASHED, uptime=None, uptime_display="-", - response_time=None, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.CRASHED, + uptime=None, + uptime_display="-", + response_time=None, + pid=1234, ), ] ) @@ -716,9 +752,14 @@ def test_poll_healthy_resets_consecutive_failures( mock_status = StatusResult( services=[ ServiceStatus( - tier="fast", model="qwen3.5-3b", port=8000, - status=ServiceHealth.HEALTHY, uptime=100.0, uptime_display="1m 40s", - response_time=0.1, pid=1234, + tier="fast", + model="qwen3.5-3b", + port=8000, + status=ServiceHealth.HEALTHY, + uptime=100.0, + uptime_display="1m 40s", + response_time=0.1, + pid=1234, ), ] ) @@ -865,9 +906,7 @@ def test_daemonize_first_fork_failure(self, mlx_stack_home: Path) -> None: class TestRemoveWatchdogPid: """Tests for remove_watchdog_pid.""" - def test_removes_existing_pid_file( - self, mlx_stack_home: Path, pids_dir: Path - ) -> None: + def test_removes_existing_pid_file(self, mlx_stack_home: Path, pids_dir: Path) -> None: pid_path = pids_dir / "watchdog.pid" pid_path.write_text("12345")