-
Notifications
You must be signed in to change notification settings - Fork 57
add guide for SGLang-Jax on TPUs #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| # Serve Qwen3-MoE with SGLang-Jax on TPU | ||
|
|
||
| SGLang-Jax supports multiple Mixture-of-Experts (MoE) models from the Qwen3 family with varying hardware requirements: | ||
|
|
||
| - **[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)**: Runs on 4 TPU v6e chips | ||
| - **[Qwen3-Coder-480B-A35B-Instruct](https://huggingface.co/Qwen/Qwen3-Coder-480B-A35B-Instruct)**: Requires 64 TPU v6e chips (16 nodes × 4 chips) | ||
| - Other Qwen3 MoE variants with different scale requirements | ||
|
|
||
| **This tutorial focuses on deploying Qwen3-Coder-480B**, the largest model requiring a multi-node distributed setup. For smaller models like Qwen3-30B, you can follow similar steps but with adjusted node counts and parallelism settings. | ||
|
|
||
| ## Hardware Requirements | ||
|
|
||
| Running Qwen3-Coder-480B requires a multi-node TPU cluster: | ||
|
|
||
| - **Total nodes**: 16 | ||
| - **TPU chips per node**: 4 (v6e) | ||
| - **Total TPU chips**: 64 | ||
| - **Tensor Parallelism (TP)**: 32 (for non-MoE layers) | ||
| - **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts) | ||
|
|
||
|
|
||
| ## Installation | ||
|
|
||
| ### Option 1: Install from PyPI | ||
|
|
||
| ```bash | ||
| uv venv --python 3.12 && source .venv/bin/activate | ||
| uv pip install sglang-jax | ||
| ``` | ||
|
|
||
| ### Option 2: Install from Source | ||
|
|
||
| ```bash | ||
| git clone https://github.com/sgl-project/sglang-jax | ||
| cd sglang-jax | ||
| uv venv --python 3.12 && source .venv/bin/activate | ||
| uv pip install -e python/ | ||
| ``` | ||
| ## Launch Distributed Server | ||
|
|
||
| ### Preparation | ||
|
|
||
| 1. **Get Node 0 IP address** (coordinator): | ||
|
|
||
| ```bash | ||
| # On node 0 | ||
| hostname -I | awk '{print $1}' | ||
| ``` | ||
|
|
||
| Save this IP as `NODE_RANK_0_IP`. | ||
|
|
||
| 2. **Download model** (recommended to use shared storage or pre-download on all nodes): | ||
|
|
||
| ```bash | ||
| export HF_TOKEN=your_huggingface_token | ||
| huggingface-cli download Qwen/Qwen3-Coder-480B --local-dir /path/to/model | ||
| ``` | ||
|
|
||
| ### Launch Command | ||
|
|
||
| Run the following command **on each node**, replacing: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a manual process of running the same command on each node by sshing and running the same command which can be time taking. Can you please check if the below version of command can be used to run the same command on all the workers which will simplify the process.
https://docs.cloud.google.com/tpu/docs/managing-tpus-tpu-vm since node-rank is the only changing param for each node, is there any other way to pass it so that it doesn't depend on the command, that way we'll be able to run the same command on all nodes with the above single command. If not, I think the existing way should be fine. |
||
| - `<NODE_RANK_0_IP>`: IP address of node 0 | ||
| - `<NODE_RANK>`: Current node rank (0-15) | ||
| - `<QWEN3_CODER_480B_MODEL_PATH>`: Path to the downloaded model | ||
|
|
||
| ```bash | ||
| JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ | ||
| python3 -u -m sgl_jax.launch_server \ | ||
| --model-path <QWEN3_CODER_480B_MODEL_PATH> \ | ||
| --trust-remote-code \ | ||
| --dist-init-addr=<NODE_RANK_0_IP>:10011 \ | ||
| --nnodes=16 \ | ||
| --tp-size=32 \ | ||
| --device=tpu \ | ||
| --random-seed=3 \ | ||
| --mem-fraction-static=0.8 \ | ||
| --chunked-prefill-size=2048 \ | ||
| --download-dir=/dev/shm \ | ||
| --dtype=bfloat16 \ | ||
| --max-running-requests=128 \ | ||
| --skip-server-warmup \ | ||
| --page-size=128 \ | ||
| --tool-call-parser=qwen3_coder \ | ||
| --node-rank=<NODE_RANK> | ||
| ``` | ||
|
|
||
| ### Example for Specific Nodes | ||
|
|
||
| **Node 0 (coordinator):** | ||
|
|
||
| ```bash | ||
| JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ | ||
| python3 -u -m sgl_jax.launch_server \ | ||
| --model-path /path/to/Qwen3-Coder-480B \ | ||
| --trust-remote-code \ | ||
| --dist-init-addr=10.0.0.2:10011 \ | ||
| --nnodes=16 \ | ||
| --tp-size=32 \ | ||
| --device=tpu \ | ||
| --random-seed=3 \ | ||
| --mem-fraction-static=0.8 \ | ||
| --chunked-prefill-size=2048 \ | ||
| --download-dir=/dev/shm \ | ||
| --dtype=bfloat16 \ | ||
| --max-running-requests=128 \ | ||
| --skip-server-warmup \ | ||
| --page-size=128 \ | ||
| --tool-call-parser=qwen3_coder \ | ||
| --node-rank=0 | ||
| ``` | ||
|
|
||
| **Node 1:** | ||
|
|
||
| ```bash | ||
| # Same command but with --node-rank=1 | ||
| JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ | ||
| python3 -u -m sgl_jax.launch_server \ | ||
| --model-path /path/to/Qwen3-Coder-480B \ | ||
| --trust-remote-code \ | ||
| --dist-init-addr=10.0.0.2:10011 \ | ||
| --nnodes=16 \ | ||
| --tp-size=32 \ | ||
| --device=tpu \ | ||
| --random-seed=3 \ | ||
| --mem-fraction-static=0.8 \ | ||
| --chunked-prefill-size=2048 \ | ||
| --download-dir=/dev/shm \ | ||
| --dtype=bfloat16 \ | ||
| --max-running-requests=128 \ | ||
| --skip-server-warmup \ | ||
| --page-size=128 \ | ||
| --tool-call-parser=qwen3_coder \ | ||
| --node-rank=1 | ||
| ``` | ||
|
|
||
| Repeat for all 16 nodes, incrementing `--node-rank` from 0 to 15. | ||
|
|
||
| ## Configuration Parameters | ||
|
|
||
| ### Distributed Training | ||
|
|
||
| - `--nnodes`: Number of nodes in the cluster (16) | ||
| - `--node-rank`: Rank of the current node (0-15) | ||
| - `--dist-init-addr`: Address of the coordinator node (node 0) with port | ||
|
|
||
| ### Model Parallelism | ||
|
|
||
| - `--tp-size`: Tensor parallelism size for non-MoE layers (32) | ||
| - **ETP**: Expert tensor parallelism automatically configured to 64 based on total chips | ||
|
|
||
| ### Memory and Performance | ||
|
|
||
| - `--mem-fraction-static`: Memory allocation for static buffers (0.8) | ||
| - `--chunked-prefill-size`: Prefill chunk size for batching (2048) | ||
| - `--max-running-requests`: Maximum concurrent requests (128) | ||
| - `--page-size`: Page size for memory management (128) | ||
|
|
||
| ### Model-Specific | ||
|
|
||
| - `--tool-call-parser`: Parser for tool calls, set to `qwen3_coder` for this model | ||
| - `--dtype`: Data type for inference (bfloat16) | ||
| - `--random-seed`: Random seed for reproducibility (3) | ||
|
|
||
| ## Verification | ||
|
|
||
| Once all nodes are running, the server will be accessible via the coordinator node (node 0). You can test it with: | ||
|
|
||
| ```bash | ||
| curl http://<NODE_RANK_0_IP>:8000/v1/completions \ | ||
| -H "Content-Type: application/json" \ | ||
| -d '{ | ||
| "model": "Qwen/Qwen3-Coder-480B", | ||
| "prompt": "def fibonacci(n):", | ||
| "max_tokens": 200, | ||
| "temperature": 0 | ||
| }' | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # Serve Qwen3 with SGLang-Jax on TPU | ||
|
|
||
| This guide demonstrates how to serve [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B) and [Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) using SGLang-Jax on TPU. | ||
|
|
||
|
|
||
| ## Provision TPU Resources | ||
|
|
||
| For **Qwen3-8B**, a single v6e chip is sufficient. For **Qwen3-32B**, use 4 chips or more. | ||
|
|
||
| ### Option 1: Using gcloud CLI | ||
|
|
||
| Install and configure gcloud CLI by following the [official installation guide](https://cloud.google.com/sdk/docs/install). | ||
|
|
||
| **Create TPU VM:** | ||
|
|
||
| ```bash | ||
| gcloud compute tpus tpu-vm create sgl-jax \ | ||
| --zone=us-east5-a \ | ||
| --version=v2-alpha-tpuv6e \ | ||
| --accelerator-type=v6e-4 | ||
| ``` | ||
|
|
||
| **Connect to TPU VM:** | ||
|
|
||
| ```bash | ||
| gcloud compute tpus tpu-vm ssh sgl-jax --zone us-east5-a | ||
| ``` | ||
|
|
||
| ### Option 2: Using SkyPilot (Recommended for Development) | ||
|
|
||
| SkyPilot simplifies TPU provisioning and offers automatic cost optimization, instance management, and environment setup. | ||
|
|
||
| **Prerequisites:** | ||
| - [Install SkyPilot](https://docs.skypilot.co/en/latest/getting-started/installation.html) | ||
| - [Configure GCP credentials](https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp) | ||
|
|
||
| **Create configuration file `sgl-jax.yaml`:** | ||
|
|
||
| ```yaml | ||
| resources: | ||
| accelerators: tpuv6e-4 | ||
| accelerator_args: | ||
| tpu_vm: True | ||
| runtime_version: v2-alpha-tpuv6e | ||
|
|
||
| setup: | | ||
| uv venv --python 3.12 | ||
| source .venv/bin/activate | ||
| uv pip install sglang-jax | ||
| ``` | ||
|
|
||
| **Launch TPU cluster:** | ||
|
|
||
| ```bash | ||
| sky launch sgl-jax.yaml \ | ||
| --cluster=sgl-jax-skypilot-v6e-4 \ | ||
| --infra=gcp \ | ||
| -i 30 \ | ||
| --down \ | ||
| -y \ | ||
| --use-spot | ||
| ``` | ||
|
|
||
| This command will: | ||
| - Find the lowest-cost spot instance across regions | ||
| - Automatically shut down after 30 minutes of idleness | ||
| - Set up the SGLang-Jax environment automatically | ||
|
|
||
| **Connect to cluster:** | ||
|
|
||
| ```bash | ||
| ssh sgl-jax-skypilot-v6e-4 | ||
| ``` | ||
|
|
||
| > **Note:** SkyPilot manages the external IP automatically, so you don't need to track it manually. | ||
|
|
||
| ## Installation | ||
|
|
||
| > **Note:** If you used SkyPilot to provision resources, the environment is already set up. Skip to the [Launch Server](#launch-server) section. | ||
|
|
||
| For gcloud CLI users, install SGLang-Jax using one of the following methods: | ||
|
|
||
| ### Option 1: Install from PyPI | ||
|
|
||
| ```bash | ||
| uv venv --python 3.12 && source .venv/bin/activate | ||
| uv pip install sglang-jax | ||
| ``` | ||
|
|
||
| ### Option 2: Install from Source | ||
|
|
||
| ```bash | ||
| git clone https://github.com/sgl-project/sglang-jax | ||
| cd sglang-jax | ||
| uv venv --python 3.12 && source .venv/bin/activate | ||
| uv pip install -e python/ | ||
| ``` | ||
|
|
||
| ## Launch Server | ||
|
|
||
| Set the model name and start the SGLang-Jax server: | ||
|
|
||
| ```bash | ||
| export MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B" | ||
|
|
||
| JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache \ | ||
| uv run python -u -m sgl_jax.launch_server \ | ||
| --model-path ${MODEL_NAME} \ | ||
| --trust-remote-code \ | ||
| --tp-size=4 \ | ||
| --device=tpu \ | ||
| --mem-fraction-static=0.8 \ | ||
| --chunked-prefill-size=2048 \ | ||
| --download-dir=/tmp \ | ||
| --dtype=bfloat16 \ | ||
| --max-running-requests 256 \ | ||
| --skip-server-warmup \ | ||
| --page-size=128 | ||
| ``` | ||
|
|
||
| ### Configuration Parameters | ||
|
|
||
| - `--tp-size`: Tensor parallelism size, should equal the number of TPU chips in your instance | ||
| - `--mem-fraction-static`: Fraction of memory allocated for static buffers | ||
| - `--chunked-prefill-size`: Size of prefill chunks for batching | ||
| - `--max-running-requests`: Maximum number of concurrent requests | ||
|
|
||
| ## Run Benchmark | ||
|
|
||
| Test serving performance with different workload configurations: | ||
|
|
||
| ```bash | ||
| uv run python -m sgl_jax.bench_serving \ | ||
| --backend sgl-jax \ | ||
| --dataset-name random \ | ||
| --num-prompts 256 \ | ||
| --random-input 4096 \ | ||
| --random-output 1024 \ | ||
| --max-concurrency 64 \ | ||
| --random-range-ratio 1 \ | ||
| --warmup-requests 0 | ||
| ``` | ||
|
|
||
| ### Benchmark Parameters | ||
|
|
||
| - `--backend`: Backend engine (use `sgl-jax`) | ||
| - `--random-input`: Input sequence length (e.g., 1024, 4096, 8192) | ||
| - `--random-output`: Output sequence length (e.g., 1, 1024) | ||
| - `--max-concurrency`: Maximum number of concurrent requests (e.g., 8, 16, 32, 64, 128, 256) | ||
| - `--num-prompts`: Total number of prompts to send | ||
|
|
||
| You can test various combinations of input/output lengths and concurrency levels to evaluate throughput and latency characteristics. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Serve SGLang-Jax on Trillium TPUs (v6e) | ||
|
|
||
| This repository provides examples demonstrating how to deploy and serve SGLang-Jax on Trillium TPUs using GCE (Google Compute Engine) for a select set of models. | ||
|
|
||
| - [Qwen3-8B/32B](./Qwen3/README.md) | ||
| - [Qwen/Qwen3-30B-A3B/Qwen/Qwen3-Coder-480B-A35B-Instruct](./Qwen3-MoE/README.md) | ||
|
|
||
| The SGLang-Jax project continues to support new models. For the specific model list, see https://github.com/sgl-project/sglang-jax/tree/main/python/sgl_jax/srt/models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we please add the tpu provisioning and ssh commands like the Qwen3 readme, so that users looking at only this readme are aware of the steps.