Skip to content

1bit affine quantization#1

Closed
khosravipasha wants to merge 3 commits intomainfrom
1bit-affine-quantization
Closed

1bit affine quantization#1
khosravipasha wants to merge 3 commits intomainfrom
1bit-affine-quantization

Conversation

@khosravipasha
Copy link
Copy Markdown
Collaborator

Proposed changes

Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@khosravipasha khosravipasha marked this pull request as ready for review February 24, 2026 17:52
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds 1-bit affine quantization support (scale + bias) to MLX, extending the existing affine quantization bit-width options and validating the new behavior via tests and benchmarks.

Changes:

  • Enable bits=1 for affine quantization in the core quantize path and documentation.
  • Add 1-bit support to CPU and Metal quantization / quantized-matmul kernels and instantiations.
  • Extend Python tests and comparative benchmarks to cover/measure 1-bit quantization.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
mlx/ops.cpp Allows affine quantization bits=1 and adds a 1-bit-specific scale/bias computation.
mlx/backend/cpu/quantized.cpp Adds 1-bit dispatch for quantized matmul and 1-bit affine scale/bias computation during quantize.
mlx/backend/metal/kernels/quantized.h Adds 1-bit paths to quantized dot/outer/dequantize helpers and template guards.
mlx/backend/metal/kernels/quantized_nax.h Adds 1-bit paths analogous to quantized.h for the NAX variant.
mlx/backend/metal/kernels/quantized.metal Adds 1-bit affine quantize/dequantize support in the Metal kernels.
mlx/backend/metal/kernels/quantized_nax.metal Instantiates quantized kernels for bits=1.
python/src/ops.cpp Updates the user-facing quantization-mode table to list affine bits=1.
python/tests/test_quantized.py Expands existing tests to include bits=1 and adds a dedicated 1-bit test suite.
python/tests/cuda_skip.py Skips the new 1-bit test on CUDA.
benchmarks/python/comparative/compare.py Adds an MLX-only comparison helper to compare quantized matmul across bit widths.
benchmarks/python/comparative/bench_mlx.py Adds 1-bit quantized matmul benchmarks and quantizes weights inside the benchmark runner.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +211 to +224
for i in range(scales.shape[0]):
for j in range(scales.shape[1]):
s = scales[i, j].item()
b = biases[i, j].item()
row_start = j * gs
row_end = row_start + gs
vals = w_hat[i, row_start:row_end]
mx.eval(vals)
for v in vals.tolist():
self.assertTrue(
abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5,
f"Value {v} not in {{bias={b}, bias+scale={b+s}}}",
)

Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new 1-bit test does a large number of scalar .item() fetches and per-slice .tolist() conversions inside nested Python loops. On GPU/Metal this typically triggers repeated device synchronizations and host transfers, which can make the test suite significantly slower/flakier. Prefer a vectorized assertion (e.g., broadcast biases/scales to w_hat shape and check isclose(w_hat, bias) OR isclose(w_hat, bias+scale) in one or a few MLX ops) to keep the test on-device.

Suggested change
for i in range(scales.shape[0]):
for j in range(scales.shape[1]):
s = scales[i, j].item()
b = biases[i, j].item()
row_start = j * gs
row_end = row_start + gs
vals = w_hat[i, row_start:row_end]
mx.eval(vals)
for v in vals.tolist():
self.assertTrue(
abs(v - b) < 1e-5 or abs(v - (b + s)) < 1e-5,
f"Value {v} not in {{bias={b}, bias+scale={b+s}}}",
)
# Vectorized check: broadcast biases/scales to w_hat shape and
# verify each element equals either bias or bias + scale.
rows, cols = w_hat.shape
num_groups = scales.shape[1]
self.assertEqual(cols, num_groups * gs)
# Expand biases/scales so that each group value is repeated gs times
# along the column dimension, matching w_hat's shape.
biases_expanded = mx.expand_dims(biases, -1)
biases_tiled = mx.broadcast_to(
biases_expanded, (biases.shape[0], biases.shape[1], gs)
)
biases_full = biases_tiled.reshape(rows, cols)
scales_expanded = mx.expand_dims(scales, -1)
scales_tiled = mx.broadcast_to(
scales_expanded, (scales.shape[0], scales.shape[1], gs)
)
scales_full = scales_tiled.reshape(rows, cols)
diff_bias = (w_hat - biases_full).abs()
diff_bias_plus_scale = (w_hat - (biases_full + scales_full)).abs()
mask = (diff_bias < 1e-5) | (diff_bias_plus_scale < 1e-5)
all_ok = mx.all(mask)
self.assertTrue(
bool(all_ok.item()),
"Some dequantized values are not in {bias, bias + scale}",
)

Copilot uses AI. Check for mistakes.
Comment on lines +4351 to +4357
====== ====================== ============================== ============= =====
mode group size bits scale type bias
====== ====================== ============================== ============= =====
affine 32, 64\ :sup:`*`, 128 1, 2, 3, 4\ :sup:`*`, 5, 6, 8 same as input yes
mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no
mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no
nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs now state affine quantization supports 1-bit. If 1-bit affine is not implemented across all GPU backends (notably CUDA), this table becomes misleading for users on those platforms. Either ensure CUDA supports 1-bit affine quantize/dequantize, or qualify the documentation to indicate backend-specific availability.

Copilot uses AI. Check for mistakes.
Comment on lines 12 to 17
w = mx.random.normal(shape=(128, 512))
for gs in [32, 64, 128]:
for b in [2, 3, 5, 6, 4, 8]:
for b in [1, 2, 3, 5, 6, 4, 8]:
with self.subTest(gs=gs, b=b):
w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b)
w_hat = mx.dequantize(w_q, scales, biases, gs, b)
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_quantize_dequantize now exercises bits=1, but the CUDA backend for affine quantize/dequantize does not currently dispatch/handle 1-bit (e.g., CUDA dispatch_bits lacks case 1 and dequant bit-unpacking only handles 2/4/8). This will likely break CUDA test runs where this test is not skipped. Either add CUDA support for 1-bit affine quantize/dequantize, or conditionally skip bits=1 in this test when running on CUDA / GPU backends that don't support it yet.

Copilot uses AI. Check for mistakes.
Comment on lines +4543 to 4547
if (bits < 1 || bits > 8 || bits == 7) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8.";
<< " is not supported. The supported bits are 1, 2, 3, 4, 5, 6 and 8.";
throw std::invalid_argument(msg.str());
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new bits=1 path is enabled globally here, but CUDA’s affine quantize/dequantize implementation still only dispatches bits {2,3,4,5,6,8}. As-is, calling mx.quantize(..., bits=1) on CUDA/GPU will likely be unsupported at runtime despite this check allowing it. Consider either (a) implementing the 1-bit CUDA kernels/dispatch, or (b) restricting bits=1 to the backends that support it (and updating the error message/docs accordingly).

Copilot uses AI. Check for mistakes.
@khosravipasha khosravipasha deleted the 1bit-affine-quantization branch March 9, 2026 19:50
@khosravipasha khosravipasha restored the 1bit-affine-quantization branch March 31, 2026 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants