Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
63e9504
[jaccl] add C shim header for JACCL integration
machiabeli May 26, 2026
f6be878
[jaccl] add C++ shim implementation wrapping JACCL Group API
machiabeli May 26, 2026
65bf0f6
[jaccl] integrate JACCL static lib build into Makefile (opt-in JACCL=1)
machiabeli May 26, 2026
460b9b5
[jaccl] add shim linkage test
machiabeli May 26, 2026
1ff84bf
[jaccl] add distributed state fields to ds4_engine
machiabeli May 26, 2026
f7f0062
[jaccl] init/teardown JACCL in engine lifecycle
machiabeli May 26, 2026
e307a63
[jaccl] add --distributed CLI flag
machiabeli May 26, 2026
bb05afe
[jaccl] distribute expert accumulation with all_sum (CPU path)
machiabeli May 26, 2026
1dcad2e
[jaccl] distribute batch expert accumulation (CPU prefill path)
machiabeli May 26, 2026
cef9457
[jaccl] skip non-owned expert gate/up projections
machiabeli May 26, 2026
158df1c
[jaccl] mask non-owned experts before fused Metal MoE kernel
machiabeli May 26, 2026
30d8e0b
[jaccl] insert all_sum on routed_out after fused Metal MoE kernel
machiabeli May 26, 2026
603a57e
[jaccl] distribute batch Metal MoE path
machiabeli May 26, 2026
78ab6ab
[jaccl] fix thread-safety: move all_sum out of prealloc into callers
machiabeli May 26, 2026
c28c28a
[jaccl] add distributed launch script
machiabeli May 26, 2026
b9dd40c
[jaccl] add distributed correctness test
machiabeli May 26, 2026
8fca82e
[jaccl] document distributed benchmark methodology
machiabeli May 26, 2026
61e5481
[jaccl] fix Metal distributed: flush router kernel before CPU mask read
machiabeli May 26, 2026
3868e24
[jaccl] Metal expert mask kernel — eliminates batch break overhead
machiabeli May 26, 2026
f58da3a
[jaccl] Metal shader early-exit for non-owned experts
machiabeli May 26, 2026
b5a4d21
[jaccl] fix distributed inference: GPU sync + defense-in-depth MoE sa…
machiabeli Jun 7, 2026
4bc4304
[jaccl] fp16 all_sum — halve RDMA payload for 12% generation speedup
machiabeli Jun 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
/ds4_native
/ds4_server_test
/ds4_test
/tests/test_jaccl_shim
/ds4flash.gguf
/TODO.md
/gguf/
*.o
*.dSYM/
/build/
__pycache__/
*.pyc
/misc/
Expand Down
45 changes: 41 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CC ?= cc
CXX ?= c++
UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S),Darwin)
Expand All @@ -9,14 +10,31 @@ endif

DEBUG_FLAGS ?= -g
CFLAGS ?= -O3 -ffast-math $(DEBUG_FLAGS) $(NATIVE_CPU_FLAG) -Wall -Wextra -std=c99
CXXFLAGS ?= -O3 -ffast-math $(DEBUG_FLAGS) $(NATIVE_CPU_FLAG) -Wall -Wextra -std=c++20
OBJCFLAGS ?= -O3 -ffast-math $(DEBUG_FLAGS) $(NATIVE_CPU_FLAG) -Wall -Wextra -fobjc-arc

LDLIBS ?= -lm -pthread
METAL_SRCS := $(wildcard metal/*.metal)

# --- JACCL (opt-in: make JACCL=1) ---
JACCL_SRC ?= $(HOME)/opensource/mlx/mlx/distributed/jaccl/lib
JACCL_BUILD_DIR := build/jaccl
JACCL_INCLUDE := $(JACCL_SRC)
JACCL_LIB := $(JACCL_BUILD_DIR)/libjaccl.a

ifeq ($(JACCL),1)
JACCL_CFLAGS := -DDS4_JACCL
JACCL_OBJS := jaccl_shim.o
JACCL_LDLIBS := $(JACCL_LIB) -lc++
else
JACCL_CFLAGS :=
JACCL_OBJS :=
JACCL_LDLIBS :=
endif

ifeq ($(UNAME_S),Darwin)
METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal
CORE_OBJS = ds4.o ds4_metal.o
METAL_LDLIBS := $(LDLIBS) -framework Foundation -framework Metal $(JACCL_LDLIBS)
CORE_OBJS = ds4.o ds4_metal.o $(JACCL_OBJS)
CPU_CORE_OBJS = ds4_cpu.o
else
CFLAGS += -D_GNU_SOURCE -fno-finite-math-only
Expand All @@ -41,6 +59,7 @@ all: ds4 ds4-server ds4-bench ds4-eval ds4-agent
help:
@echo "DS4 build targets:"
@echo " make Build Metal ./ds4, ./ds4-server, ./ds4-bench, ./ds4-eval, and ./ds4-agent"
@echo " make JACCL=1 Build with JACCL distributed support (requires macOS SDK >= 26.2)"
@echo " make cpu Build CPU-only ./ds4, ./ds4-server, ./ds4-bench, ./ds4-eval, and ./ds4-agent"
@echo " make test Build and run tests"
@echo " make clean Remove build outputs"
Expand Down Expand Up @@ -122,7 +141,7 @@ cuda-regression: tests/cuda_long_context_smoke
endif

ds4.o: ds4.c ds4.h ds4_gpu.h
$(CC) $(CFLAGS) -c -o $@ ds4.c
$(CC) $(CFLAGS) $(JACCL_CFLAGS) -c -o $@ ds4.c

ds4_cli.o: ds4_cli.c ds4.h linenoise.h
$(CC) $(CFLAGS) -c -o $@ ds4_cli.c
Expand Down Expand Up @@ -157,6 +176,23 @@ rax.o: rax.c rax.h rax_malloc.h
linenoise.o: linenoise.c linenoise.h
$(CC) $(CFLAGS) -c -o $@ linenoise.c

# --- JACCL shim + static lib ---
jaccl_shim.o: jaccl_shim.cpp jaccl_shim.h $(JACCL_LIB)
$(CXX) $(CXXFLAGS) -I$(JACCL_INCLUDE) -c -o $@ jaccl_shim.cpp

tests/test_jaccl_shim.o: tests/test_jaccl_shim.c jaccl_shim.h
$(CC) $(CFLAGS) -c -o $@ tests/test_jaccl_shim.c

tests/test_jaccl_shim: tests/test_jaccl_shim.o jaccl_shim.o $(JACCL_LIB)
$(CC) $(CFLAGS) -o $@ tests/test_jaccl_shim.o jaccl_shim.o $(JACCL_LIB) -lc++

$(JACCL_LIB):
@mkdir -p $(JACCL_BUILD_DIR)
cmake -S $(JACCL_SRC) -B $(JACCL_BUILD_DIR) -DCMAKE_BUILD_TYPE=Release > /dev/null 2>&1
cmake --build $(JACCL_BUILD_DIR) --config Release -j$$(sysctl -n hw.ncpu) > /dev/null 2>&1
@test -f $(JACCL_LIB) || (echo "ERROR: libjaccl.a not found after build" && exit 1)
@echo "Built JACCL static library: $(JACCL_LIB)"

ds4_cpu.o: ds4.c ds4.h ds4_gpu.h
$(CC) $(CFLAGS) -DDS4_NO_GPU -c -o $@ ds4.c

Expand Down Expand Up @@ -196,4 +232,5 @@ test: ds4_test ds4-eval
./ds4_test

clean:
rm -f ds4 ds4-server ds4-bench ds4-eval ds4-agent ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o
rm -f ds4 ds4-server ds4-bench ds4-eval ds4-agent ds4_cpu ds4_native ds4_server_test ds4_test *.o tests/cuda_long_context_smoke tests/cuda_long_context_smoke.o tests/test_jaccl_shim
rm -rf build/
208 changes: 208 additions & 0 deletions distributed_launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#!/usr/bin/env bash
# distributed_launch.sh — Launch ds4-server across multiple nodes via JACCL
#
# Discovers RDMA interfaces via asmi, generates JACCL env vars, and SSHs
# to each node to start ds4-server with --distributed.
#
# Usage:
# ./distributed_launch.sh --nodes hub,m3u4 --model gguf/ds4flash.gguf [--ctx 32768] [--port 8080]

set -euo pipefail

# --- Defaults ---
NODES=""
MODEL=""
CTX=32768
PORT=8080
DS4_BIN="./ds4-server"
EXTRA_ARGS=""

usage() {
cat <<USAGE
Usage: $(basename "$0") --nodes node1,node2 --model <path> [OPTIONS]

Required:
--nodes node1,node2 Comma-separated list of nodes (hostnames or Tailscale names)
--model <path> Path to GGUF model file (must exist on all nodes)

Options:
--ctx N Context length (default: 32768)
--port P Server port (default: 8080)
--bin <path> Path to ds4-server binary (default: ./ds4-server)
--extra "<args>" Extra args to pass to ds4-server
-h, --help Show this help

Environment:
The script uses asmi (port 9090) on each node to discover RDMA interfaces.
The coordinator IP is rank 0's LAN IP (from Tailscale, NOT TB5 /30 IPs).

Examples:
# 2-node launch
$(basename "$0") --nodes hub,m3u4 --model gguf/ds4flash.gguf --ctx 32768

# 4-node launch with custom port
$(basename "$0") --nodes hub,m3u1,m3u3,m3u4 --model gguf/ds4flash.gguf --port 9090
USAGE
exit 0
}

# --- Parse args ---
while [[ $# -gt 0 ]]; do
case "$1" in
--nodes) NODES="$2"; shift 2 ;;
--model) MODEL="$2"; shift 2 ;;
--ctx) CTX="$2"; shift 2 ;;
--port) PORT="$2"; shift 2 ;;
--bin) DS4_BIN="$2"; shift 2 ;;
--extra) EXTRA_ARGS="$2"; shift 2 ;;
-h|--help) usage ;;
*) echo "Unknown option: $1"; usage ;;
esac
done

if [[ -z "$NODES" || -z "$MODEL" ]]; then
echo "Error: --nodes and --model are required"
echo ""
usage
fi

# --- Split nodes into array ---
IFS=',' read -ra NODE_LIST <<< "$NODES"
WORLD_SIZE=${#NODE_LIST[@]}

if [[ $WORLD_SIZE -lt 2 ]]; then
echo "Error: distributed mode requires at least 2 nodes"
exit 1
fi

echo "=== ds4 distributed launch ==="
echo "Nodes: ${NODE_LIST[*]}"
echo "World size: $WORLD_SIZE"
echo "Model: $MODEL"
echo "Context: $CTX"
echo "Port: $PORT"
echo ""

# --- Resolve coordinator IP (rank 0's LAN IP via Tailscale) ---
# Use the LAN IP (10.x.x.x), NOT TB5 /30 IPs which cause JACCL error 60.
COORD_NODE="${NODE_LIST[0]}"
echo "Resolving coordinator LAN IP for $COORD_NODE..."

# Try tailscale status to get the 100.x IP, then fall back to 10.x from /etc/hosts or DNS
COORD_IP=$(ssh -o ConnectTimeout=5 "$COORD_NODE" \
"tailscale ip -4 2>/dev/null || hostname -I 2>/dev/null | awk '{print \$1}'" 2>/dev/null)

if [[ -z "$COORD_IP" ]]; then
echo "Error: could not resolve LAN IP for coordinator $COORD_NODE"
exit 1
fi
echo "Coordinator: $COORD_IP ($COORD_NODE)"
echo ""

# --- Discover RDMA interfaces via asmi ---
# Build the JACCL_IBV_DEVICES JSON matrix.
# For N nodes, this is an NxN matrix where [i][j] is the RDMA interface on node i
# that connects to node j (null for self).
#
# Example for 2 nodes: [[null, "rdma_en7"], ["rdma_en5", null]]

echo "Discovering RDMA interfaces via asmi..."

declare -A RDMA_IFACE # RDMA_IFACE[i,j] = interface name on node i for link to node j

for (( i=0; i<WORLD_SIZE; i++ )); do
node="${NODE_LIST[$i]}"
echo " Querying $node:9090/links..."

# asmi /links returns JSON array of RDMA link objects with peer_hostname and rdma_device
links_json=$(curl -sf "http://${node}:9090/links" 2>/dev/null || echo "[]")

if [[ "$links_json" == "[]" ]]; then
echo " Warning: no RDMA links found on $node (asmi may not be running)"
fi

for (( j=0; j<WORLD_SIZE; j++ )); do
if [[ $i -eq $j ]]; then
RDMA_IFACE[$i,$j]="null"
continue
fi

peer="${NODE_LIST[$j]}"
# Extract the RDMA device name for the link to this peer
iface=$(echo "$links_json" | python3 -c "
import json, sys
links = json.load(sys.stdin)
for link in links:
peer = link.get('peer_hostname', link.get('peer', ''))
if '$peer' in peer:
print(link.get('rdma_device', link.get('interface', '')))
break
" 2>/dev/null || echo "")

if [[ -n "$iface" ]]; then
RDMA_IFACE[$i,$j]="\"$iface\""
else
echo " Warning: no RDMA interface found on $node for peer $peer"
RDMA_IFACE[$i,$j]="null"
fi
done
done

# --- Build JACCL_IBV_DEVICES JSON ---
IBV_JSON="["
for (( i=0; i<WORLD_SIZE; i++ )); do
[[ $i -gt 0 ]] && IBV_JSON+=", "
IBV_JSON+="["
for (( j=0; j<WORLD_SIZE; j++ )); do
[[ $j -gt 0 ]] && IBV_JSON+=", "
IBV_JSON+="${RDMA_IFACE[$i,$j]}"
done
IBV_JSON+="]"
done
IBV_JSON+="]"

echo ""
echo "JACCL_IBV_DEVICES=$IBV_JSON"
echo ""

# --- Launch on each node ---
PIDS=()
for (( rank=0; rank<WORLD_SIZE; rank++ )); do
node="${NODE_LIST[$rank]}"
echo "Launching rank $rank on $node..."

ssh -o ConnectTimeout=10 "$node" bash -c "'
export JACCL_RANK=$rank
export JACCL_WORLD_SIZE=$WORLD_SIZE
export JACCL_COORDINATOR=$COORD_IP
export JACCL_IBV_DEVICES='"'"'$IBV_JSON'"'"'
echo \"ds4: starting rank $rank/$WORLD_SIZE on \$(hostname)\"
echo \" JACCL_COORDINATOR=$COORD_IP\"
echo \" JACCL_IBV_DEVICES=\$JACCL_IBV_DEVICES\"
cd \$(dirname $DS4_BIN) 2>/dev/null || true
exec $DS4_BIN --distributed --metal --host 0.0.0.0 --port $PORT --ctx $CTX --model $MODEL $EXTRA_ARGS
'" &
PIDS+=($!)
done

echo ""
echo "=== All $WORLD_SIZE ranks launched ==="
echo "PIDs: ${PIDS[*]}"
echo "Press Ctrl-C to stop all nodes"
echo ""

# --- Wait for any child to exit ---
cleanup() {
echo ""
echo "Shutting down all ranks..."
for pid in "${PIDS[@]}"; do
kill "$pid" 2>/dev/null || true
done
wait
echo "All ranks stopped."
}
trap cleanup INT TERM

wait -n "${PIDS[@]}" 2>/dev/null || true
echo "A rank exited. Shutting down remaining..."
cleanup
Loading