Skip to content

Commit 816ff0b

Browse files
authored
Merge pull request #2 from weklund/fix/pull-use-python-api
fix: use huggingface_hub Python API for model pull
2 parents d6747d8 + a209296 commit 816ff0b

3 files changed

Lines changed: 44 additions & 635 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"httpx>=0.27",
2626
"psutil>=5.9",
2727
"pyyaml>=6.0",
28-
"huggingface-hub[cli]>=1.8.0",
28+
"huggingface-hub>=1.8.0",
2929
]
3030

3131
[project.urls]

src/mlx_stack/core/pull.py

Lines changed: 10 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,13 @@
1818
from pathlib import Path
1919
from typing import Any
2020

21+
from huggingface_hub import snapshot_download
2122
from rich.console import Console
2223

2324
from mlx_stack.core.catalog import CatalogEntry, QuantSource, get_entry_by_id, load_catalog
2425
from mlx_stack.core.config import ConfigCorruptError, get_value
2526
from mlx_stack.core.paths import ensure_data_home, get_data_home
2627

27-
# --------------------------------------------------------------------------- #
28-
# HuggingFace CLI binary resolution
29-
# --------------------------------------------------------------------------- #
30-
31-
32-
def _resolve_hf_cli() -> str:
33-
"""Resolve the HuggingFace CLI binary name.
34-
35-
Modern huggingface_hub versions install the CLI as ``hf`` rather than
36-
``huggingface-cli``. We try ``hf`` first (via :func:`shutil.which`)
37-
and fall back to ``huggingface-cli`` for older installations.
38-
39-
Returns:
40-
The binary name that is available on ``PATH``, preferring ``hf``.
41-
"""
42-
if shutil.which("hf"):
43-
return "hf"
44-
if shutil.which("huggingface-cli"):
45-
return "huggingface-cli"
46-
# Neither found — return "hf" (the modern default) so the caller
47-
# raises a helpful FileNotFoundError.
48-
return "hf"
49-
50-
5128
# --------------------------------------------------------------------------- #
5229
# Exceptions
5330
# --------------------------------------------------------------------------- #
@@ -321,71 +298,18 @@ def is_model_downloaded(model_path: Path) -> bool:
321298
# --------------------------------------------------------------------------- #
322299

323300

324-
def _filter_traceback(output: str) -> str:
325-
"""Filter Python traceback lines from output, returning clean error message.
326-
327-
Extracts the meaningful error message from output that may contain
328-
a full Python traceback. Removes traceback header, frame lines, and
329-
code context lines, keeping only pre-traceback content and the final
330-
exception line.
331-
332-
Args:
333-
output: Raw output that may contain traceback lines.
334-
335-
Returns:
336-
The filtered, human-readable error message.
337-
"""
338-
lines = output.strip().splitlines()
339-
if not lines:
340-
return output
341-
342-
# Check if the output contains a traceback
343-
has_traceback = any(
344-
line.strip().startswith("Traceback (most recent call last)")
345-
for line in lines
346-
)
347-
348-
if not has_traceback:
349-
return output.strip()
350-
351-
# Walk through lines:
352-
# - Keep lines before the traceback
353-
# - Skip the traceback header and all indented frame/code lines
354-
# - Keep the final exception line (first non-indented line after frames)
355-
meaningful_lines: list[str] = []
356-
in_traceback = False
357-
for line in lines:
358-
stripped = line.strip()
359-
if stripped.startswith("Traceback (most recent call last)"):
360-
in_traceback = True
361-
continue
362-
if in_traceback:
363-
# Inside traceback: skip lines that start with whitespace
364-
# (frame references like ' File "..."' and code context lines)
365-
if line.startswith((" ", "\t")) or stripped == "":
366-
continue
367-
# First non-indented, non-empty line is the exception message
368-
meaningful_lines.append(stripped)
369-
in_traceback = False
370-
continue
371-
if stripped:
372-
meaningful_lines.append(stripped)
373-
374-
return "\n".join(meaningful_lines) if meaningful_lines else output.strip()
375-
376-
377301
def _run_download(
378302
hf_repo: str,
379303
local_dir: Path,
380304
console: Console,
381305
) -> None:
382-
"""Run the HuggingFace CLI download command with real-time output.
306+
"""Download a model snapshot using the huggingface_hub Python API.
383307
384-
Resolves the CLI binary via :func:`_resolve_hf_cli` (prefers ``hf``,
385-
falls back to ``huggingface-cli``). Uses subprocess.Popen with
386-
stderr=subprocess.STDOUT so that HF CLI tqdm progress bars (written
387-
to stderr) are merged into stdout and streamed to the user in
388-
real-time. Captures output lines for error extraction on failure.
308+
Uses :func:`huggingface_hub.snapshot_download` directly instead of
309+
shelling out to the ``hf`` / ``huggingface-cli`` binaries. This
310+
avoids PATH resolution issues when mlx-stack is installed via
311+
``uv tool install`` or ``pipx``, where dependency entry-points are
312+
not exposed on the user's PATH.
389313
390314
Args:
391315
hf_repo: The HuggingFace repo to download.
@@ -395,82 +319,12 @@ def _run_download(
395319
Raises:
396320
DownloadError: If the download fails.
397321
"""
398-
# Resolve the HF CLI binary: prefer "hf" (modern), fall back to
399-
# "huggingface-cli" (legacy).
400-
hf_binary = _resolve_hf_cli()
401-
cmd = [
402-
hf_binary,
403-
"download",
404-
hf_repo,
405-
"--local-dir",
406-
str(local_dir),
407-
]
408-
409322
try:
410-
proc = subprocess.Popen(
411-
cmd,
412-
stdout=subprocess.PIPE,
413-
stderr=subprocess.STDOUT,
414-
text=True,
415-
)
416-
except FileNotFoundError:
417-
msg = (
418-
"HuggingFace CLI not found (tried 'hf' and 'huggingface-cli').\n"
419-
"Install huggingface_hub:\n"
420-
" pip install 'huggingface_hub[cli]'\n"
421-
"Or: uv pip install 'huggingface_hub[cli]'"
422-
)
423-
raise DownloadError(msg) from None
424-
except OSError as exc:
425-
msg = f"Failed to start download: {exc}"
426-
raise DownloadError(msg) from None
427-
428-
# Stream stdout (merged with stderr) line-by-line to show download
429-
# progress bars in real-time. Capture lines for error extraction.
430-
# Filter traceback blocks DURING streaming — suppress them from
431-
# console output but still capture them for the error handler.
432-
assert proc.stdout is not None
433-
captured_lines: list[str] = []
434-
in_traceback = False
435-
try:
436-
for line in proc.stdout:
437-
stripped = line.rstrip("\n")
438-
if not stripped:
439-
continue
440-
441-
captured_lines.append(stripped)
442-
443-
# Detect start of a traceback block
444-
if stripped.strip().startswith("Traceback (most recent call last)"):
445-
in_traceback = True
446-
continue
447-
448-
if in_traceback:
449-
# Inside traceback: suppress indented frame/code lines
450-
if stripped.startswith((" ", "\t")):
451-
continue
452-
# First non-indented line after frames is the exception
453-
# message — suppress it too (it's the error summary)
454-
in_traceback = False
455-
continue
456-
457-
# Normal line — show to user
458-
console.print(f" {stripped}")
459-
460-
# Wait for process to complete
461-
proc.wait(timeout=3600)
462-
except subprocess.TimeoutExpired:
463-
proc.kill()
464-
proc.wait()
465-
msg = "Download timed out after 1 hour."
323+
snapshot_download(repo_id=hf_repo, local_dir=str(local_dir))
324+
except Exception as exc:
325+
msg = f"Download failed for {hf_repo}: {exc}"
466326
raise DownloadError(msg) from None
467327

468-
if proc.returncode != 0:
469-
raw_output = "\n".join(captured_lines)
470-
clean_error = _filter_traceback(raw_output)
471-
msg = f"Download failed for {hf_repo}:\n{clean_error}"
472-
raise DownloadError(msg)
473-
474328

475329
def download_model(
476330
hf_repo: str,

0 commit comments

Comments
 (0)