Add JACCL expert-parallel distributed inference#351
Open
machiabeli wants to merge 22 commits into
Open
Conversation
C99-compatible header exposing the JACCL Group API (init, all_sum, barrier, send, recv) as extern "C" functions. Dtype enum values match jaccl::Dtype from upstream group.h (Float32=11, Float16=9). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Thin C++ wrapper (~45 LOC) that bridges jaccl_shim.h to the JACCL standalone lib. Heap-allocates a shared_ptr<Group> behind an opaque void* handle. Verified: compiles with -std=c++20 against upstream jaccl headers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- `make JACCL=1` builds JACCL via CMake into build/jaccl/, compiles jaccl_shim.cpp, and links libjaccl.a + -lc++ into all 5 binaries - `make` (without JACCL=1) is unchanged -- no new deps, no regressions - JACCL_SRC defaults to ~/opensource/mlx/mlx/distributed/jaccl/lib - DS4_JACCL define passed to ds4.c for conditional compilation - build/ added to .gitignore Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Single-process test that verifies jaccl_is_available(), attempts non-strict init (NULL without env vars), and exercises the init/rank/size/free lifecycle if RDMA is present. Validates all shim symbols link correctly against libjaccl.a. Build: make JACCL=1 tests/test_jaccl_shim Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add bool distributed to ds4_engine_options, conditional jaccl_shim.h include under DS4_JACCL, and jaccl_group/world_size/rank/expert_start/ expert_end fields to struct ds4_engine. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In ds4_engine_open: after model load, init JACCL group from env vars when opt->distributed is set. Computes expert_start/expert_end for rank-based expert partitioning. Falls back to single-node defaults (world_size=1, all experts owned) when not distributed. In ds4_engine_close: free JACCL group before other teardown. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire --distributed flag through ds4 CLI and ds4-server arg parsing. Sets engine_opts.distributed = true. Only effective when built with JACCL=1; otherwise the flag is accepted but the init block is compiled out. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Core distributed expert parallelism change for the CPU reference path. - Add file-scope g_jaccl_group/g_expert_start/g_expert_end globals, set in engine_open, read by inner kernels (safe: single instance lock) - Add selected[] expert IDs and distributed ownership fields to matvec_q2_k_accum_ctx - In matvec_q2_k_accum_worker: skip experts not owned by this rank (expert_id < expert_start || expert_id >= expert_end) - After matvec_q2_k_experts_accum_prequant returns in both layer_routed_moe_one and layer_routed_moe_one_prealloc: all_sum the partial output across ranks via JACCL RDMA When not distributed (g_expert_start=0, g_expert_end=N_EXPERT), all experts pass ownership check and no all_sum is called -- single-node behavior is unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add expert ownership filter to matvec_q2_k_batch_accum_rows_worker: skip experts outside [g_expert_start, g_expert_end) in the inner loop. Insert all_sum on the batch MoE output (n_tok * DS4_N_EMBD floats) after accumulation completes in layer_routed_moe_batch. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
In distributed mode, skip the IQ2_XXS gate+up matmul for experts outside [g_expert_start, g_expert_end). Two paths modified: - Single-token: add skip_slot[] to matvec_iq2_xxs_mid_ctx, early-exit in worker with mid[idx]=0 (saves both matmul and SwiGLU compute) - Batch prefill: add expert_start/expert_end to batch_mid_ctx, skip non-owned active_expert entries with zero fill Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add metal_graph_mask_non_owned_experts() which reads router_selected and router_weights via ds4_gpu_tensor_contents() (StorageModeShared, zero-copy) and zeros weights for expert IDs outside [expert_start, expert_end). Called after ds4_gpu_router_select_tensor, before the fused MoE kernel. This zeroes output but does NOT skip GPU dispatch (correctness-first, shader early-exit is future optimization). Also define the batch variant for D3. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
After ds4_gpu_routed_moe_one_tensor returns (synchronous — calls ds4_gpu_finish_command_buffer internally), read routed_out via ds4_gpu_tensor_contents and all_sum DS4_N_EMBD floats across ranks. Placed before shared expert down-projection and HC expand. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wire the same mask + all_sum pattern for batch (prefill) Metal path: - Call metal_graph_mask_non_owned_experts_batch after ds4_gpu_router_select_batch_tensor to zero weights for n_tokens * DS4_N_EXPERT_USED entries outside owned range - Insert all_sum on batch_routed_out (n_tokens * DS4_N_EMBD floats) after ds4_gpu_routed_moe_batch_tensor, before shared expert section All 5 binaries link clean with JACCL=1 and without. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
layer_routed_moe_one_prealloc is called from routed_moe_tokens_worker via ds4_parallel_for, meaning multiple threads issue concurrent jaccl_group_all_sum calls on the same rank — undefined behavior for JACCL collectives. Remove the all_sum from layer_routed_moe_one_prealloc and add it at each call site after the function returns: - layer_ffn_one_decode_scratch: single all_sum on scratch->ffn_moe - layer_ffn_shared_batch: single all_sum on the full moe buffer after both the token-parallel and serial-fallback paths complete - Two Metal validation paths: all_sum on cpu_routed The non-prealloc layer_routed_moe_one keeps its own all_sum since it is never called from threads. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Discovers RDMA interfaces via asmi /links on each node, builds the JACCL_IBV_DEVICES NxN matrix, resolves coordinator LAN IP via Tailscale (not TB5 /30 IPs which cause error 60), and SSHs to each node to start ds4-server with --distributed. Supports 2-N node launches, graceful Ctrl-C shutdown, and custom port/context/binary paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three-phase test: automated single-node CPU baseline capture, manual distributed run instructions (requires 2 live nodes), and a --compare mode that diffs the outputs for token-exact match. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Covers single-node baseline, 2-node and 4-node distributed benchmark commands, expected overhead calculations (0.06ms/token generation, higher for batch prefill), per-layer profiling with DS4_*_PROFILE env vars, and a results recording table. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The router select kernel (ds4_gpu_router_select_tensor) runs in a batched command buffer that isn't committed until the batch ends. Our expert mask function reads router_selected/router_weights on CPU immediately after — reading stale GPU data, producing garbage output. Fix: ds4_gpu_end_commands() before mask, ds4_gpu_begin_commands() after. This breaks the Metal batch pipeline (14 tok/s vs 33 tok/s single-node) but produces correct distributed output. Optimization: move masking into a Metal kernel to avoid the CPU round-trip entirely. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace CPU expert ownership masking with two Metal compute kernels
(kernel_expert_mask, kernel_expert_mask_batch) that run in the same
command buffer as the router select kernel. Removes the
end_commands/begin_commands pair that cost ~0.5ms per layer per token
(~21ms total across 43 layers).
- Add expert_mask.metal: decode (n_expert_used threads) and batch
(n_tokens * n_expert_used threads) kernel variants
- Wire into ds4_metal.m runtime source assembly, pipeline init/cleanup
- Add C-callable ds4_gpu_expert_mask() and ds4_gpu_expert_mask_batch()
- Delete CPU metal_graph_mask_non_owned_experts{,_batch} from ds4.c
- Both JACCL=1 and standard builds compile clean, zero warnings
Net: +104 -41 LOC across 4 files.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add early return in pair_swiglu kernels (IQ2_XXS + Q4_K) when expert weight == 0.0f, skipping the gate+up matmul entirely. Currently does not measurably improve single-token decode (GPU is memory-bandwidth bound at 6 experts × 4096 dims). Expected to help with batch prefill and larger expert dimensions where compute dominates. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…fety Root cause: ds4_gpu_tensor_contents() was read BEFORE the Metal command buffer completed — each rank's GPU finished at a different time, so jaccl_group_all_sum() transmitted partially-computed MoE outputs over RDMA. Result: different garbage on each rank despite correct all_sum semantics. Fix: ds4_gpu_end_commands() before reading routed_out for all_sum. This flushes the Metal command buffer so the MoE kernel results are CPU-visible before JACCL reads them for RDMA transmission. Defense-in-depth (secondary fixes): - Blit-zero mid buffer before swiglu dispatch (ds4_metal.m) so stale GPU memory can't leak through early-exit paths - Zero mid[] row in swiglu kernel early-exit (moe.metal IQ2_XXS + Q4_K) when expert_mask zeros the weight for non-owned experts - Wire kernel_expert_compact dispatch (env-gated DS4_EXPERT_COMPACT=1) as an alternative to expert_mask that rewrites selected[] in-place Verified: 2-node hub+m3u4 distributed produces correct "Paris" output on both ranks. Single-node regression clean at 36.7 tok/s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Cast routed_out to float16 before all_sum, cast back after. Halves the RDMA payload from 28KB to 14KB per layer with no quality loss — the downstream HC post accumulates in float32. Benchmark: 14.5 → 16.3 tok/s generation (+12%) on 2-node hub+m3u4. Correctness: "Paris" output verified on both ranks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Expert-parallel distributed inference for DeepSeek V4 Flash/Pro via JACCL
(Apple's RDMA collective library for Thunderbolt 5). Each rank owns a subset
of the 256 MoE experts; after each layer's routed MoE forward pass,
jaccl_group_all_sum() merges partial results across ranks.
Opt-in:
make JACCL=1. Requires JACCL headers from ml-explore/mlx.What changed (1,559 lines across 17 files)
Core integration (ds4.c, ds4.h, ds4_cli.c, ds4_server.c):
--distributedCLI flag enables JACCL init at engine startupC shim (jaccl_shim.h, jaccl_shim.cpp):
Metal kernels (metal/expert_mask.metal, metal/moe.metal, ds4_metal.m):
Build (Makefile):
JACCL=1opt-in flag;JACCL_SRCdefaults to mlx submodule pathTools:
Key design decisions
Expert-parallel, not layer-parallel. All ranks run all 61 layers; only MoE
experts are split. The all_sum payload is ~28KB per layer (n_embd x float32),
which completes in <0.3ms at 5.5 GB/s RDMA bandwidth.
GPU sync before RDMA read. Metal GPU kernels are asynchronous. Without
ds4_gpu_end_commands() before reading routed_out, each rank transmits
partially-computed data. This was the root cause of the garbage output bug.
Defense-in-depth for expert masking. Three layers: (A) blit-zero the mid
buffer, (B) zero mid[] in kernel early-exit, (C) expert_compact rewrites
selected[] to owned-only. Any one suffices; all three prevent regressions.
Env var device matrix. JACCL_IBV_DEVICES points to a JSON file (not
inline JSON) containing the NxN RDMA device connectivity matrix.
Test plan