feat(mlx): Add disk-based caching for converted MLX weights to prevent repeated CPU-heavy loading delay#50
Open
felk-dev wants to merge 1 commit into
Open
feat(mlx): Add disk-based caching for converted MLX weights to prevent repeated CPU-heavy loading delay#50felk-dev wants to merge 1 commit into
felk-dev wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
📌 PR Description
🔍 Background & Problem
Currently, when loading a ParoQuant model on Apple Silicon using the MLX backend (
load.py), the model weights (originally in AutoAWQ or PARO-native format) must be unpacked and repacked into MLX's native sequentialuint32layout.Since this layout conversion is executed on the CPU using Python/NumPy, it incurs a significant time penalty (taking several minutes for large models). Under the current implementation, this conversion runs on every single load because the converted weights are only kept in memory, causing massive startup delays for MLX users.
💡 Solution
This PR introduces a robust, user-friendly disk-based caching mechanism for the converted MLX weights:
converted_mlx.safetensors) is placed directly inside the local model directory. If present, it loads it immediately and skips the conversion entirely.~/.cache/paroquant/mlx_cache/<model_hash>/converted_mlx.safetensors).<model_hash>is uniquely generated by hashing the absolute path, filenames, sizes, and modification times (mtime) of the original.safetensorsfiles. This prevents stale cache loads when a model is updated or modified.⚙️ Key Changes
hashlib.sha256.mlx.core.save_safetensorsafter conversion.