Skip to content

Add JACCL expert-parallel distributed inference#351

Open
machiabeli wants to merge 22 commits into
antirez:mainfrom
machiabeli:feat/jaccl-distributed
Open

Add JACCL expert-parallel distributed inference#351
machiabeli wants to merge 22 commits into
antirez:mainfrom
machiabeli:feat/jaccl-distributed

Conversation

@machiabeli
Copy link
Copy Markdown

Summary

Expert-parallel distributed inference for DeepSeek V4 Flash/Pro via JACCL
(Apple's RDMA collective library for Thunderbolt 5). Each rank owns a subset
of the 256 MoE experts; after each layer's routed MoE forward pass,
jaccl_group_all_sum() merges partial results across ranks.

Opt-in: make JACCL=1. Requires JACCL headers from ml-explore/mlx.

What changed (1,559 lines across 17 files)

Core integration (ds4.c, ds4.h, ds4_cli.c, ds4_server.c):

  • --distributed CLI flag enables JACCL init at engine startup
  • Expert range [start, end) partitioned by rank; 256/N experts per rank
  • all_sum inserted after every routed MoE output (CPU decode, CPU batch, Metal graph)
  • GPU sync (ds4_gpu_end_commands) before all_sum to ensure Metal kernel completion

C shim (jaccl_shim.h, jaccl_shim.cpp):

  • 46-line C wrapper around JACCL's C++ Group API (all_sum, barrier, send, recv)
  • Statically linked via libjaccl.a (built by CMake in build/jaccl/)

Metal kernels (metal/expert_mask.metal, metal/moe.metal, ds4_metal.m):

  • expert_mask: zero router weights for non-owned experts
  • expert_compact: rewrite selected[] to owned-only (env-gated DS4_EXPERT_COMPACT)
  • Blit-zero mid buffer before swiglu to prevent stale GPU memory reads
  • Zero mid[] in swiglu early-exit path (IQ2_XXS + Q4_K kernel variants)

Build (Makefile):

  • JACCL=1 opt-in flag; JACCL_SRC defaults to mlx submodule path
  • CMake builds libjaccl.a; C++ shim links against it

Tools:

  • distributed_launch.sh: multi-node launcher with asmi topology discovery
  • tests/test_distributed_correctness.sh: automated 2-node correctness test
  • tests/test_jaccl_shim.c: shim linkage smoke test
  • docs/distributed-benchmark.md: methodology and baseline results

Key design decisions

  1. Expert-parallel, not layer-parallel. All ranks run all 61 layers; only MoE
    experts are split. The all_sum payload is ~28KB per layer (n_embd x float32),
    which completes in <0.3ms at 5.5 GB/s RDMA bandwidth.

  2. GPU sync before RDMA read. Metal GPU kernels are asynchronous. Without
    ds4_gpu_end_commands() before reading routed_out, each rank transmits
    partially-computed data. This was the root cause of the garbage output bug.

  3. Defense-in-depth for expert masking. Three layers: (A) blit-zero the mid
    buffer, (B) zero mid[] in kernel early-exit, (C) expert_compact rewrites
    selected[] to owned-only. Any one suffices; all three prevent regressions.

  4. Env var device matrix. JACCL_IBV_DEVICES points to a JSON file (not
    inline JSON) containing the NxN RDMA device connectivity matrix.

Test plan

  • Single-node regression: correct output at 36.7 tok/s
  • 2-node hub+m3u4: correct output ("Paris") on both ranks at 14.5 tok/s
  • 4-node full mesh: 256/4 = 64 experts per rank
  • DeepSeek V4 Pro Q4 (384 experts, 838GB GGUF)
  • Sustained generation benchmark (100+ tokens)

machiabeli and others added 22 commits May 27, 2026 02:42
C99-compatible header exposing the JACCL Group API (init, all_sum, barrier,
send, recv) as extern "C" functions. Dtype enum values match jaccl::Dtype
from upstream group.h (Float32=11, Float16=9).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Thin C++ wrapper (~45 LOC) that bridges jaccl_shim.h to the JACCL
standalone lib. Heap-allocates a shared_ptr<Group> behind an opaque
void* handle. Verified: compiles with -std=c++20 against upstream
jaccl headers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- `make JACCL=1` builds JACCL via CMake into build/jaccl/, compiles
  jaccl_shim.cpp, and links libjaccl.a + -lc++ into all 5 binaries
- `make` (without JACCL=1) is unchanged -- no new deps, no regressions
- JACCL_SRC defaults to ~/opensource/mlx/mlx/distributed/jaccl/lib
- DS4_JACCL define passed to ds4.c for conditional compilation
- build/ added to .gitignore

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Single-process test that verifies jaccl_is_available(), attempts
non-strict init (NULL without env vars), and exercises the
init/rank/size/free lifecycle if RDMA is present. Validates all
shim symbols link correctly against libjaccl.a.

Build: make JACCL=1 tests/test_jaccl_shim

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add bool distributed to ds4_engine_options, conditional jaccl_shim.h
include under DS4_JACCL, and jaccl_group/world_size/rank/expert_start/
expert_end fields to struct ds4_engine.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In ds4_engine_open: after model load, init JACCL group from env vars
when opt->distributed is set. Computes expert_start/expert_end for
rank-based expert partitioning. Falls back to single-node defaults
(world_size=1, all experts owned) when not distributed.

In ds4_engine_close: free JACCL group before other teardown.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire --distributed flag through ds4 CLI and ds4-server arg parsing.
Sets engine_opts.distributed = true. Only effective when built with
JACCL=1; otherwise the flag is accepted but the init block is compiled
out.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Core distributed expert parallelism change for the CPU reference path.

- Add file-scope g_jaccl_group/g_expert_start/g_expert_end globals,
  set in engine_open, read by inner kernels (safe: single instance lock)
- Add selected[] expert IDs and distributed ownership fields to
  matvec_q2_k_accum_ctx
- In matvec_q2_k_accum_worker: skip experts not owned by this rank
  (expert_id < expert_start || expert_id >= expert_end)
- After matvec_q2_k_experts_accum_prequant returns in both
  layer_routed_moe_one and layer_routed_moe_one_prealloc:
  all_sum the partial output across ranks via JACCL RDMA

When not distributed (g_expert_start=0, g_expert_end=N_EXPERT), all
experts pass ownership check and no all_sum is called -- single-node
behavior is unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add expert ownership filter to matvec_q2_k_batch_accum_rows_worker:
skip experts outside [g_expert_start, g_expert_end) in the inner loop.
Insert all_sum on the batch MoE output (n_tok * DS4_N_EMBD floats)
after accumulation completes in layer_routed_moe_batch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In distributed mode, skip the IQ2_XXS gate+up matmul for experts
outside [g_expert_start, g_expert_end). Two paths modified:

- Single-token: add skip_slot[] to matvec_iq2_xxs_mid_ctx, early-exit
  in worker with mid[idx]=0 (saves both matmul and SwiGLU compute)
- Batch prefill: add expert_start/expert_end to batch_mid_ctx, skip
  non-owned active_expert entries with zero fill

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add metal_graph_mask_non_owned_experts() which reads router_selected
and router_weights via ds4_gpu_tensor_contents() (StorageModeShared,
zero-copy) and zeros weights for expert IDs outside [expert_start,
expert_end). Called after ds4_gpu_router_select_tensor, before the
fused MoE kernel. This zeroes output but does NOT skip GPU dispatch
(correctness-first, shader early-exit is future optimization).

Also define the batch variant for D3.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
After ds4_gpu_routed_moe_one_tensor returns (synchronous — calls
ds4_gpu_finish_command_buffer internally), read routed_out via
ds4_gpu_tensor_contents and all_sum DS4_N_EMBD floats across ranks.
Placed before shared expert down-projection and HC expand.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire the same mask + all_sum pattern for batch (prefill) Metal path:

- Call metal_graph_mask_non_owned_experts_batch after
  ds4_gpu_router_select_batch_tensor to zero weights for
  n_tokens * DS4_N_EXPERT_USED entries outside owned range
- Insert all_sum on batch_routed_out (n_tokens * DS4_N_EMBD floats)
  after ds4_gpu_routed_moe_batch_tensor, before shared expert section

All 5 binaries link clean with JACCL=1 and without.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
layer_routed_moe_one_prealloc is called from routed_moe_tokens_worker
via ds4_parallel_for, meaning multiple threads issue concurrent
jaccl_group_all_sum calls on the same rank — undefined behavior for
JACCL collectives.

Remove the all_sum from layer_routed_moe_one_prealloc and add it at
each call site after the function returns:
- layer_ffn_one_decode_scratch: single all_sum on scratch->ffn_moe
- layer_ffn_shared_batch: single all_sum on the full moe buffer
  after both the token-parallel and serial-fallback paths complete
- Two Metal validation paths: all_sum on cpu_routed

The non-prealloc layer_routed_moe_one keeps its own all_sum since it
is never called from threads.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Discovers RDMA interfaces via asmi /links on each node, builds the
JACCL_IBV_DEVICES NxN matrix, resolves coordinator LAN IP via
Tailscale (not TB5 /30 IPs which cause error 60), and SSHs to each
node to start ds4-server with --distributed.

Supports 2-N node launches, graceful Ctrl-C shutdown, and custom
port/context/binary paths.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three-phase test: automated single-node CPU baseline capture, manual
distributed run instructions (requires 2 live nodes), and a --compare
mode that diffs the outputs for token-exact match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Covers single-node baseline, 2-node and 4-node distributed benchmark
commands, expected overhead calculations (0.06ms/token generation,
higher for batch prefill), per-layer profiling with DS4_*_PROFILE env
vars, and a results recording table.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The router select kernel (ds4_gpu_router_select_tensor) runs in a batched
command buffer that isn't committed until the batch ends. Our expert mask
function reads router_selected/router_weights on CPU immediately after —
reading stale GPU data, producing garbage output.

Fix: ds4_gpu_end_commands() before mask, ds4_gpu_begin_commands() after.
This breaks the Metal batch pipeline (14 tok/s vs 33 tok/s single-node)
but produces correct distributed output. Optimization: move masking into
a Metal kernel to avoid the CPU round-trip entirely.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace CPU expert ownership masking with two Metal compute kernels
(kernel_expert_mask, kernel_expert_mask_batch) that run in the same
command buffer as the router select kernel. Removes the
end_commands/begin_commands pair that cost ~0.5ms per layer per token
(~21ms total across 43 layers).

- Add expert_mask.metal: decode (n_expert_used threads) and batch
  (n_tokens * n_expert_used threads) kernel variants
- Wire into ds4_metal.m runtime source assembly, pipeline init/cleanup
- Add C-callable ds4_gpu_expert_mask() and ds4_gpu_expert_mask_batch()
- Delete CPU metal_graph_mask_non_owned_experts{,_batch} from ds4.c
- Both JACCL=1 and standard builds compile clean, zero warnings

Net: +104 -41 LOC across 4 files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add early return in pair_swiglu kernels (IQ2_XXS + Q4_K) when expert
weight == 0.0f, skipping the gate+up matmul entirely. Currently does
not measurably improve single-token decode (GPU is memory-bandwidth
bound at 6 experts × 4096 dims). Expected to help with batch prefill
and larger expert dimensions where compute dominates.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…fety

Root cause: ds4_gpu_tensor_contents() was read BEFORE the Metal command
buffer completed — each rank's GPU finished at a different time, so
jaccl_group_all_sum() transmitted partially-computed MoE outputs over RDMA.
Result: different garbage on each rank despite correct all_sum semantics.

Fix: ds4_gpu_end_commands() before reading routed_out for all_sum. This
flushes the Metal command buffer so the MoE kernel results are CPU-visible
before JACCL reads them for RDMA transmission.

Defense-in-depth (secondary fixes):
- Blit-zero mid buffer before swiglu dispatch (ds4_metal.m) so stale GPU
  memory can't leak through early-exit paths
- Zero mid[] row in swiglu kernel early-exit (moe.metal IQ2_XXS + Q4_K)
  when expert_mask zeros the weight for non-owned experts
- Wire kernel_expert_compact dispatch (env-gated DS4_EXPERT_COMPACT=1)
  as an alternative to expert_mask that rewrites selected[] in-place

Verified: 2-node hub+m3u4 distributed produces correct "Paris" output
on both ranks. Single-node regression clean at 36.7 tok/s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Cast routed_out to float16 before all_sum, cast back after. Halves the
RDMA payload from 28KB to 14KB per layer with no quality loss — the
downstream HC post accumulates in float32.

Benchmark: 14.5 → 16.3 tok/s generation (+12%) on 2-node hub+m3u4.
Correctness: "Paris" output verified on both ranks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant