Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds 1-bit affine quantization support (scale + bias) to MLX, extending the existing affine quantization bit-width options and validating the new behavior via tests and benchmarks.
Changes:
- Enable
bits=1for affine quantization in the core quantize path and documentation. - Add 1-bit support to CPU and Metal quantization / quantized-matmul kernels and instantiations.
- Extend Python tests and comparative benchmarks to cover/measure 1-bit quantization.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
mlx/ops.cpp |
Allows affine quantization bits=1 and adds a 1-bit-specific scale/bias computation. |
mlx/backend/cpu/quantized.cpp |
Adds 1-bit dispatch for quantized matmul and 1-bit affine scale/bias computation during quantize. |
mlx/backend/metal/kernels/quantized.h |
Adds 1-bit paths to quantized dot/outer/dequantize helpers and template guards. |
mlx/backend/metal/kernels/quantized_nax.h |
Adds 1-bit paths analogous to quantized.h for the NAX variant. |
mlx/backend/metal/kernels/quantized.metal |
Adds 1-bit affine quantize/dequantize support in the Metal kernels. |
mlx/backend/metal/kernels/quantized_nax.metal |
Instantiates quantized kernels for bits=1. |
python/src/ops.cpp |
Updates the user-facing quantization-mode table to list affine bits=1. |
python/tests/test_quantized.py |
Expands existing tests to include bits=1 and adds a dedicated 1-bit test suite. |
python/tests/cuda_skip.py |
Skips the new 1-bit test on CUDA. |
benchmarks/python/comparative/compare.py |
Adds an MLX-only comparison helper to compare quantized matmul across bit widths. |
benchmarks/python/comparative/bench_mlx.py |
Adds 1-bit quantized matmul benchmarks and quantizes weights inside the benchmark runner. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for i in range(scales.shape[0]): | ||
| for j in range(scales.shape[1]): | ||
| s = scales[i, j].item() | ||
| b = biases[i, j].item() | ||
| row_start = j * gs | ||
| row_end = row_start + gs | ||
| vals = w_hat[i, row_start:row_end] | ||
| mx.eval(vals) | ||
| for v in vals.tolist(): | ||
| self.assertTrue( | ||
| abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5, | ||
| f"Value {v} not in {{bias={b}, bias+scale={b+s}}}", | ||
| ) | ||
|
|
There was a problem hiding this comment.
This new 1-bit test does a large number of scalar .item() fetches and per-slice .tolist() conversions inside nested Python loops. On GPU/Metal this typically triggers repeated device synchronizations and host transfers, which can make the test suite significantly slower/flakier. Prefer a vectorized assertion (e.g., broadcast biases/scales to w_hat shape and check isclose(w_hat, bias) OR isclose(w_hat, bias+scale) in one or a few MLX ops) to keep the test on-device.
| for i in range(scales.shape[0]): | |
| for j in range(scales.shape[1]): | |
| s = scales[i, j].item() | |
| b = biases[i, j].item() | |
| row_start = j * gs | |
| row_end = row_start + gs | |
| vals = w_hat[i, row_start:row_end] | |
| mx.eval(vals) | |
| for v in vals.tolist(): | |
| self.assertTrue( | |
| abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5, | |
| f"Value {v} not in {{bias={b}, bias+scale={b+s}}}", | |
| ) | |
| # Vectorized check: broadcast biases/scales to w_hat shape and | |
| # verify each element equals either bias or bias + scale. | |
| rows, cols = w_hat.shape | |
| num_groups = scales.shape[1] | |
| self.assertEqual(cols, num_groups * gs) | |
| # Expand biases/scales so that each group value is repeated gs times | |
| # along the column dimension, matching w_hat's shape. | |
| biases_expanded = mx.expand_dims(biases, -1) | |
| biases_tiled = mx.broadcast_to( | |
| biases_expanded, (biases.shape[0], biases.shape[1], gs) | |
| ) | |
| biases_full = biases_tiled.reshape(rows, cols) | |
| scales_expanded = mx.expand_dims(scales, -1) | |
| scales_tiled = mx.broadcast_to( | |
| scales_expanded, (scales.shape[0], scales.shape[1], gs) | |
| ) | |
| scales_full = scales_tiled.reshape(rows, cols) | |
| diff_bias = (w_hat - biases_full).abs() | |
| diff_bias_plus_scale = (w_hat - (biases_full + scales_full)).abs() | |
| mask = (diff_bias < 1e-5) | (diff_bias_plus_scale < 1e-5) | |
| all_ok = mx.all(mask) | |
| self.assertTrue( | |
| bool(all_ok.item()), | |
| "Some dequantized values are not in {bias, bias + scale}", | |
| ) |
| ====== ====================== ============================== ============= ===== | ||
| mode group size bits scale type bias | ||
| ====== ====================== ============================== ============= ===== | ||
| affine 32, 64\ :sup:`*`, 128 1, 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes | ||
| mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no | ||
| mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no | ||
| nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no |
There was a problem hiding this comment.
The docs now state affine quantization supports 1-bit. If 1-bit affine is not implemented across all GPU backends (notably CUDA), this table becomes misleading for users on those platforms. Either ensure CUDA supports 1-bit affine quantize/dequantize, or qualify the documentation to indicate backend-specific availability.
| w = mx.random.normal(shape=(128, 512)) | ||
| for gs in [32, 64, 128]: | ||
| for b in [2, 3, 5, 6, 4, 8]: | ||
| for b in [1, 2, 3, 5, 6, 4, 8]: | ||
| with self.subTest(gs=gs, b=b): | ||
| w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) | ||
| w_hat = mx.dequantize(w_q, scales, biases, gs, b) |
There was a problem hiding this comment.
test_quantize_dequantize now exercises bits=1, but the CUDA backend for affine quantize/dequantize does not currently dispatch/handle 1-bit (e.g., CUDA dispatch_bits lacks case 1 and dequant bit-unpacking only handles 2/4/8). This will likely break CUDA test runs where this test is not skipped. Either add CUDA support for 1-bit affine quantize/dequantize, or conditionally skip bits=1 in this test when running on CUDA / GPU backends that don't support it yet.
| if (bits < 1 || bits > 8 || bits == 7) { | ||
| std::ostringstream msg; | ||
| msg << "[quantize] The requested number of bits " << bits | ||
| << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; | ||
| << " is not supported. The supported bits are 1, 2, 3, 4, 5, 6 and 8."; | ||
| throw std::invalid_argument(msg.str()); |
There was a problem hiding this comment.
The new bits=1 path is enabled globally here, but CUDA’s affine quantize/dequantize implementation still only dispatches bits {2,3,4,5,6,8}. As-is, calling mx.quantize(..., bits=1) on CUDA/GPU will likely be unsupported at runtime despite this check allowing it. Consider either (a) implementing the 1-bit CUDA kernels/dispatch, or (b) restricting bits=1 to the backends that support it (and updating the error message/docs accordingly).
Proposed changes
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes