Skip to content

Latest commit

 

History

History
831 lines (651 loc) · 22.9 KB

File metadata and controls

831 lines (651 loc) · 22.9 KB

MLX Implementation Guide: Porting Triton Kernels to Metal

Practical Code Examples & Patterns


1. MVMR Kernel: From Triton to Metal

Your Current Triton Implementation (Conceptual)

# From your mvmr_triton_kernel.py pattern:
# Input:  a(T, G, M, C), a_idx(T,), b(B, G, C), b_idx(B,)
# Output: o(K, G, M), o_idx(T,)
#
# Operation: For each t in T:
#   out_pos = o_idx[t]
#   out[out_pos, :, :] += a[t] @ b[b_idx[t]].T

MLX Metal Kernel Version

Option 1: Using mlx.fast.metal_kernel() (Python-based)

# mlx_ops/mvmr.py
import mlx.core as mx
from typing import Tuple

def mvmr_metal_kernel_source():
    """Metal kernel for sparse matrix-vector multiplication with reduction"""
    return """
    #include <metal_stdlib>
    using namespace metal;

    kernel void mvmr_kernel(
        device const float* a [[buffer(0)]],          // (T, G, M, C) flattened
        device const uint32_t* a_idx [[buffer(1)]],   // (T,)
        device const float* b [[buffer(2)]],          // (B, G, C) flattened
        device const uint32_t* b_idx [[buffer(3)]],   // (B,)
        device const uint32_t* o_idx [[buffer(4)]],   // (T,)
        device float* o [[buffer(5)]],                // (K, G, M) flattened
        device const uint32_t* dims [[buffer(6)]],    // [T, G, M, C, K, B]
        uint3 gid [[threadgroup_position_in_grid]],
        uint3 tid [[thread_position_in_threadgroup]],
        uint3 tpg [[threads_per_threadgroup]]
    ) {
        uint linear_idx = gid.x * tpg.x + tid.x;

        // dims layout: [T, G, M, C, K, B]
        uint T = dims[0], G = dims[1], M = dims[2], C = dims[3];
        uint K = dims[4];  // output size

        // Each thread handles one (t, g, m) triple
        if (linear_idx >= T * G * M) return;

        uint t = linear_idx / (G * M);
        uint gm = linear_idx % (G * M);
        uint g = gm / M;
        uint m = gm % M;

        // Get indices
        uint b_pos = b_idx[t];  // which b vector to use
        uint out_pos = o_idx[t];

        // Compute dot product a[t,g,m,:] @ b[b_pos,g,:]
        float sum = 0.0;
        for (uint c = 0; c < C; c++) {
            uint a_offset = (((t * G + g) * M + m) * C) + c;
            uint b_offset = ((b_pos * G + g) * C) + c;
            sum += a[a_offset] * b[b_offset];
        }

        // Atomic add to output (handles duplicates in o_idx)
        uint out_offset = ((out_pos * G + g) * M) + m;
        atomic_fetch_add_explicit(
            (device atomic_float*)&o[out_offset],
            sum,
            memory_order_relaxed
        );
    }
    """

def sparse_matrix_vector_multiplication_reduction(
    a: mx.array,      # (T, G, M, C)
    a_idx: mx.array,  # (T,)
    b: mx.array,      # (B, G, C)
    b_idx: mx.array,  # (B,) - actually not used but mirrors PyTorch signature
    o_idx: mx.array,  # (T,)
    n_o: int          # output size K
) -> mx.array:
    """
    Sparse matrix-vector multiplication with reduction.

    For each t in range(T):
        out[o_idx[t], :, :] += a[t] @ b[b_idx[t]].T
    """

    T, G, M, C = a.shape
    B = b.shape[0]

    # Create output tensor
    o = mx.zeros((n_o, G, M), dtype=a.dtype)

    # Create Metal kernel
    kernel = mx.fast.metal_kernel(
        name="mvmr_kernel",
        source=mvmr_metal_kernel_source(),
        input_shapes=[
            a.shape,      # a
            a_idx.shape,  # a_idx
            b.shape,      # b
            b_idx.shape,  # b_idx (not used in kernel)
            o_idx.shape,  # o_idx
            (n_o, G, M),  # o
        ],
        output_shape=(n_o, G, M),
        output_dtype=a.dtype
    )

    # Calculate grid size
    threads_per_block = 256
    total_elements = T * G * M
    blocks = (total_elements + threads_per_block - 1) // threads_per_block

    # Execute kernel
    o = kernel(
        a.reshape(-1),
        a_idx,
        b.reshape(-1),
        b_idx,
        o_idx,
        o.reshape(-1),
        mx.array([T, G, M, C, n_o, B], dtype=mx.uint32)
    )

    return o.reshape((n_o, G, M))

