Skip to content

arch(metal): organize Metal shader library as standalone module #86

@ohdearquant

Description

@ohdearquant

Context

All Metal shaders currently live inline in crates/inference/src/forward/metal_qwen35.rs as Rust string constants (~14K lines, growing). This worked for bootstrapping but is hitting limits:

  • No MSL syntax highlighting or linting — shaders are opaque strings to the Rust compiler
  • No reuse across models — Qwen3.5-specific but GEMV/softmax/RMSNorm are generic
  • Hard to diff — shader changes appear as string diffs, not structured code changes
  • Can't benchmark shaders in isolation — coupled to the full forward pass
  • Can't study/adapt MLX techniques easily — no clean place to put reference implementations

Proposal

Extract Metal shaders into a dedicated module with .metal source files:

crates/inference/src/metal/
  shaders/
    gemv_q8.metal          # Q8_0 quantized matrix-vector multiply
    gemv_q4.metal          # Q4_0 quantized GEMV
    attention.metal         # decode attention (flash GQA)
    softmax.metal           # online softmax
    norm.metal              # RMSNorm, LayerNorm
    elementwise.metal       # SiLU, residual add
    lm_head.metal           # logit projection + optional argmax
  include/
    common.metal            # shared types, dequant helpers, SIMD-group utils
    quantize.metal          # Q8/Q4 block format definitions
  mod.rs                    # compile-time include_str! or build.rs metal-lib compilation
  dispatch.rs               # kernel launch wrappers (buffer binding, grid sizing)
  pipeline.rs               # MTLComputePipelineState caching

Benefits

  • .metal files get IDE syntax highlighting + Metal validation (xcrun metal -c)
  • Each kernel is individually reviewable (PR diffs show actual MSL, not string changes)
  • Generic kernels (GEMV, softmax) are reusable across model architectures
  • Can add a metal-shader-bench binary that benchmarks individual kernels
  • MLX kernel adaptations (issue perf(metal): study MLX Metal kernels for GEMV/attention optimization #85) have a clean landing zone
  • Build-time xcrun metal compilation catches shader errors before runtime

Migration plan

  1. Extract existing inline MSL strings into .metal files (no behavior change)
  2. Replace include_str! with metal-rs pipeline compilation
  3. Add xcrun metal -c validation to CI (macOS runner only)
  4. Refactor metal_qwen35.rs to use the shader library via dispatch.rs

Non-goals

  • Not adding new kernels in this issue — just restructuring existing ones
  • Not changing the Metal backend API surface
  • Not supporting non-Apple GPUs (WGPU backend is separate)

Priority

P2 — prerequisite for sustainable Metal kernel development (issue #85) but not blocking current work.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions