# From your mvmr_triton_kernel.py pattern:
# Input: a(T, G, M, C), a_idx(T,), b(B, G, C), b_idx(B,)
# Output: o(K, G, M), o_idx(T,)
#
# Operation: For each t in T:
# out_pos = o_idx[t]
# out[out_pos, :, :] += a[t] @ b[b_idx[t]].T# mlx_ops/mvmr.py
import mlx.core as mx
from typing import Tuple
def mvmr_metal_kernel_source():
"""Metal kernel for sparse matrix-vector multiplication with reduction"""
return """
#include <metal_stdlib>
using namespace metal;
kernel void mvmr_kernel(
device const float* a [[buffer(0)]], // (T, G, M, C) flattened
device const uint32_t* a_idx [[buffer(1)]], // (T,)
device const float* b [[buffer(2)]], // (B, G, C) flattened
device const uint32_t* b_idx [[buffer(3)]], // (B,)
device const uint32_t* o_idx [[buffer(4)]], // (T,)
device float* o [[buffer(5)]], // (K, G, M) flattened
device const uint32_t* dims [[buffer(6)]], // [T, G, M, C, K, B]
uint3 gid [[threadgroup_position_in_grid]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tpg [[threads_per_threadgroup]]
) {
uint linear_idx = gid.x * tpg.x + tid.x;
// dims layout: [T, G, M, C, K, B]
uint T = dims[0], G = dims[1], M = dims[2], C = dims[3];
uint K = dims[4]; // output size
// Each thread handles one (t, g, m) triple
if (linear_idx >= T * G * M) return;
uint t = linear_idx / (G * M);
uint gm = linear_idx % (G * M);
uint g = gm / M;
uint m = gm % M;
// Get indices
uint b_pos = b_idx[t]; // which b vector to use
uint out_pos = o_idx[t];
// Compute dot product a[t,g,m,:] @ b[b_pos,g,:]
float sum = 0.0;
for (uint c = 0; c < C; c++) {
uint a_offset = (((t * G + g) * M + m) * C) + c;
uint b_offset = ((b_pos * G + g) * C) + c;
sum += a[a_offset] * b[b_offset];
}
// Atomic add to output (handles duplicates in o_idx)
uint out_offset = ((out_pos * G + g) * M) + m;
atomic_fetch_add_explicit(
(device atomic_float*)&o[out_offset],
sum,
memory_order_relaxed
);
}
"""
def sparse_matrix_vector_multiplication_reduction(
a: mx.array, # (T, G, M, C)
a_idx: mx.array, # (T,)
b: mx.array, # (B, G, C)
b_idx: mx.array, # (B,) - actually not used but mirrors PyTorch signature
o_idx: mx.array, # (T,)
n_o: int # output size K
) -> mx.array:
"""
Sparse matrix-vector multiplication with reduction.
For each t in range(T):
out[o_idx[t], :, :] += a[t] @ b[b_idx[t]].T
"""
T, G, M, C = a.shape
B = b.shape[0]
# Create output tensor
o = mx.zeros((n_o, G, M), dtype=a.dtype)
# Create Metal kernel
kernel = mx.fast.metal_kernel(
name="mvmr_kernel",
source=mvmr_metal_kernel_source(),
input_shapes=[
a.shape, # a
a_idx.shape, # a_idx
b.shape, # b
b_idx.shape, # b_idx (not used in kernel)
o_idx.shape, # o_idx
(n_o, G, M), # o
],
output_shape=(n_o, G, M),
output_dtype=a.dtype
)
# Calculate grid size
threads_per_block = 256
total_elements = T * G * M
blocks = (total_elements + threads_per_block - 1) // threads_per_block
# Execute kernel
o = kernel(
a.reshape(-1),
a_idx,
b.reshape(-1),
b_idx,
o_idx,
o.reshape(-1),
mx.array([T, G, M, C, n_o, B], dtype=mx.uint32)
)
return o.reshape((n_o, G, M))// mlx_ext/mvmr.cpp
#include <mlx/mlx.h>
#include <metal_stdlib>
using namespace mlx::core;
using namespace mlx::core::metal;
void mvmr_gpu(
const array& a, // (T, G, M, C)
const array& a_idx, // (T,)
const array& b, // (B, G, C)
const array& b_idx, // (B,)
const array& o_idx, // (T,)
array& output, // (K, G, M)
int n_o // output size K
) {
// Get device and command queue
metal::Device& device = metal::device();
auto& queue = device.get_command_queue();
// Create Metal command encoder
auto encoder = device.get_command_encoder();
// Load kernel
auto kernelLib = device.get_kernel("mvmr_kernel");
// Set up buffers
encoder.setBuffer(a.data<float>(), 0); // buffer(0)
encoder.setBuffer(a_idx.data<uint32_t>(), 1);
encoder.setBuffer(b.data<float>(), 2);
encoder.setBuffer(b_idx.data<uint32_t>(), 3);
encoder.setBuffer(o_idx.data<uint32_t>(), 4);
encoder.setBuffer(output.data<float>(), 5);
// Calculate grid
int T = a.shape(0), G = a.shape(1), M = a.shape(2);
int total = T * G * M;
MTLSize threadsPerGroup = MTLSizeMake(256, 1, 1);
MTLSize numThreadgroups = MTLSizeMake(
(total + 255) / 256, 1, 1
);
// Dispatch
encoder.dispatchThreadgroups(numThreadgroups, threadsPerGroup);
encoder.endEncoding();
queue.commit();
}
// Python binding
void init_mvmr(py::module& m) {
m.def("sparse_matrix_vector_multiplication_reduction",
&mvmr_gpu,
"Sparse matrix-vector multiply with reduction");
}-
Buffer Management
- MLX handles contiguity checks
- Flatten high-dimensional tensors for Metal
- Use uint32_t for indices (GPU-compatible)
-
Atomic Operations
// Metal atomic for float32 reduction atomic_fetch_add_explicit( (device atomic_float*)&output[idx], value, memory_order_relaxed );
-
Grid Calculation
# Match your Triton pattern: # grid = (cdiv(T, L) * cdiv(G, BG) * cdiv(M, BM) * cdiv(C, BC),) # Simplified for Metal: total_work = T * G * M # work items threads_per_block = 256 blocks = (total_work + threads_per_block - 1) // threads_per_block
-
Gradient Support
- MLX's metal_kernel automatically creates vjp
- For complex gradients, implement custom vjp
- Reference VVOR kernel for backward pass pattern
# From large_segment_reduce_triton.py
# Input: x (T, C), lengths (K,)
# Output: y (K, C) - each segment's reduction
# Uses repeat_interleave to expand lengths to indices# mlx_ops/segment_reduce.py
import mlx.core as mx
from mlx_graphs.utils.scatter import scatter
def segment_reduce(
x: mx.array, # (T, C)
lengths: mx.array, # (K,) - length of each segment
operation: str = "sum" # or "max", "mean", "min"
) -> mx.array:
"""
Reduce segments of x based on lengths.
Example:
x = [a, b, c, d, e]
lengths = [2, 3] # first 2 elements, then 3
sum: [a+b, c+d+e]
"""
T, C = x.shape
K = lengths.shape[0]
# Create segment indices: [0, 0, 1, 1, 1]
indices = mx.repeat(
mx.arange(K), # [0, 1, 2, ..., K-1]
lengths # repeat by lengths
)
# Use scatter for reduction
if operation == "sum":
y = scatter(
src=x,
index=indices,
dim=0,
dim_size=K,
reduce="sum"
)
elif operation == "max":
y = scatter(
src=x,
index=indices,
dim=0,
dim_size=K,
reduce="max"
)
elif operation == "mean":
y = scatter(
src=x,
index=indices,
dim=0,
dim_size=K,
reduce="mean"
)
else:
raise ValueError(f"Unknown operation: {operation}")
return y# For large segments or performance-critical paths
def large_segment_reduce_metal_kernel():
return """
#include <metal_stdlib>
using namespace metal;
kernel void segment_reduce_sum(
device const float* x [[buffer(0)]], // (T, C)
device const uint32_t* segment_ids [[buffer(1)]], // (T,)
device const uint32_t* offsets [[buffer(2)]], // (K,)
device const uint32_t* lengths [[buffer(3)]], // (K,)
device float* y [[buffer(4)]], // (K, C)
uint3 gid [[threadgroup_position_in_grid]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tpg [[threads_per_threadgroup]]
) {
uint linear_idx = gid.x * tpg.x + tid.x;
uint T = x.size() / x.shape[1];
uint C = x.shape[1];
uint K = y.shape[0];
// Each thread handles one output element y[k, c]
if (linear_idx >= K * C) return;
uint k = linear_idx / C;
uint c = linear_idx % C;
uint start = offsets[k];
uint len = lengths[k];
// Sum segment
float sum = 0.0;
for (uint i = 0; i < len; i++) {
uint t = start + i;
sum += x[t * C + c];
}
y[k * C + c] = sum;
}
"""# SORTED segments (what MLX scatter assumes):
# Input: x = [a, b, c, d, e], lengths = [2, 3]
# Works: segments are contiguous
# UNSORTED/INDEXED segments (what your code might need):
# Input: x = values, indices = [0, 1, 0, 2, 1]
# Needs: indexed_segment_reduce with index mapping
# MLX solution for unsorted:
# Use custom Metal kernel OR
# Pre-sort data + track permutation for backward pass# Compute pairwise distances
distances = torch.cdist(points_a, points_b) # (N, M)def indexed_distance(
a: mx.array, # (N, D) or variable-length
b: mx.array, # (M, D) or variable-length
a_indices: mx.array = None, # if variable-length
b_indices: mx.array = None,
metric: str = "euclidean"
) -> mx.array:
"""
Compute pairwise distances with optional indexing.
For point clouds, often you want:
distances[i, j] = ||a[a_indices[i]] - b[b_indices[j]]||
"""
if a_indices is None:
# Standard pairwise distance
a_expanded = mx.expand_dims(a, 1) # (N, 1, D)
b_expanded = mx.expand_dims(b, 0) # (1, M, D)
diff = a_expanded - b_expanded # (N, M, D)
else:
# Gather with indexing
a_selected = a[a_indices] # (N, D)
b_selected = b[b_indices] # (M, D)
a_expanded = mx.expand_dims(a_selected, 1)
b_expanded = mx.expand_dims(b_selected, 0)
diff = a_expanded - b_expanded
if metric == "euclidean":
# ||a - b||_2 = sqrt(sum((a-b)^2))
distances = mx.sqrt(mx.sum(diff**2, axis=-1))
elif metric == "squared_euclidean":
distances = mx.sum(diff**2, axis=-1)
elif metric == "cosine":
a_norm = mx.sqrt(mx.sum(a**2, axis=-1, keepdims=True))
b_norm = mx.sqrt(mx.sum(b**2, axis=-1, keepdims=True))
ab = mx.sum(diff, axis=-1) # not correct, for illustration
distances = 1.0 - (ab / (a_norm * b_norm))
else:
raise ValueError(f"Unknown metric: {metric}")
return distances- Dense implementation above requires memory O(NMD)
- For large point clouds, may exceed GPU memory
- Alternative: Compute in blocks or use kernel
def indexed_distance_blocked(
a: mx.array,
b: mx.array,
block_size: int = 512,
metric: str = "euclidean"
) -> mx.array:
"""Process in blocks to reduce memory usage"""
N, D = a.shape
M = b.shape[0]
result = mx.zeros((N, M), dtype=a.dtype)
for i in range(0, N, block_size):
end_i = min(i + block_size, N)
a_block = a[i:end_i]
a_exp = mx.expand_dims(a_block, 1) # (block_size, 1, D)
b_exp = mx.expand_dims(b, 0) # (1, M, D)
diff = a_exp - b_exp
if metric == "euclidean":
distances = mx.sqrt(mx.sum(diff**2, axis=-1))
else:
distances = mx.sum(diff**2, axis=-1)
result[i:end_i] = distances
return resultfrom mlx import core as mx
from typing import Callable, List
def custom_sparse_op_with_grad(
forward_fn: Callable,
backward_fn: Callable,
*args
):
"""Wrapper for custom ops with manual gradient implementation"""
class CustomOp:
def __init__(self, forward_fn, backward_fn):
self.forward_fn = forward_fn
self.backward_fn = backward_fn
self._saved_tensors = None
def __call__(self, *args):
# Save for backward
self._saved_tensors = args
# Forward pass
return self.forward_fn(*args)
def vjp(self, *grad_outputs):
"""Vector-Jacobian product (backward)"""
return self.backward_fn(self._saved_tensors, grad_outputs[0])
op = CustomOp(forward_fn, backward_fn)
return op(*args)def mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o):
"""Forward pass: sparse mat-vec multiply"""
# Call your Metal kernel here
return mx.zeros((n_o, a.shape[1], a.shape[2]))
def mvmr_backward(saved_tensors, grad_output):
"""Backward pass: compute gradients w.r.t. a and b"""
a, a_idx, b, b_idx, o_idx, n_o = saved_tensors
# grad_a computed via VVOR pattern
# grad_b computed via transposed MVMR
grad_a = None # sparse_vector_vector_outer_product_reduction(...)
grad_b = None # sparse_matrix_vector_multiplication_reduction(...)
return (grad_a, None, grad_b, None, None, None)
# Use mlx.custom_vjp for cleaner API
@mx.custom_vjp
def mvmr_with_grad(a, a_idx, b, b_idx, o_idx, n_o):
return mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o)
def mvmr_fwd(a, a_idx, b, b_idx, o_idx, n_o):
return mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o), (a, a_idx, b, b_idx, o_idx)
def mvmr_bwd(aux, g):
a, a_idx, b, b_idx, o_idx = aux
return mvmr_backward((a, a_idx, b, b_idx, o_idx, len(set(o_idx))), g)
mvmr_with_grad.defvjp(mvmr_fwd, mvmr_bwd)class PointCloudBatch:
"""Represent variable-length point clouds without padding"""
def __init__(self, positions, features=None, lengths=None, indices=None):
self.positions = positions # (T, 3) - flattened points
self.features = features # (T, C) - flattened features
self.lengths = lengths # (B,) - points per cloud
self.indices = indices or mx.repeat(mx.arange(len(lengths)), lengths)
self.batch_size = len(lengths)
self.total_points = positions.shape[0]
def to_dense(self, max_len=None):
"""Convert to padded dense tensor"""
if max_len is None:
max_len = mx.max(self.lengths).item()
B, C = self.batch_size, self.positions.shape[1]
dense = mx.zeros((B, max_len, C))
for b in range(B):
start = mx.sum(self.lengths[:b])
end = start + self.lengths[b]
dense[b, :self.lengths[b].item()] = self.positions[start:end]
return dense
def gather_by_index(self, idx):
"""Gather points by flattened index"""
return self.positions[idx]
def segment_reduce(self, values, operation="sum"):
"""Reduce values by point cloud"""
return segment_reduce(values, self.lengths, operation)
# Usage:
# pc_batch = PointCloudBatch(
# positions=mx.random.normal((1000, 3)),
# lengths=mx.array([300, 400, 300]) # 3 point clouds
# )class OffsetPointCloudBatch:
"""Even more efficient: use offsets instead of lengths"""
def __init__(self, positions, features=None, offsets=None):
self.positions = positions # (T, 3)
self.features = features # (T, C)
if offsets is None:
# Compute from positions
offsets = mx.array([0, positions.shape[0]])
self.offsets = offsets # (B+1,) cumulative
self.batch_size = len(offsets) - 1
def get_points(self, batch_idx):
"""Get all points for one cloud"""
start = self.offsets[batch_idx]
end = self.offsets[batch_idx + 1]
return self.positions[start:end]
def as_lengths(self):
"""Convert offsets to lengths"""
return mx.diff(self.offsets)import mlx.nn as nn
import mlx.core as mx
class PointCNNConv(nn.Module):
"""MLX version of PointCNN convolution"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# MLX uses simple linear layers
self.mlp = nn.Sequential(
nn.Linear(in_channels, 64),
nn.ReLU(),
nn.Linear(64, out_channels)
)
def forward(
self,
x: mx.array, # (T, C_in)
coords: mx.array, # (T, 3)
neighbors_idx: mx.array, # (T, K) - indices of K neighbors
neighbors_rel_pos: mx.array, # (T, K, 3)
batch_indices: mx.array = None
) -> mx.array:
"""
x: point features
coords: point coordinates
neighbors_idx: k-nearest neighbors
neighbors_rel_pos: relative positions to neighbors
"""
T, K = neighbors_idx.shape
# Gather features of neighbors
neighbors_x = x[neighbors_idx] # (T, K, C_in)
# Expand relative positions
neighbors_rel_expanded = mx.expand_dims(
neighbors_rel_pos, -1
) # (T, K, 3, 1)
# Apply MLP to each neighbor
# This is simplified - real PointCNN does X-transformation
neighbor_features = self.mlp(neighbors_x) # (T, K, C_out)
# Aggregate (max pool over neighbors)
aggregated = mx.max(neighbor_features, axis=1) # (T, C_out)
return aggregated
class PointCNNEncoder(nn.Module):
"""Full PointCNN encoder stack"""
def __init__(self, num_layers: int = 3):
super().__init__()
self.layers = nn.Sequential(*[
PointCNNConv(64 if i > 0 else 3, 64)
for i in range(num_layers)
])
def forward(self, x: mx.array, coords: mx.array) -> mx.array:
# In real implementation, compute k-NN before each layer
return self.layers(x)def train_pointcnn(model, data_loader, optimizer, loss_fn, epochs=10):
"""MLX training pattern"""
for epoch in range(epochs):
total_loss = 0.0
for batch in data_loader:
# Forward pass
def loss_fn_wrapper(params):
model.update(params)
logits = model(batch['positions'], batch['features'])
loss = loss_fn(logits, batch['labels'])
return loss
# Compute gradients
loss, grads = mx.value_and_grad(loss_fn_wrapper)(
model.parameters()
)
# Update parameters
optimizer.update(model, grads)
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(data_loader)}")import torch
import mlx.core as mx
def test_mvmr_correctness():
"""Verify MLX kernel matches PyTorch reference"""
# Random test data
T, G, M, C, K = 100, 8, 16, 32, 64
# Create test arrays
a_pt = torch.randn(T, G, M, C)
a_idx_pt = torch.arange(T)
b_pt = torch.randn(K, G, C)
b_idx_pt = torch.randint(0, K, (T,))
o_idx_pt = torch.randint(0, K, (T,))
# PyTorch reference
o_pt = torch.zeros(K, G, M)
for t in range(T):
o_pt[o_idx_pt[t]] += a_pt[t] @ b_pt[b_idx_pt[t]].transpose(0, 1)
# MLX version
a_mx = mx.array(a_pt.numpy())
b_mx = mx.array(b_pt.numpy())
o_idx_mx = mx.array(o_idx_pt.numpy())
b_idx_mx = mx.array(b_idx_pt.numpy())
o_mx = sparse_matrix_vector_multiplication_reduction(
a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
)
# Compare
o_mx_np = mx.array(o_mx).tolist()
o_pt_np = o_pt.numpy()
error = max(abs(o_mx_np[i] - o_pt_np[i]) for i in range(K))
print(f"Max error: {error}")
assert error < 1e-5, f"Correctness check failed: {error}"
def test_gradient_flow():
"""Verify gradients flow correctly"""
def forward(a, b, o_idx):
return sparse_matrix_vector_multiplication_reduction(
a, None, b, None, o_idx, 10
)
# Create test data
a = mx.random.normal((100, 8, 16, 32))
b = mx.random.normal((10, 8, 32))
o_idx = mx.array(mx.random.randint(0, 10, (100,)))
# Compute gradient
grad_fn = mx.grad(lambda x: mx.sum(forward(x, b, o_idx)))
grad_a = grad_fn(a)
assert grad_a is not None
assert grad_a.shape == a.shape
print(f"Gradient shape correct: {grad_a.shape}")import time
def benchmark_mvmr(device="metal"):
"""Measure kernel performance"""
# Warm up
for _ in range(10):
out = sparse_matrix_vector_multiplication_reduction(
a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
)
# Time
times = []
for _ in range(100):
start = time.time()
out = sparse_matrix_vector_multiplication_reduction(
a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
)
mx.eval(out) # Force evaluation
times.append(time.time() - start)
mean_time = sum(times) / len(times) * 1000 # ms
std_time = (sum((t - mean_time/1000)**2 for t in times) / len(times))**0.5 * 1000
print(f"MVMR: {mean_time:.2f} ± {std_time:.2f} ms")
print(f"Throughput: {(T*G*M*C / (mean_time/1000) / 1e9):.2f} GFLOP/s")- Metal kernels require explicit grid management - unlike Triton's auto-tuning
- MLX's scatter operations are improving - use mlx-graphs but test thoroughly
- Ragged tensors need index-based patterns - no native ragged type
- Gradients require careful context setup - especially for custom Metal kernels
- Performance comes from optimization, not just correctness - profile early
- Block-wise computation - useful for memory-constrained operations
Start with Option A (mlx.fast.metal_kernel) for rapid prototyping, then migrate to Option B (C++ extensions) once you understand performance bottlenecks.