Option 2: Using C++ Extensions (More Control)

// mlx_ext/mvmr.cpp
#include <mlx/mlx.h>
#include <metal_stdlib>

using namespace mlx::core;
using namespace mlx::core::metal;

void mvmr_gpu(
    const array& a,           // (T, G, M, C)
    const array& a_idx,       // (T,)
    const array& b,           // (B, G, C)
    const array& b_idx,       // (B,)
    const array& o_idx,       // (T,)
    array& output,            // (K, G, M)
    int n_o                   // output size K
) {
    // Get device and command queue
    metal::Device& device = metal::device();
    auto& queue = device.get_command_queue();

    // Create Metal command encoder
    auto encoder = device.get_command_encoder();

    // Load kernel
    auto kernelLib = device.get_kernel("mvmr_kernel");

    // Set up buffers
    encoder.setBuffer(a.data<float>(), 0);      // buffer(0)
    encoder.setBuffer(a_idx.data<uint32_t>(), 1);
    encoder.setBuffer(b.data<float>(), 2);
    encoder.setBuffer(b_idx.data<uint32_t>(), 3);
    encoder.setBuffer(o_idx.data<uint32_t>(), 4);
    encoder.setBuffer(output.data<float>(), 5);

    // Calculate grid
    int T = a.shape(0), G = a.shape(1), M = a.shape(2);
    int total = T * G * M;

    MTLSize threadsPerGroup = MTLSizeMake(256, 1, 1);
    MTLSize numThreadgroups = MTLSizeMake(
        (total + 255) / 256, 1, 1
    );

    // Dispatch
    encoder.dispatchThreadgroups(numThreadgroups, threadsPerGroup);
    encoder.endEncoding();
    queue.commit();
}

// Python binding
void init_mvmr(py::module& m) {
    m.def("sparse_matrix_vector_multiplication_reduction",
          &mvmr_gpu,
          "Sparse matrix-vector multiply with reduction");
}

Key Considerations

  1. Buffer Management

    • MLX handles contiguity checks
    • Flatten high-dimensional tensors for Metal
    • Use uint32_t for indices (GPU-compatible)
  2. Atomic Operations

    // Metal atomic for float32 reduction
    atomic_fetch_add_explicit(
        (device atomic_float*)&output[idx],
        value,
        memory_order_relaxed
    );
  3. Grid Calculation

    # Match your Triton pattern:
    # grid = (cdiv(T, L) * cdiv(G, BG) * cdiv(M, BM) * cdiv(C, BC),)
    
    # Simplified for Metal:
    total_work = T * G * M  # work items
    threads_per_block = 256
    blocks = (total_work + threads_per_block - 1) // threads_per_block
  4. Gradient Support

    • MLX's metal_kernel automatically creates vjp
    • For complex gradients, implement custom vjp
    • Reference VVOR kernel for backward pass pattern

2. Implementing Segment Reduce

Your Pattern (Triton)

# From large_segment_reduce_triton.py
# Input: x (T, C), lengths (K,)
# Output: y (K, C) - each segment's reduction
# Uses repeat_interleave to expand lengths to indices

MLX Implementation Option A: Python with Scatter

# mlx_ops/segment_reduce.py
import mlx.core as mx
from mlx_graphs.utils.scatter import scatter

