Skip to content

feat(mlx): Add disk-based caching for converted MLX weights to prevent repeated CPU-heavy loading delay#50

Open
felk-dev wants to merge 1 commit into
z-lab:mainfrom
felk-dev:main
Open

feat(mlx): Add disk-based caching for converted MLX weights to prevent repeated CPU-heavy loading delay#50
felk-dev wants to merge 1 commit into
z-lab:mainfrom
felk-dev:main

Conversation

@felk-dev

@felk-dev felk-dev commented Jun 1, 2026

Copy link
Copy Markdown

📌 PR Description

🔍 Background & Problem

Currently, when loading a ParoQuant model on Apple Silicon using the MLX backend (load.py), the model weights (originally in AutoAWQ or PARO-native format) must be unpacked and repacked into MLX's native sequential uint32 layout.

Since this layout conversion is executed on the CPU using Python/NumPy, it incurs a significant time penalty (taking several minutes for large models). Under the current implementation, this conversion runs on every single load because the converted weights are only kept in memory, causing massive startup delays for MLX users.

💡 Solution

This PR introduces a robust, user-friendly disk-based caching mechanism for the converted MLX weights:

  1. Dual-Path Cache Check:
    • 1st Priority: Checks if a pre-converted weight file (converted_mlx.safetensors) is placed directly inside the local model directory. If present, it loads it immediately and skips the conversion entirely.
    • 2nd Priority: Checks the global user cache directory (~/.cache/paroquant/mlx_cache/<model_hash>/converted_mlx.safetensors).
  2. Automatic Cache Invalidation:
    • The <model_hash> is uniquely generated by hashing the absolute path, filenames, sizes, and modification times (mtime) of the original .safetensors files. This prevents stale cache loads when a model is updated or modified.
  3. Double Storage / Space Saving Guidance:
    • If a user loads from the global cache while the original weights are still present, a helpful suggestion is printed in the terminal recommending they move the cached file to the model directory and delete the original weights to avoid double storage.
    • A similar terminal guide is shown upon successful initial conversion.

⚙️ Key Changes

  • Modified paroquant/inference/backends/mlx/load.py:
    • Implemented cache hash calculation using hashlib.sha256.
    • Added cache loading logic before checking raw weights.
    • Added caching logic using mlx.core.save_safetensors after conversion.
    • Added friendly terminal instructions for disk space management.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant