Skip to content
Open

MLX docs #18845

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/mlx/pte_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,9 @@ def main(): # noqa: C901
parser.add_argument(
"--delegate-index",
type=int,
default=None,
default=0,
metavar="N",
help="Index of delegate to extract (0-based). If not specified, extracts first matching delegate.",
help="Index of delegate to extract (0-based, default: 0).",
)
parser.add_argument(
"--parse-mlx",
Expand Down
2 changes: 2 additions & 0 deletions docs/source/backends-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Backends are the bridge between your exported model and the hardware it runs on.
| [XNNPACK](backends/xnnpack/xnnpack-overview.md) | All | CPU | General-purpose, fallback |
| [CUDA](/backends/cuda/cuda-overview.md) | Linux/Windows | GPU | NVIDIA GPU acceleration |
| [Core ML](/backends/coreml/coreml-overview.md) | iOS, macOS | NPU/GPU/CPU | Apple devices, high performance |
| [MLX](/backends/mlx/mlx-overview.md) | macOS | GPU | Apple Silicon GPU (MLX) |
| [Metal Performance Shaders](/backends/mps/mps-overview.md) | iOS, macOS | GPU | Apple GPU acceleration |
| [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration |
| [Qualcomm](backends-qualcomm) | Android | NPU | Qualcomm SoCs |
Expand Down Expand Up @@ -55,6 +56,7 @@ Backends are the bridge between your exported model and the hardware it runs on.
backends/xnnpack/xnnpack-overview
backends/cuda/cuda-overview
backends/coreml/coreml-overview
backends/mlx/mlx-overview
backends/mps/mps-overview
backends/vulkan/vulkan-overview
backends-qualcomm
Expand Down
10 changes: 10 additions & 0 deletions docs/source/backends/mlx/mlx-op-support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Op Support

The MLX backend supports ~90 ATen operators plus multi-node fused patterns and custom ops. The partitioner automatically determines which ops in your model can be delegated to MLX. Unsupported ops fall back to ExecuTorch's portable CPU runtime.

For the current list of supported operators and fused patterns, see the source:

- **[ops.py](https://github.com/pytorch/executorch/blob/main/backends/mlx/ops.py)** — Single-op handlers (ATen op → MLX IR node)
- **[patterns.py](https://github.com/pytorch/executorch/blob/main/backends/mlx/patterns.py)** — Multi-node fused patterns (quantized linear, SDPA, KV cache, etc.)

During lowering, the MLX partitioner prints a summary of supported and unsupported ops so you can see which ones are delegated and which fall back to CPU.
144 changes: 144 additions & 0 deletions docs/source/backends/mlx/mlx-overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# MLX Backend

The MLX delegate is the ExecuTorch backend for Apple Silicon GPUs via the [MLX](https://github.com/ml-explore/mlx) framework. It compiles PyTorch models into a custom FlatBuffer bytecode format at export time and executes them using MLX GPU primitives at runtime.

::::{note}
The MLX delegate is experimental and under active development.
::::

## Features

- GPU acceleration on Apple Silicon (M1 and later) via MLX.
- INT2/INT4/INT8 weight quantization via [TorchAO](https://github.com/pytorch/ao).
- Dynamic shape support.
- Mutable buffers for persistent state across inference calls (e.g., KV cache).
- Zero-copy constant loading on unified memory.

## Target Requirements

- Apple Silicon Mac (M1 or later)
- [macOS](https://developer.apple.com/macos) >= 14.0

## Development Requirements

- [macOS](https://developer.apple.com/macos) on Apple Silicon (M1 or later)
- [Xcode](https://developer.apple.com/xcode/) (full installation, not just Command Line Tools — the Metal compiler is required)

Verify the Metal compiler is available:

```bash
xcrun -sdk macosx --find metal
```

If this prints a path (e.g., `/Applications/Xcode.app/.../metal`), you're set. If it errors, install Xcode from [developer.apple.com](https://developer.apple.com/xcode/), then switch the active developer directory:

```bash
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
```

----

## Using the MLX Backend

To target the MLX backend during export and lowering, pass an instance of `MLXPartitioner` to `to_edge_transform_and_lower`. The MLX backend also provides a set of graph optimization passes via `get_default_passes()` that should be passed as `transform_passes`. The example below demonstrates this process using MobileNet V2:

```python
import torch
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.mlx import MLXPartitioner
from executorch.backends.mlx.passes import get_default_passes
from executorch.exir import to_edge_transform_and_lower

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

et_program = to_edge_transform_and_lower(
torch.export.export(mobilenet_v2, sample_inputs),
transform_passes=get_default_passes(),
partitioner=[MLXPartitioner()],
).to_executorch()

with open("mv2_mlx.pte", "wb") as file:
et_program.write_to_file(file)
```

`get_default_passes()` includes RMSNorm fusion, consecutive view/permute/dtype-cast collapsing, no-op removal, and common subexpression elimination. These are recommended for all models and required for optimal LLM performance.

::::{note}
The MLX backend is primarily designed for LLM and generative AI workloads on Apple Silicon. The MobileNet V2 example above is shown for simplicity, but in practice you would use this backend for models like Llama, Whisper, and other transformer-based architectures. See [LLM example](https://github.com/pytorch/executorch/tree/main/backends/mlx/examples/llm) for a more representative use case.
::::

See [Partitioner API](mlx-partitioner.md) for a reference on available partitioner options.

----

## Quantization

The MLX backend supports INT4, INT8, and NVFP4 weight quantization via TorchAO for both linear and embedding layers. This is particularly useful for LLM inference. See [MLX Quantization](mlx-quantization.md) for details.

----

## Runtime Integration

### Python (pybindings)

The simplest way to get started is to install ExecuTorch with Python bindings. From the repo root:

```bash
python install_executorch.py
```

On Apple Silicon, when the Metal compiler is available, the MLX backend is automatically included. You can then export models in Python using the MLX partitioner and run them via the ExecuTorch Python API.

### C++ (CMake preset)

To build the C++ runtime with the MLX delegate, use the `mlx-release` CMake workflow preset from the repo root:

```bash
cmake --workflow --preset mlx-release
```

This configures and builds a Release build of the ExecuTorch runtime with the MLX delegate and installs artifacts into `cmake-out/`. The preset enables the MLX delegate along with commonly needed extensions (module, data loader, flat tensor, LLM runner, etc.).

Downstream C++ apps can then `find_package(executorch)` and link against `mlxdelegate` and `mlx`. The `executorch_target_link_options_shared_lib` utility handles whole-archive linkage (required for static initializer registration) cross-platform, and `executorch_target_copy_mlx_metallib` copies the Metal kernel library next to the binary so MLX can find it at runtime:

```cmake
# CMakeLists.txt
find_package(executorch REQUIRED)

# Link MLX delegate (with whole-archive for static initializer registration)
target_link_libraries(my_target PRIVATE mlxdelegate mlx)
executorch_target_link_options_shared_lib(mlxdelegate)

# Copy mlx.metallib next to the binary for runtime
executorch_target_copy_mlx_metallib(my_target)
```

No additional steps are necessary to use the backend beyond linking the target. An MLX-delegated `.pte` file will automatically run on the registered backend.

There is also an `mlx-debug` preset useful during development:

```bash
cmake --workflow --preset mlx-debug
```

## Reference

**→{doc}`/backends/mlx/mlx-troubleshooting` — Debug common issues.**

**→{doc}`/backends/mlx/mlx-partitioner` — Partitioner options.**

**→{doc}`/backends/mlx/mlx-quantization` — Supported quantization schemes.**

**→{doc}`/backends/mlx/mlx-op-support` — Supported operators.**

```{toctree}
:maxdepth: 2
:hidden:
:caption: MLX Backend
mlx-troubleshooting
mlx-partitioner
mlx-quantization
mlx-op-support
```
41 changes: 41 additions & 0 deletions docs/source/backends/mlx/mlx-partitioner.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Partitioner API

The MLX partitioner API allows for configuration of model delegation to the MLX backend. Passing an `MLXPartitioner` instance with no additional parameters will run as much of the model as possible on the MLX backend with default settings. This is the most common use case.

## Usage

```python
import torch
from executorch.backends.mlx import MLXPartitioner
from executorch.exir import to_edge_transform_and_lower

et_program = to_edge_transform_and_lower(
torch.export.export(model, example_inputs),
partitioner=[MLXPartitioner()],
).to_executorch()
```

::::{important}
`MLXPartitioner` must be used with `to_edge_transform_and_lower()`. The legacy `to_edge()` + `to_backend()` workflow is **not supported** because it decomposes ops that MLX has optimized implementations for.
::::

## Unsupported Op Logging

During partitioning, the partitioner logs a summary of any unsupported ops. This is useful for understanding what will fall back to CPU:

```
================================================================================
MLX Partitioner: UNSUPPORTED OPS SUMMARY
================================================================================
[UNSUPPORTED x2] aten.some_op.default
Reason: No handler registered
================================================================================
```

If all ops are supported, you'll see:

```
(All call_function nodes are supported!)
```

Set `ET_MLX_DEBUG=1` to see detailed per-node support decisions during partitioning.
88 changes: 88 additions & 0 deletions docs/source/backends/mlx/mlx-quantization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Quantization

The MLX backend supports weight-only quantization via [TorchAO](https://github.com/pytorch/ao) for reducing model size and improving inference performance, particularly for LLMs on Apple Silicon. Quantization is applied to the eager model in-place **before** `torch.export()`.

## `quantize_`

The MLX backend uses TorchAO's [`quantize_`](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py) API under the hood. You can call it directly for full control over quantization configs and granularity. The key TorchAO configs are:

- [`IntxWeightOnlyConfig`](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py) — for INT2/INT4/INT8 weight-only quantization with per-group granularity (group sizes 32, 64, 128)
- [`ExportableNVFP4Config`](https://github.com/pytorch/executorch/blob/main/extension/llm/export/nvfp4.py) — for NVFP4 weight-only quantization

```python
import torch
from torchao.quantization.quant_api import quantize_, IntxWeightOnlyConfig
from torchao.quantization.granularity import PerGroup

# INT4 weight-only quantization for linear layers (group_size=32)
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear),
)

# INT8 weight-only quantization for embedding layers (group_size=128)
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerGroup(128)),
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
```

## `quantize_model_`

For convenience, ExecuTorch provides `quantize_model_` which wraps `quantize_` with sensible defaults for common LLM quantization configurations:

```python
from executorch.extension.llm.export.quantize import quantize_model_

# Quantize linear layers with INT4, embedding layers with INT8
# Note: 8w defaults to per-axis grouping, which MLX does not support.
# Always pass an explicit group size when using 8w with MLX.
quantize_model_(model, qlinear_config="4w", qembedding_config="8w", qembedding_group_size=128)
```

### Supported configs

| Config | Description |
|--------|-------------|
| `4w` | INT4 weight-only quantization (per-group) |
| `8w` | INT8 weight-only quantization (per-group) |
| `nvfp4` | NVIDIA FP4 weight-only quantization |

These can be applied independently to linear layers and embedding layers.

### Using the LLM Export Script

The simplest way to export a quantized model is via the `export_llm_hf` script, which calls `quantize_model_` internally:

```bash
# INT4 quantization for both linear and embedding layers
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "unsloth/Llama-3.2-1B-Instruct" \
--output llama_int4.pte \
--use-custom-sdpa \
--use-custom-kv-cache \
--qlinear 4w \
--qembedding 4w

# INT8 quantization for linear layers only
# Note: --qlinear-group-size is required for 8w (default is per-axis, which MLX does not support)
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "unsloth/Llama-3.2-1B-Instruct" \
--output llama_int8.pte \
--use-custom-sdpa \
--use-custom-kv-cache \
--qlinear 8w \
--qlinear-group-size 128
```

### CLI Quantization Options

| Option | Default | Description |
|--------|---------|-------------|
| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) |
| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) |
| `--qlinear-group-size` | Depends on config | Group size for linear layer quantization (32, 64, or 128). Defaults to 32 for `4w`, 16 for `nvfp4`. **Required for `8w`** (default is per-axis, which MLX does not support). |
| `--qembedding-group-size` | Depends on config | Group size for embedding layer quantization (32, 64, or 128). Defaults to 32 for `4w`, 16 for `nvfp4`. **Required for `8w`** (default is per-axis, which MLX does not support). |
| `--no-tie-word-embeddings` | False | Disable re-tying lm_head to embedding after quantization |
Loading
Loading