def segment_reduce(
    x: mx.array,           # (T, C)
    lengths: mx.array,     # (K,) - length of each segment
    operation: str = "sum" # or "max", "mean", "min"
) -> mx.array:
    """
    Reduce segments of x based on lengths.

    Example:
        x = [a, b, c, d, e]
        lengths = [2, 3]  # first 2 elements, then 3
        sum: [a+b, c+d+e]
    """

    T, C = x.shape
    K = lengths.shape[0]

    # Create segment indices: [0, 0, 1, 1, 1]
    indices = mx.repeat(
        mx.arange(K),      # [0, 1, 2, ..., K-1]
        lengths            # repeat by lengths
    )

    # Use scatter for reduction
    if operation == "sum":
        y = scatter(
            src=x,
            index=indices,
            dim=0,
            dim_size=K,
            reduce="sum"
        )
    elif operation == "max":
        y = scatter(
            src=x,
            index=indices,
            dim=0,
            dim_size=K,
            reduce="max"
        )
    elif operation == "mean":
        y = scatter(
            src=x,
            index=indices,
            dim=0,
            dim_size=K,
            reduce="mean"
        )
    else:
        raise ValueError(f"Unknown operation: {operation}")

    return y

MLX Implementation Option B: Custom Metal Kernel

# For large segments or performance-critical paths
def large_segment_reduce_metal_kernel():
    return """
    #include <metal_stdlib>
    using namespace metal;

    kernel void segment_reduce_sum(
        device const float* x [[buffer(0)]],       // (T, C)
        device const uint32_t* segment_ids [[buffer(1)]],  // (T,)
        device const uint32_t* offsets [[buffer(2)]],      // (K,)
        device const uint32_t* lengths [[buffer(3)]],      // (K,)
        device float* y [[buffer(4)]],             // (K, C)
        uint3 gid [[threadgroup_position_in_grid]],
        uint3 tid [[thread_position_in_threadgroup]],
        uint3 tpg [[threads_per_threadgroup]]
    ) {
        uint linear_idx = gid.x * tpg.x + tid.x;

        uint T = x.size() / x.shape[1];
        uint C = x.shape[1];
        uint K = y.shape[0];

        // Each thread handles one output element y[k, c]
        if (linear_idx >= K * C) return;

        uint k = linear_idx / C;
        uint c = linear_idx % C;

        uint start = offsets[k];
        uint len = lengths[k];

        // Sum segment
        float sum = 0.0;
        for (uint i = 0; i < len; i++) {
            uint t = start + i;
            sum += x[t * C + c];
        }

        y[k * C + c] = sum;
    }
    """

Critical Difference: Sorted vs Unsorted Segments

# SORTED segments (what MLX scatter assumes):
# Input: x = [a, b, c, d, e], lengths = [2, 3]
# Works: segments are contiguous

# UNSORTED/INDEXED segments (what your code might need):
# Input: x = values, indices = [0, 1, 0, 2, 1]
# Needs: indexed_segment_reduce with index mapping

# MLX solution for unsorted:
# Use custom Metal kernel OR
# Pre-sort data + track permutation for backward pass

3. Indexed Distance Computation (Replacing cdist)

PyTorch Original

# Compute pairwise distances
distances = torch.cdist(points_a, points_b)  # (N, M)

MLX Implementation

def indexed_distance(
    a: mx.array,        # (N, D) or variable-length
    b: mx.array,        # (M, D) or variable-length
    a_indices: mx.array = None,  # if variable-length
    b_indices: mx.array = None,
    metric: str = "euclidean"
) -> mx.array:
    """
    Compute pairwise distances with optional indexing.

    For point clouds, often you want:
    distances[i, j] = ||a[a_indices[i]] - b[b_indices[j]]||
    """

    if a_indices is None:
        # Standard pairwise distance
        a_expanded = mx.expand_dims(a, 1)  # (N, 1, D)
        b_expanded = mx.expand_dims(b, 0)  # (1, M, D)
        diff = a_expanded - b_expanded      # (N, M, D)
    else:
        # Gather with indexing
        a_selected = a[a_indices]           # (N, D)
        b_selected = b[b_indices]           # (M, D)

        a_expanded = mx.expand_dims(a_selected, 1)
        b_expanded = mx.expand_dims(b_selected, 0)
        diff = a_expanded - b_expanded

    if metric == "euclidean":
        # ||a - b||_2 = sqrt(sum((a-b)^2))
        distances = mx.sqrt(mx.sum(diff**2, axis=-1))
    elif metric == "squared_euclidean":
        distances = mx.sum(diff**2, axis=-1)
    elif metric == "cosine":
        a_norm = mx.sqrt(mx.sum(a**2, axis=-1, keepdims=True))
        b_norm = mx.sqrt(mx.sum(b**2, axis=-1, keepdims=True))
        ab = mx.sum(diff, axis=-1)  # not correct, for illustration
        distances = 1.0 - (ab / (a_norm * b_norm))
    else:
        raise ValueError(f"Unknown metric: {metric}")

    return distances

Performance Note

  • Dense implementation above requires memory O(NMD)
  • For large point clouds, may exceed GPU memory
  • Alternative: Compute in blocks or use kernel

Blocked Distance Computation (Memory-Efficient)

def indexed_distance_blocked(
    a: mx.array,
    b: mx.array,
    block_size: int = 512,
    metric: str = "euclidean"
) -> mx.array:
    """Process in blocks to reduce memory usage"""

    N, D = a.shape
    M = b.shape[0]

    result = mx.zeros((N, M), dtype=a.dtype)

    for i in range(0, N, block_size):
        end_i = min(i + block_size, N)

        a_block = a[i:end_i]
        a_exp = mx.expand_dims(a_block, 1)  # (block_size, 1, D)
        b_exp = mx.expand_dims(b, 0)        # (1, M, D)

        diff = a_exp - b_exp

        if metric == "euclidean":
            distances = mx.sqrt(mx.sum(diff**2, axis=-1))
        else:
            distances = mx.sum(diff**2, axis=-1)

        result[i:end_i] = distances

    return result

4. Custom Autograd Implementation

Basic Pattern: Custom Operation with Gradients

from mlx import core as mx
from typing import Callable, List

def custom_sparse_op_with_grad(
    forward_fn: Callable,
    backward_fn: Callable,
    *args
):
    """Wrapper for custom ops with manual gradient implementation"""

    class CustomOp:
        def __init__(self, forward_fn, backward_fn):
            self.forward_fn = forward_fn
            self.backward_fn = backward_fn
            self._saved_tensors = None

        def __call__(self, *args):
            # Save for backward
            self._saved_tensors = args

            # Forward pass
            return self.forward_fn(*args)

        def vjp(self, *grad_outputs):
            """Vector-Jacobian product (backward)"""
            return self.backward_fn(self._saved_tensors, grad_outputs[0])

    op = CustomOp(forward_fn, backward_fn)
    return op(*args)

Example: MVMR with Custom Gradient

def mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o):
    """Forward pass: sparse mat-vec multiply"""
    # Call your Metal kernel here
    return mx.zeros((n_o, a.shape[1], a.shape[2]))

def mvmr_backward(saved_tensors, grad_output):
    """Backward pass: compute gradients w.r.t. a and b"""
    a, a_idx, b, b_idx, o_idx, n_o = saved_tensors

    # grad_a computed via VVOR pattern
    # grad_b computed via transposed MVMR

    grad_a = None  # sparse_vector_vector_outer_product_reduction(...)
    grad_b = None  # sparse_matrix_vector_multiplication_reduction(...)

    return (grad_a, None, grad_b, None, None, None)

# Use mlx.custom_vjp for cleaner API
@mx.custom_vjp
def mvmr_with_grad(a, a_idx, b, b_idx, o_idx, n_o):
    return mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o)

def mvmr_fwd(a, a_idx, b, b_idx, o_idx, n_o):
    return mvmr_forward(a, a_idx, b, b_idx, o_idx, n_o), (a, a_idx, b, b_idx, o_idx)

def mvmr_bwd(aux, g):
    a, a_idx, b, b_idx, o_idx = aux
    return mvmr_backward((a, a_idx, b, b_idx, o_idx, len(set(o_idx))), g)

mvmr_with_grad.defvjp(mvmr_fwd, mvmr_bwd)

5. MLX-Compatible Data Structures for Point Clouds

Ragged Tensors (Variable-Length Sequences)

class PointCloudBatch:
    """Represent variable-length point clouds without padding"""

    def __init__(self, positions, features=None, lengths=None, indices=None):
        self.positions = positions      # (T, 3) - flattened points
        self.features = features        # (T, C) - flattened features
        self.lengths = lengths          # (B,) - points per cloud
        self.indices = indices or mx.repeat(mx.arange(len(lengths)), lengths)
        self.batch_size = len(lengths)
        self.total_points = positions.shape[0]

    def to_dense(self, max_len=None):
        """Convert to padded dense tensor"""
        if max_len is None:
            max_len = mx.max(self.lengths).item()

        B, C = self.batch_size, self.positions.shape[1]
        dense = mx.zeros((B, max_len, C))

        for b in range(B):
            start = mx.sum(self.lengths[:b])
            end = start + self.lengths[b]
            dense[b, :self.lengths[b].item()] = self.positions[start:end]

        return dense

    def gather_by_index(self, idx):
        """Gather points by flattened index"""
        return self.positions[idx]

    def segment_reduce(self, values, operation="sum"):
        """Reduce values by point cloud"""
        return segment_reduce(values, self.lengths, operation)

# Usage:
# pc_batch = PointCloudBatch(
#     positions=mx.random.normal((1000, 3)),
#     lengths=mx.array([300, 400, 300])  # 3 point clouds
# )

Offset-Based Indexing (Better Performance)

class OffsetPointCloudBatch:
    """Even more efficient: use offsets instead of lengths"""

    def __init__(self, positions, features=None, offsets=None):
        self.positions = positions      # (T, 3)
        self.features = features        # (T, C)

        if offsets is None:
            # Compute from positions
            offsets = mx.array([0, positions.shape[0]])

        self.offsets = offsets          # (B+1,) cumulative
        self.batch_size = len(offsets) - 1

    def get_points(self, batch_idx):
        """Get all points for one cloud"""
        start = self.offsets[batch_idx]
        end = self.offsets[batch_idx + 1]
        return self.positions[start:end]

    def as_lengths(self):
        """Convert offsets to lengths"""
        return mx.diff(self.offsets)

6. Integration with MLX Models

MLX Module Pattern

import mlx.nn as nn
import mlx.core as mx

class PointCNNConv(nn.Module):
    """MLX version of PointCNN convolution"""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # MLX uses simple linear layers
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, 64),
            nn.ReLU(),
            nn.Linear(64, out_channels)
        )

    def forward(
        self,
        x: mx.array,              # (T, C_in)
        coords: mx.array,         # (T, 3)
        neighbors_idx: mx.array,  # (T, K) - indices of K neighbors
        neighbors_rel_pos: mx.array,  # (T, K, 3)
        batch_indices: mx.array = None
    ) -> mx.array:
        """
        x: point features
        coords: point coordinates
        neighbors_idx: k-nearest neighbors
        neighbors_rel_pos: relative positions to neighbors
        """

        T, K = neighbors_idx.shape

        # Gather features of neighbors
        neighbors_x = x[neighbors_idx]  # (T, K, C_in)

        # Expand relative positions
        neighbors_rel_expanded = mx.expand_dims(
            neighbors_rel_pos, -1
        )  # (T, K, 3, 1)

        # Apply MLP to each neighbor
        # This is simplified - real PointCNN does X-transformation
        neighbor_features = self.mlp(neighbors_x)  # (T, K, C_out)

        # Aggregate (max pool over neighbors)
        aggregated = mx.max(neighbor_features, axis=1)  # (T, C_out)

        return aggregated

class PointCNNEncoder(nn.Module):
    """Full PointCNN encoder stack"""

    def __init__(self, num_layers: int = 3):
        super().__init__()

        self.layers = nn.Sequential(*[
            PointCNNConv(64 if i > 0 else 3, 64)
            for i in range(num_layers)
        ])

    def forward(self, x: mx.array, coords: mx.array) -> mx.array:
        # In real implementation, compute k-NN before each layer
        return self.layers(x)

Training Loop

def train_pointcnn(model, data_loader, optimizer, loss_fn, epochs=10):
    """MLX training pattern"""

    for epoch in range(epochs):
        total_loss = 0.0

        for batch in data_loader:
            # Forward pass
            def loss_fn_wrapper(params):
                model.update(params)
                logits = model(batch['positions'], batch['features'])
                loss = loss_fn(logits, batch['labels'])
                return loss

            # Compute gradients
            loss, grads = mx.value_and_grad(loss_fn_wrapper)(
                model.parameters()
            )

            # Update parameters
            optimizer.update(model, grads)

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(data_loader)}")

7. Testing & Validation

Correctness Testing: MLX vs PyTorch

import torch
import mlx.core as mx

def test_mvmr_correctness():
    """Verify MLX kernel matches PyTorch reference"""

    # Random test data
    T, G, M, C, K = 100, 8, 16, 32, 64

    # Create test arrays
    a_pt = torch.randn(T, G, M, C)
    a_idx_pt = torch.arange(T)
    b_pt = torch.randn(K, G, C)
    b_idx_pt = torch.randint(0, K, (T,))
    o_idx_pt = torch.randint(0, K, (T,))

    # PyTorch reference
    o_pt = torch.zeros(K, G, M)
    for t in range(T):
        o_pt[o_idx_pt[t]] += a_pt[t] @ b_pt[b_idx_pt[t]].transpose(0, 1)

    # MLX version
    a_mx = mx.array(a_pt.numpy())
    b_mx = mx.array(b_pt.numpy())
    o_idx_mx = mx.array(o_idx_pt.numpy())
    b_idx_mx = mx.array(b_idx_pt.numpy())

    o_mx = sparse_matrix_vector_multiplication_reduction(
        a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
    )

    # Compare
    o_mx_np = mx.array(o_mx).tolist()
    o_pt_np = o_pt.numpy()

    error = max(abs(o_mx_np[i] - o_pt_np[i]) for i in range(K))
    print(f"Max error: {error}")
    assert error < 1e-5, f"Correctness check failed: {error}"

def test_gradient_flow():
    """Verify gradients flow correctly"""

    def forward(a, b, o_idx):
        return sparse_matrix_vector_multiplication_reduction(
            a, None, b, None, o_idx, 10
        )

    # Create test data
    a = mx.random.normal((100, 8, 16, 32))
    b = mx.random.normal((10, 8, 32))
    o_idx = mx.array(mx.random.randint(0, 10, (100,)))

    # Compute gradient
    grad_fn = mx.grad(lambda x: mx.sum(forward(x, b, o_idx)))
    grad_a = grad_fn(a)

    assert grad_a is not None
    assert grad_a.shape == a.shape
    print(f"Gradient shape correct: {grad_a.shape}")

Performance Profiling

import time

def benchmark_mvmr(device="metal"):
    """Measure kernel performance"""

    # Warm up
    for _ in range(10):
        out = sparse_matrix_vector_multiplication_reduction(
            a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
        )

    # Time
    times = []
    for _ in range(100):
        start = time.time()
        out = sparse_matrix_vector_multiplication_reduction(
            a_mx, None, b_mx, b_idx_mx, o_idx_mx, K
        )
        mx.eval(out)  # Force evaluation
        times.append(time.time() - start)

    mean_time = sum(times) / len(times) * 1000  # ms
    std_time = (sum((t - mean_time/1000)**2 for t in times) / len(times))**0.5 * 1000

    print(f"MVMR: {mean_time:.2f} ± {std_time:.2f} ms")
    print(f"Throughput: {(T*G*M*C / (mean_time/1000) / 1e9):.2f} GFLOP/s")

Key Takeaways

  1. Metal kernels require explicit grid management - unlike Triton's auto-tuning
  2. MLX's scatter operations are improving - use mlx-graphs but test thoroughly
  3. Ragged tensors need index-based patterns - no native ragged type
  4. Gradients require careful context setup - especially for custom Metal kernels
  5. Performance comes from optimization, not just correctness - profile early
  6. Block-wise computation - useful for memory-constrained operations

Recommended Starting Point

Start with Option A (mlx.fast.metal_kernel) for rapid prototyping, then migrate to Option B (C++ extensions) once you understand performance bottlenecks.