From 08ce80bbc1853823520b10ae482f7e7d5553e449 Mon Sep 17 00:00:00 2001 From: SamitHuang <285365963@qq.com> Date: Mon, 2 Mar 2026 16:32:32 +0800 Subject: [PATCH 01/10] temp save rfc Signed-off-by: SamitHuang <285365963@qq.com> --- docs/en/advanced/rfc-vllm-rollout-backend.md | 360 +++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 docs/en/advanced/rfc-vllm-rollout-backend.md diff --git a/docs/en/advanced/rfc-vllm-rollout-backend.md b/docs/en/advanced/rfc-vllm-rollout-backend.md new file mode 100644 index 0000000000..5c3cbd20f5 --- /dev/null +++ b/docs/en/advanced/rfc-vllm-rollout-backend.md @@ -0,0 +1,360 @@ +# RFC: Add vLLM as a Rollout Backend in Slime + +- **Author**: \ +- **Status**: Draft +- **Audience**: Slime rollout/runtime maintainers, RL training maintainers +- **Last Updated**: 2026-03-02 + +## 1. Summary + +This RFC proposes adding **vLLM** as a first-class rollout backend in Slime while preserving current SGLang behavior and avoiding regressions in GRPO workflows. + +The design is based on: + +1. A backend-agnostic rollout request/response contract. +2. Backend adapters (`SGLangClient`, `VLLMClient`) that isolate protocol differences. +3. Capability-aware behavior for non-parity features (abort, routed experts, prompt logprobs). +4. Managed vLLM mode from day one (Slime manages vLLM process lifecycle inside Ray, same as SGLang path). +5. Weight sync via NCCL broadcast (GPU direct transfer, no disk I/O). + +## 2. Why this is needed + +Slime currently assumes SGLang behavior in multiple places: + +- rollout generation response schema (`meta_info.finish_reason.type`, `output_token_logprobs`, optional `routed_experts`) +- router control-plane interactions (`/workers`, `/list_workers`, `/abort_request`) +- rollout server startup flow in Ray + +Supporting vLLM is therefore not a URL replacement. It requires a compatibility layer across both data plane and control plane semantics. + +## 3. Goals and Non-goals + +### Goals + +- Add `--rollout-backend {sglang,vllm}`. +- Keep SGLang path unchanged by default. +- Keep trainer/algorithm interfaces stable. +- Support GRPO rollout with explicit compatibility semantics and observability. + +### Non-goals (initial phase) + +- Full parity for all SGLang-specific features. +- Colocate mode (training and rollout sharing same GPUs). +- R3 routed expert replay on vLLM. +- Multi-instance vLLM with router load balancing. + +## 4. Architecture and Interface Changes (Explicit) + +This section is the key communication point for Slime maintainers. + +### 4.1 Architecture changes + +#### End-to-end architecture (string diagram) + +```text + +--------------------+ + | TrainerLoop | + +---------+----------+ + | + v + +--------------------+ + | RolloutFunction | + +---------+----------+ + | + v + +-------------------------------+ + | CanonicalRolloutRequest | + | (input_ids, sampling_params, | + | return_logprob, prompt_text) | + +---------------+---------------+ + | + v + +--------------------------+ + | RolloutBackendClient | + +------------+-------------+ + | + +-----------------+-----------------+ + | | + v v + +------------------------+ +------------------------+ + | SGLangClient | | VLLMClient | + +-----------+------------+ +------------+-----------+ + | | + v v + +-------------------------------+ +------------------------------+ + | SGLangRouter or SlimeRouter | | SlimeRouter(generic) or | + | (SGLang control-plane aware) | | direct vLLM endpoint | + +---------------+---------------+ +--------------+---------------+ + | | + v v + +----------------------+ +----------------------+ + | SGLang workers | | vLLM workers | + +----------+-----------+ +----------+-----------+ + \ / + \ / + v v + +-----------------------------------+ + | CanonicalRolloutResponse | + | (text, token_ids, token_logprobs, | + | finish_reason, backend_raw) | + +----------------+------------------+ + | + v + +-----------------------------------+ + | SampleUpdate + Training Pipeline | + | (backend-agnostic consumption) | + +-----------------------------------+ +``` + +#### Component responsibility map + +| Component | Responsibility before RFC | Responsibility after RFC | +|---|---|---| +| `RolloutFunction` | Contains generic rollout + SGLang protocol details | Contains generic rollout orchestration only | +| Backend protocol layer | Implicit in rollout logic | Explicit via `RolloutBackendClient` adapters | +| `SGLangClient` | N/A (scattered logic) | Owns SGLang request/response/control-plane specifics | +| `VLLMClient` | N/A | Owns vLLM request/response mapping and retries | +| Trainer/sample pipeline | Consumes SGLang-shaped fields indirectly | Consumes canonical fields, backend-agnostic | +| Router integration | Mixed generic + SGLang-specific assumptions | SGLang-specific control-plane isolated to SGLang adapter | + +#### Control-plane behavior split + +| Control-plane behavior | SGLang path | vLLM initial path | +|---|---|---| +| Worker registration/discovery APIs | Supported | Not required in generic path | +| Worker-level abort | Supported (`/abort_request`) | Fallback to timeout/cancel semantics | +| Routed experts replay metadata | Supported | Explicitly unsupported (capability-gated) | +| Health/load-balance routing | Existing SGLang/SlimeRouter behavior | SlimeRouter generic mode or direct endpoint | + +#### A) Rollout backend abstraction layer (new) + +- Introduce a backend client interface: + - `RolloutBackendClient` + - backend capability descriptor +- Add concrete adapters: + - `SGLangClient` (existing behavior extraction) + - `VLLMClient` (new) + +**Impact**: rollout logic calls a unified backend interface instead of directly embedding SGLang HTTP semantics. + +#### B) Rollout execution path refactor + +- `sglang_rollout.generate` no longer directly depends on SGLang HTTP payload/response shape. +- It builds a canonical request and consumes a canonical response. + +**Impact**: trainer-side logic remains stable while backend-specific details move into adapters. + +#### C) Startup path split in RolloutManager + +- Existing path (SGLang): keep current managed startup behavior. +- vLLM path: managed mode -- Slime creates `VLLMEngine` Ray actors that launch and manage local vLLM server processes, just like SGLang path uses `SGLangEngine`. + +**Impact**: vLLM gets the same lifecycle management as SGLang (process startup, health check, shutdown). + +#### D) Weight sync: VLLMEngine with same interface as SGLangEngine + +Training-side weight updater (`UpdateWeightFromDistributed`) calls engine methods via Ray remote. The core call chain with source locations: + +```text +Training side (Megatron actor, UNCHANGED) + UpdateWeightFromDistributed.update_weights() + -> engine.pause_generation.remote() # Ray remote call + -> VLLMEngine.pause_generation() # Ray actor method + -> requests.post("http://localhost:8000/sleep?level=2") + -> requests.post("http://localhost:8000/wake_up?tags=weights") + + -> engine.flush_cache.remote() + -> VLLMEngine.flush_cache() # no-op, sleep level 2 covers this + + -> engine.init_weights_update_group.remote() + -> VLLMEngine.init_weights_update_group() + -> requests.post("http://localhost:8000/collective_rpc", + json={"method": "init_weight_update_group", + "master_address": ..., "master_port": ..., + "rank_offset": ..., "world_size": ...}) + + -> dist.broadcast(param, src=0, group=nccl_group) # training process NCCL broadcast + -> engine.update_weights_from_distributed.remote() # tell vLLM "receive these params" + -> VLLMEngine.update_weights_from_distributed() + -> requests.post("http://localhost:8000/collective_rpc", + json={"method": "update_weight", + "name": ..., "dtype_name": ..., "shape": ...}) + -> vLLM worker: NCCL broadcast recv + model.load_weights() + + -> engine.continue_generation.remote() + -> VLLMEngine.continue_generation() + -> requests.post("http://localhost:8000/wake_up?tags=kv_cache") + + +Rollout generation side + sglang_rollout.generate() + -> VLLMClient.generate() + -> httpx.post("http://localhost:8000/v1/completions") # direct to local vLLM +``` + +`VLLMEngine` exposes **the same method signatures** as `SGLangEngine`. Endpoint mapping: + +| Method (same signature) | SGLangEngine internal | VLLMEngine internal | +|---|---|---| +| `pause_generation()` | `POST /pause_generation` | `POST /sleep?level=2` + `POST /wake_up?tags=weights` | +| `flush_cache()` | `GET /flush_cache` | no-op (sleep level 2 covers this) | +| `init_weights_update_group(...)` | `POST /init_weights_update_group` | `POST /collective_rpc {"method":"init_weight_update_group",...}` | +| `update_weights_from_distributed(...)` | `POST /update_weights_from_distributed` | `POST /collective_rpc {"method":"update_weight",...}` | +| `continue_generation()` | `POST /continue_generation` | `POST /wake_up?tags=kv_cache` | + +**Impact**: training-side code (`UpdateWeightFromDistributed`, `train.py`, actor code) requires **zero changes**. Weight sync uses NCCL broadcast (GPU direct transfer), same efficiency as SGLang path. + +Source links: +- [train.py#L88](../../../train.py#L88), [actor_group.py#L126](../../../slime/ray/actor_group.py#L126), [actor.py#L532](../../../slime/backends/megatron_utils/actor.py#L532) +- [update_weight_from_distributed.py#L82](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L82), [#L89](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L89), [#L90](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L90), [#L280](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L280), [#L321](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L321), [#L139](../../../slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py#L139) +- [sglang_engine.py#L415](../../../slime/backends/sglang_utils/sglang_engine.py#L415), [#L296](../../../slime/backends/sglang_utils/sglang_engine.py#L296), [#L373](../../../slime/backends/sglang_utils/sglang_engine.py#L373), [#L398](../../../slime/backends/sglang_utils/sglang_engine.py#L398), [#L420](../../../slime/backends/sglang_utils/sglang_engine.py#L420) +- [sglang_rollout.py#L47](../../../slime/rollout/sglang_rollout.py#L47), [#L72](../../../slime/rollout/sglang_rollout.py#L72), [#L165](../../../slime/rollout/sglang_rollout.py#L165), [#L311](../../../slime/rollout/sglang_rollout.py#L311) +- [rollout.py#L477](../../../slime/ray/rollout.py#L477), [#L1028](../../../slime/ray/rollout.py#L1028), [#L1041](../../../slime/ray/rollout.py#L1041) + +#### E) Router decision: no router for vLLM initial phase + +- SGLang Model Gateway: only supports SGLang workers, not applicable. +- SlimeRouter: only needed for R3 / radix-tree caching; Qwen2.5-0.5B is not MoE and uses token-in/token-out. +- Single vLLM instance, `VLLMClient` connects directly to local vLLM server port. + +### 4.2 Interface changes + +#### A) New CLI interfaces + +- `--rollout-backend {sglang,vllm}` +- `--vllm-base-url` +- `--vllm-api-mode` (e.g. OpenAI-compatible completion mode) +- `--vllm-model` +- `--vllm-max-retries` + +#### B) New internal canonical interfaces + +Add canonical rollout contract types: + +- `RolloutBackendRequest` +- `RolloutBackendResponse` + +These carry backend-neutral fields such as: + +- input token ids +- sampling params +- output token ids/logprobs +- canonical finish reason (`stop|length|abort`) +- backend raw response for debugging + +#### C) Capability-gated interface behavior + +Backends declare capabilities (abort support, routed experts support, prompt logprobs support). Unsupported features are explicitly gated, logged, and/or failed fast. + +## 5. File-level Change Map (for maintainers) + +### New files + +- `slime/rollout/backends/base_client.py` -- backend client interface + capability model +- `slime/rollout/backends/sglang_client.py` -- extracted SGLang rollout client +- `slime/rollout/backends/vllm_client.py` -- vLLM rollout client +- `slime/rollout/backends/__init__.py` -- adapter exports +- `slime/backends/vllm_utils/vllm_engine.py` -- `VLLMEngine` Ray actor (analogous to `SGLangEngine`) +- `slime/backends/vllm_utils/worker_extension.py` -- vLLM WorkerExtension for NCCL weight sync + +### Modified files + +- `slime/rollout/base_types.py` -- add canonical backend request/response types +- `slime/rollout/sglang_rollout.py` -- use backend adapters instead of hardcoded SGLang protocol +- `slime/utils/arguments.py` -- add `--rollout-backend`, vLLM args; skip SGLang parse when vLLM +- `slime/ray/rollout.py` -- startup flow split: create `VLLMEngine` when `rollout_backend=vllm` + +### Unchanged files (by design) + +- `train.py` -- training loop does not know about backend +- `slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py` -- engine method signatures are identical +- `slime/backends/megatron_utils/actor.py` -- calls engine methods polymorphically + +## 6. Compatibility Matrix (initial) + +| Capability | SGLang | vLLM (initial) | Handling | +|---|---|---|---| +| Token-level response logprobs | Yes | Partial/endpoint-dependent | Adapter normalization | +| Finish reason mapping (`stop|length|abort`) | Native | Different enums likely | Canonical mapping | +| Worker-level abort | Yes | Typically no direct equivalent | timeout/cancel fallback strategy | +| Prompt logprobs (OPD-related) | Yes | Partial/unknown by endpoint | capability-gated | +| Routed experts replay (R3) | Yes | No | explicit unsupported gate | +| SGLang worker API dependency | Yes | No | isolate in SGLang adapter | + +## 7. Phased Plan + +### Phase 1: Managed vLLM GRPO training (small-scale validation) + +**Goal**: Qwen2.5-0.5B GRPO 8-GPU sync training on GSM8K, loss/reward convergence comparable to SGLang. + +Key deliverables: +- `VLLMEngine` Ray actor (process lifecycle, sleep/wake_up, weight sync via NCCL) +- `VLLMClient` rollout adapter (request/response mapping, logprob/finish_reason normalization) +- `RolloutBackendRequest/Response` canonical contract +- `SGLangClient` extraction (isolate SGLang protocol details) +- `start_rollout_servers` branching for vLLM path +- Argument parsing split (`--rollout-backend vllm`) + +Technical decisions: +- No router (single vLLM instance, direct connection) +- Weight sync: NCCL broadcast via `VLLMEngine` with same method signatures as `SGLangEngine` +- Non-colocate mode (separate training and rollout GPUs) +- Rollout function: reuse `sglang_rollout.py` (backend-agnostic orchestration) +- Training-side code: zero changes + +Acceptance criteria: +- `num_rollout=3` smoke passes without errors +- Weight sync correctness: rollout output changes with each training update +- reward mean delta < 5% vs SGLang path (same seed/config, 20 steps) +- Existing SGLang tests remain green + +### Phase 2: Performance, scale, and advanced features + +- Colocate mode support (GPU IPC weight transfer) +- Multi-instance vLLM with load balancing +- Abort/cancel strategy refinement +- Prompt logprob support (OPD scenarios) +- Deterministic computation verification +- Larger model validation (e.g. Qwen3-4B) + +## 8. Validation strategy + +### Unit + +- finish reason normalization +- token/logprob alignment +- capability-gated behavior + +### Integration + +- SGLang vs vLLM response schema comparisons +- timeout/retry/non-crash behavior checks + +### End-to-end GRPO + +- smoke (short rollout) +- stability (medium horizon) +- stress (higher concurrency / latency pressure) + +## 9. Risks and mitigations + +- **Logprob semantic mismatch**: strict adapter checks + canonicalization + tests. +- **Abort mismatch**: capability model + timeout/cancel fallback + explicit logs. +- **Training drift**: mandatory A/B runs with fixed seeds and aligned configs. +- **Middleware assumptions**: enforce capability checks and default-disable incompatible middleware paths. + +## 10. Open questions for Slime maintainers + +1. Should vLLM initial path be direct endpoint only, or standardize via SlimeRouter generic mode immediately? +2. What minimum logprob fidelity is required to claim GRPO support? +3. Should OPD prompt-logprob paths be blocked for vLLM until full parity? +4. What backend quality gates are required before default recommendation? + +## 11. Decision requested + +Approve incremental implementation with: + +- contract-first adapter architecture, +- external vLLM initial support, +- explicit capability-gated behavior, +- strict SGLang non-regression requirement. From 3af806ee731e77d3b9030cb57b6fc758280c8aca Mon Sep 17 00:00:00 2001 From: SamitHuang <285365963@qq.com> Date: Tue, 3 Mar 2026 17:03:34 +0800 Subject: [PATCH 02/10] add plan Signed-off-by: SamitHuang <285365963@qq.com> --- goal_plan.md | 27 ++ rfc-vllm-rollout-backend-en.md | 401 ++++++++++++++++++++++++++++++ rfc-vllm-rollout-backend.md | 436 +++++++++++++++++++++++++++++++++ 3 files changed, 864 insertions(+) create mode 100644 goal_plan.md create mode 100644 rfc-vllm-rollout-backend-en.md create mode 100644 rfc-vllm-rollout-backend.md diff --git a/goal_plan.md b/goal_plan.md new file mode 100644 index 0000000000..adb39bf047 --- /dev/null +++ b/goal_plan.md @@ -0,0 +1,27 @@ +### 阶段一:打通Qwen2.5-0.5B GRPO 8卡同步/异步训练(train.py和train_async.py),GSM8K 数据集,loss/reward 收敛与 SGLang backend 基本一致,且满足确定性计算,多次重复运行Loss曲线完全一致。 + +#### 初步方案: +- 对标SGLang,Slime 在 Ray 内管理 vLLM 的完整生命周期,包括进程拉起、权重同步、推理暂停/恢复 +- 暂不使用Router,SGLang Model Gateway仅只支持SGLang Worker,SlimeRouter仅在 R3 / radix-tree caching 时需要,Qwen2.5-0.5B 非 MoE 且用 token-in/token-out +- 单vLLM实例,无router,通过vLLMClient 直连本地 vLLM 进程端口 +- 若训推不共卡,权重同步采用NCCL broadcast,对标SGLang update_weights_from_distributed (默认,优先用于支持异步训推) +- 若colocate,权重同步采用GPU IPC(vLLM update_weights_from_ipc, update_weights_from_tensor),对标SGLang update_weights_from_tensor + +#### 风险: +- slime, sglang版本依赖,和vllm 0.16的版本依赖冲突(numpy, torch, transformers, etc) +- 算力 + +First Design and RFC by 03/06 + + +### 阶段二:接入vllm-project/router,支持多实例vLLM + +- vllm router forked from SGLang Model Gateway + +### 阶段三:多节点大规模验证,MoE模型,optional:验证MTP Speculative Decoding,FP8 rollout 等高级特性 + +- Model: Qwen/Qwen3-30B-A3B or GLM4.7 +- Parallel: 16卡 or 128卡, Train mixed EP+FSDP, Rollout EP+DP +- Verify more features: + - Bf16 train, FP8 rollout + - MTP Speculative Decoding diff --git a/rfc-vllm-rollout-backend-en.md b/rfc-vllm-rollout-backend-en.md new file mode 100644 index 0000000000..dad61a0731 --- /dev/null +++ b/rfc-vllm-rollout-backend-en.md @@ -0,0 +1,401 @@ +# RFC: Supporting vLLM as a Rollout Backend in Slime + +- **Author**: \ +- **Status**: Phase 1 Done +- **Audience**: Slime rollout/runtime maintainers, RL training maintainers +- **Last Updated**: 2026-03-03 + +## 1. Summary + +This RFC proposes adding **vLLM** as a first-class rollout backend in Slime while keeping the existing SGLang behavior unchanged and the GRPO training pipeline unaffected. + +Core design principles: + +1. Define backend-agnostic rollout request/response contracts (`RolloutBackendRequest`/`RolloutBackendResponse`). +2. Isolate protocol differences through backend adapters (`SGLangClient`, `VLLMClient`). +3. Apply explicit capability gating for non-equivalent features (abort, routed experts, prompt logprobs). +4. Managed mode -- Slime manages the vLLM process lifecycle within Ray, on par with the SGLang path. +5. Weight synchronization leverages vLLM's native weight transfer API, automatically selecting the backend based on deployment mode: + - **Colocate mode**: CUDA IPC (`IPCWeightTransferEngine`) -- zero-copy via shared GPU memory. + - **Non-colocate mode**: NCCL broadcast (`NCCLWeightTransferEngine`) -- direct GPU transfer. + +**Current status**: Phase 1 is complete and verified. GRPO training runs successfully with Qwen2.5-0.5B + GSM8K on 4 GPUs in colocate mode. + +## 2. Motivation + +Slime currently assumes SGLang behavior in multiple places: + +- Rollout generation response format (`meta_info.finish_reason.type`, `output_token_logprobs`, optional `routed_experts`) +- Router control-plane interactions (`/workers`, `/list_workers`, `/abort_request`) +- Ray-based rollout server startup flow + +Supporting vLLM is therefore not just "swapping a URL" -- it requires compatibility layers on both the data plane and control plane. + +## 3. Goals and Non-Goals + +### Goals + +- Add `--rollout-backend {sglang,vllm}` +- Keep SGLang as the default, unchanged +- Keep the training-side interface stable +- Support GRPO rollout with explicit compatibility semantics and observability +- **Support colocate mode** (training and inference share GPUs, weight sync via CUDA IPC) + +### Non-Goals (Current Phase) + +- Full feature parity with all SGLang capabilities +- R3 routing replay on vLLM +- Multi-instance vLLM + router load balancing + +## 4. Architecture and Interface Changes + +### 4.1 Architecture Changes + +#### End-to-End Architecture Diagram + +```text + +--------------------+ + | TrainerLoop | + +---------+----------+ + | + +---------------------+---------------------+ + | (generate) (update_weights) + v v + +--------------------+ +------------------------------------+ + | RolloutFunction | | weight updater (auto-selected) | + | (sglang_rollout) | | colocate → UpdateWeightFromTensor | + +---------+----------+ | otherwise → ...FromDistributed | + | +---------------+--------------------+ + v | + +-------------------------------+ | + | RolloutBackendRequest | | + +---------------+---------------+ | + | v + v +----------------------------+ + +--------------------------+ | VLLMEngine (Ray actor) | + | RolloutBackendClient | | weight transfer backend: | + +------------+-------------+ | colocate → IPC | + | | otherwise → NCCL | + +-------------+-------------+ +-------------+--------------+ + | | | + v v v ++------------+ +-----------+ +---------------+ +| SGLang | | VLLM | | vLLM server | +| Client | | Client | | /update_weights| ++-----+------+ +-----+-----+ | /pause /resume | + | | | /sleep /wake_up| + v v +---------------+ + +-----------+ +-------------+ + | SGLang | | vLLM server | + | Router | | /v1/compl. | + +-----------+ +-------------+ + \ / + \ / + v v + +-----------------------------------+ + | RolloutBackendResponse | + | (text, token_ids, token_logprobs, | + | finish_reason, backend_raw) | + +----------------+------------------+ + | + v + +-----------------------------------+ + | SampleUpdate + Training Pipeline | + | (backend-agnostic consumption) | + +-----------------------------------+ +``` + +#### Component Responsibility Comparison + +| Component | Before RFC | After RFC | +|---|---|---| +| `RolloutFunction` | Generic rollout + SGLang protocol details | Generic rollout orchestration only | +| Backend protocol layer | Implicit in rollout logic | Explicit `RolloutBackendClient` adapter | +| `SGLangClient` | Did not exist (logic scattered) | Owns all SGLang request/response/control-plane details | +| `VLLMClient` | Did not exist | Owns vLLM `/v1/completions` request/response mapping | +| Training/sample pipeline | Indirectly consumed SGLang-format fields | Consumes unified contract fields, backend-agnostic | +| Weight sync | SGLang IPC or NCCL only | Auto-adapts: SGLang IPC / vLLM IPC / vLLM NCCL | + +#### Control-Plane Behavior Split + +| Control-plane behavior | SGLang path | vLLM path | +|---|---|---| +| Worker registration/discovery API | Supported | Not needed | +| Worker-level abort | Supported (`/abort_request`) | Degraded to timeout/cancel strategy | +| Routed experts replay | Supported | Explicitly unsupported (capability-gated) | +| Health check | Existing SGLang/SlimeRouter | `GET /health` direct connection | +| Memory management (sleep/wake) | `release/resume_memory_occupation` | `POST /sleep` / `POST /wake_up` | + +#### A) Rollout Backend Abstraction Layer + +- `RolloutBackendClient` (`slime/rollout/backends/base_client.py`): defines `generate()` + `capabilities` abstract interface +- `BackendCapabilities` dataclass: declares `supports_abort`, `supports_routed_experts`, `supports_prompt_logprobs` +- `SGLangClient` (extracted from existing code), `VLLMClient` (new) + +**Impact**: Rollout logic calls a unified interface and no longer directly embeds SGLang HTTP semantics. + +#### B) Rollout Execution Path + +- `sglang_rollout.generate` constructs `RolloutBackendRequest` and consumes `RolloutBackendResponse` +- Selects `SGLangClient` or `VLLMClient` based on `--rollout-backend` +- `VLLMClient` calls vLLM's OpenAI-compatible `/v1/completions` endpoint + +**Impact**: Training-side logic remains stable; backend details are encapsulated in adapters. + +#### C) RolloutManager Startup Path + +- Existing path (SGLang): unchanged +- vLLM path: `_start_vllm_rollout_servers()` creates a `VLLMEngine` Ray actor that starts and manages a local vLLM server process + +**Impact**: vLLM gets the same lifecycle management as SGLang. + +#### D) Weight Synchronization + +Weight synchronization supports two modes, automatically selected based on deployment: + +**Selection logic** (`actor.py`): +```python +update_weight_cls = UpdateWeightFromTensor if args.colocate else UpdateWeightFromDistributed +``` +Identical to the SGLang selection logic -- backend-agnostic. + +##### D.1) Colocate Mode: CUDA IPC + +The vLLM server starts with `--weight-transfer-config '{"backend": "ipc"}'`. + +Call chain: +```text +UpdateWeightFromTensor.update_weights() + → engine.pause_generation.remote() # VLLMEngine → POST /pause?mode=abort + → _send_to_colocated_vllm_engine() + → each training rank: + reduce_tensor(tensor) # create CUDA IPC handle + {gpu_uuid: ipc_handle} # keyed by physical GPU UUID + → dist.gather_object (Gloo) # collect handles from all TP ranks + → pickle + base64 encode + → engine.update_weights_from_tensor.remote() + → VLLMEngine → POST /update_weights + { "update_info": { + "names": [...], + "dtype_names": [...], + "shapes": [...], + "ipc_handles_pickled": "base64..." + }} + → vLLM IPCWeightTransferEngine.receive_weights() + → each TP worker looks up its IPC handle by GPU UUID + → func(*args) reconstructs tensor → load_weights() + → engine.continue_generation.remote() # VLLMEngine → POST /resume +``` + +Key points: +- The training process and vLLM workers share the same physical GPU; CUDA IPC enables zero-copy weight transfer +- For TP>1, IPC handles from all training ranks are merged via Gloo gather; each parameter contains a mapping of all GPU UUIDs +- The training side must keep tensor references alive until `ray.get()` returns (vLLM has finished reading) + +##### D.2) Non-Colocate Mode: NCCL Broadcast + +The vLLM server starts with `--weight-transfer-config '{"backend": "nccl"}'`. + +Call chain: +```text +UpdateWeightFromDistributed.update_weights() + → engine.init_weights_update_group.remote() # VLLMEngine → POST /init_weight_transfer_engine + → vLLM NCCLWeightTransferEngine initializes StatelessProcessGroup + PyNcclCommunicator + → training side: NCCLWeightTransferEngine.trainer_init() + → establishes a matching NCCL communicator with vLLM + → PyNcclCommunicator.broadcast(tensor, src=0) # direct GPU transfer to vLLM workers + → engine.update_weights_from_distributed.remote() + → VLLMEngine → POST /update_weights + { "update_info": { "names": [...], "dtype_names": [...], "shapes": [...] }} + → vLLM NCCLWeightTransferEngine.receive_weights() +``` + +Key points: +- The training side uses vLLM's `NCCLWeightTransferEngine.trainer_init()` to initialize NCCL, compatible with vLLM's internal `StatelessProcessGroup` + `PyNcclCommunicator` +- Cannot use `torch.distributed.init_process_group` (incompatible with vLLM) + +#### VLLMEngine Endpoint Mapping + +`VLLMEngine` exposes method signatures compatible with `SGLangEngine`. Internal HTTP endpoint mapping: + +| Method | VLLMEngine internal HTTP call | +|---|---| +| `pause_generation()` | `POST /pause?mode=abort` | +| `flush_cache()` | no-op | +| `continue_generation()` | `POST /resume` | +| `init_weights_update_group(...)` | `POST /init_weight_transfer_engine` | +| `update_weights_from_distributed(...)` | `POST /update_weights` (NCCL mode) | +| `update_weights_from_tensor(...)` | `POST /update_weights` (IPC mode) | +| `release_memory_occupation()` | `POST /sleep?level=1&mode=abort` | +| `resume_memory_occupation()` | `POST /wake_up` | +| `health_generate()` | `GET /health` | +| `shutdown()` | `process.terminate()` | + +**Impact**: Training-side code (`UpdateWeightFromTensor`, `UpdateWeightFromDistributed`, actor code) uses a unified interface, unaware of the specific backend. + +#### E) Router Decision + +vLLM does not use a router in the current phase: +- Single vLLM instance; `VLLMClient` connects directly to the local vLLM server port +- SGLang Model Gateway only supports SGLang workers, not applicable +- SlimeRouter is only needed for R3 / radix-tree caching scenarios + +### 4.2 Interface Changes + +#### A) New CLI Arguments + +- `--rollout-backend {sglang,vllm}` -- select rollout backend +- `--vllm-base-url` -- manually specify vLLM server address (not needed when auto-managed) +- `--vllm-model` -- model path for vLLM to load (defaults to `--hf-checkpoint`) +- `--vllm-max-retries` -- max retries for generation requests +- `--vllm-enforce-eager` -- disable CUDA graph (default True) + +#### B) New Internal Unified Interface + +Added to `slime/rollout/base_types.py`: + +- `RolloutBackendRequest`: input_ids, sampling_params, return_logprob, return_routed_experts, image_data, session_id +- `RolloutBackendResponse`: text, output_token_ids, output_token_logprobs, finish_reason (`stop|length|abort`), prompt_tokens, completion_tokens, backend_raw, routed_experts + +#### C) Capability Gating + +`BackendCapabilities` dataclass declares capabilities per backend: + +| Capability | SGLangClient | VLLMClient | +|---|---|---| +| `supports_abort` | True | False | +| `supports_routed_experts` | True | False | +| `supports_prompt_logprobs` | True | False | + +Unsupported features are explicitly gated, logged, or fail fast at call time. + +## 5. File-Level Change List + +### New Files + +| File | Description | +|---|---| +| `slime/rollout/backends/base_client.py` | `RolloutBackendClient` abstract interface + `BackendCapabilities` | +| `slime/rollout/backends/sglang_client.py` | SGLang rollout client (extracted from existing code) | +| `slime/rollout/backends/vllm_client.py` | vLLM rollout client (`/v1/completions` adapter) | +| `slime/rollout/backends/__init__.py` | Adapter exports | +| `slime/backends/vllm_utils/vllm_engine.py` | `VLLMEngine` Ray actor (process management + weight sync) | +| `run-qwen2.5-0.5B-vllm.sh` | vLLM validation script (Qwen2.5-0.5B + GSM8K + colocate) | + +### Modified Files + +| File | Description | +|---|---| +| `slime/rollout/base_types.py` | Added `RolloutBackendRequest` / `RolloutBackendResponse` | +| `slime/rollout/sglang_rollout.py` | Uses backend adapter instead of hard-coded SGLang protocol | +| `slime/utils/arguments.py` | Added `--rollout-backend`, vLLM arguments; sets sglang alias defaults for vLLM | +| `slime/ray/rollout.py` | `_start_vllm_rollout_servers()` to start VLLMEngine | +| `slime/backends/megatron_utils/actor.py` | Weight updater selection: `colocate → UpdateWeightFromTensor` (backend-agnostic) | +| `slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py` | Added `_send_to_colocated_vllm_engine()` (CUDA IPC path) | +| `slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py` | vLLM NCCL compatibility: `NCCLWeightTransferEngine.trainer_init()` + `PyNcclCommunicator.broadcast()` | + +### Unchanged Files + +| File | Description | +|---|---| +| `train.py` | Training loop is backend-agnostic | +| SGLang-related files | SGLang path is completely unaffected | + +### Not Used (Differences from Initial RFC Draft) + +- **No `worker_extension.py`**: vLLM's native weight transfer API fully meets requirements; no custom WorkerExtension needed +- **No `collective_rpc`**: Uses vLLM's `/init_weight_transfer_engine` and `/update_weights` endpoints + +## 6. Capability Compatibility Matrix + +| Capability | SGLang | vLLM | Handling | +|---|---|---|---| +| Token-level response logprobs | Supported | Supported (`choice.logprobs.token_logprobs`) | Adapter normalization | +| Output token IDs | `meta_info` | `choice.token_ids` | Adapter normalization | +| Finish reason | `stop\|length\|abort` | `stop\|length` + others | Canonical mapping | +| Worker-level abort | Supported | Not supported | Timeout/cancel degradation | +| Prompt logprobs (OPD) | Supported | Partial | Capability-gated | +| Routed experts (R3) | Supported | Not supported | Explicit gate | +| Colocate (GPU sharing) | Supported (FlattenedTensorBucket IPC) | Supported (CUDA IPC handles) | Native IPC per backend | +| Memory management | `release/resume_memory_occupation` | `POST /sleep` / `POST /wake_up` | Unified interface | +| `include_stop_str_in_output` | `no_stop_trim` | `include_stop_str_in_output` | Parameter mapping | + +## 7. Phased Plan + +### Phase 1: Managed vLLM GRPO Training ✅ Done + +**Goal**: Qwen2.5-0.5B GRPO training on 4 GPUs in colocate mode with GSM8K dataset, running successfully. + +Completed deliverables: +- `VLLMEngine` Ray actor: process lifecycle management, `/pause`/`/resume`/`/sleep`/`/wake_up` +- `VLLMClient` rollout adapter: `/v1/completions` request mapping, logprob/finish_reason/token_ids normalization +- `RolloutBackendRequest`/`Response` unified contract +- `SGLangClient` extraction +- `_start_vllm_rollout_servers` startup path +- Argument parsing branching (`--rollout-backend vllm`, sglang alias defaults) +- **Colocate mode**: CUDA IPC weight transfer (`_send_to_colocated_vllm_engine` → `IPCWeightTransferEngine`) +- **Non-colocate mode**: NCCL weight transfer (`NCCLWeightTransferEngine.trainer_init()` → `PyNcclCommunicator.broadcast()`) +- Validation script: `run-qwen2.5-0.5B-vllm.sh` + +Technical decisions: +- No router (single vLLM instance, direct connection) +- Colocate weight sync: CUDA IPC (training and vLLM share GPU memory, zero-copy) +- Non-colocate weight sync: vLLM native NCCL (`StatelessProcessGroup` + `PyNcclCommunicator`) +- Rollout function: reuses `sglang_rollout.py` (backend-agnostic orchestration) +- Training-side weight updater selection is identical to SGLang (`colocate → UpdateWeightFromTensor`) + +Acceptance results: +- ✅ Qwen2.5-0.5B + GSM8K + 4 GPU colocate training runs successfully +- ✅ Weight sync correctness: rollout outputs change with each training update +- ✅ Existing SGLang path unaffected + +### Phase 2: Multi-Instance vLLM + Async Training + +**Goal**: Support multi-instance vLLM rollout with [vllm-project/router](https://github.com/vllm-project/router) for load balancing, and verify async training (`train_async.py`) on larger models. + +Key deliverables: +- Integrate [vllm-project/router](https://github.com/vllm-project/router) as the vLLM-side load balancer +- `start_rollout_servers` launches N `VLLMEngine` actors + one vllm-router process +- `VLLMClient` generation requests routed through vllm-router instead of direct connection +- Verify async training (`train_async.py`) correctness with the vLLM backend + +Acceptance criteria: +- Multi-instance (e.g., 2-4 vLLM workers) rollout stable over 20+ rollout steps +- Async training (`train_async.py`) with no hangs or weight sync race conditions +- Throughput improvement compared to single-instance baseline +- Larger model verification (e.g., Qwen3-4B, TP=2) + +### Phase 3 (Future): Advanced Features + +- Abort/cancel strategy refinement +- Prompt logprob support (OPD scenarios) +- Deterministic computation verification +- Performance benchmarks (vLLM vs SGLang throughput comparison) + +## 8. Validation Strategy + +### Completed (Phase 1) + +- ✅ End-to-end GRPO training: `run-qwen2.5-0.5B-vllm.sh` (Qwen2.5-0.5B, GSM8K, 4 GPUs, colocate) +- ✅ Weight sync correctness verification +- ✅ SGLang path non-regression + +### Remaining + +- Unit tests: finish reason normalization, token/logprob alignment, capability-gated behavior +- Integration tests: SGLang vs vLLM response schema comparison +- Stress tests: higher concurrency / latency pressure + +## 9. Key Implementation Challenges and Solutions + +### 9.1 vLLM NCCL Incompatibility + +**Problem**: vLLM internally uses `StatelessProcessGroup` + `PyNcclCommunicator` for NCCL communication, which is incompatible with `torch.distributed.init_process_group`. + +**Solution**: The training side uses `NCCLWeightTransferEngine.trainer_init()` to initialize NCCL, ensuring a matching communicator is established with the vLLM side. + +### 9.2 vLLM Logprobs Format Differences + +**Problem**: vLLM's `/v1/completions` logprobs format (`choice.token_ids`, `choice.logprobs.token_logprobs`) differs from SGLang's. + +**Solution**: `VLLMClient` performs explicit mapping, converting vLLM's format into the unified `RolloutBackendResponse` format. diff --git a/rfc-vllm-rollout-backend.md b/rfc-vllm-rollout-backend.md new file mode 100644 index 0000000000..fc8e569793 --- /dev/null +++ b/rfc-vllm-rollout-backend.md @@ -0,0 +1,436 @@ +# RFC: 在 Slime 中支持 vLLM 作为 Rollout Backend + +## 1. 概要 + +本 RFC 提议在 Slime 中增加 **vLLM** 作为一等 rollout backend,同时保持现有 SGLang 行为不变,不影响 GRPO 训练流程。 + +核心设计思路: + +1. 定义 backend 无关的 rollout 请求/响应契约(`RolloutBackendRequest`/`RolloutBackendResponse`)。 +2. 通过 backend adapter(`SGLangClient`、`VLLMClient`)隔离协议差异。 +3. 对不等价能力(abort、routed experts、prompt logprobs)做显式 capability gating。 +4. Managed 模式 —— Slime 在 Ray 内管理 vLLM 进程生命周期,与 SGLang 路径同等级别。 +5. 权重同步利用 vLLM 原生 weight transfer API,根据部署模式自动选择后端: + - **Colocate 模式**:CUDA IPC(`IPCWeightTransferEngine`),GPU 共享内存零拷贝。 + - **Non-colocate 模式**:NCCL broadcast(`NCCLWeightTransferEngine`),GPU 直传。 + +**当前状态**:Phase 1 已完成并验证通过。Qwen2.5-0.5B + GSM8K + 4 GPU colocate 模式下 GRPO 训练正常运行。 + +## 2. 为什么需要这个 + +Slime 当前在多处假设 SGLang 行为: + +- rollout 生成响应格式(`meta_info.finish_reason.type`、`output_token_logprobs`、可选 `routed_experts`) +- router 控制面交互(`/workers`、`/list_workers`、`/abort_request`) +- Ray 内 rollout server 启动流程 + +因此支持 vLLM 不是"换个 URL",而是需要在数据面和控制面做兼容层。 + +## 3. 目标与非目标 + +### 目标 + +- 新增 `--rollout-backend {sglang,vllm}` +- SGLang 路径默认不变 +- 训练侧接口保持稳定 +- 支持 GRPO rollout,具有显式兼容性语义和可观测性 +- **支持 colocate 模式**(训练与推理共享 GPU,通过 CUDA IPC 同步权重) + +### 非目标(当前阶段) + +- 所有 SGLang 特性的完全对等 +- vLLM 上的 R3 路由回放 +- 多 vLLM 实例 + router 负载均衡 + +## 4. 架构和接口改动 + +### 4.1 架构改动 + +#### 端到端架构图 + +```text + +--------------------+ + | TrainerLoop | + +---------+----------+ + | + +---------------------+---------------------+ + | (generate) (update_weights) + v v + +--------------------+ +------------------------------------+ + | RolloutFunction | | weight updater (自动选择) | + | (sglang_rollout) | | colocate → UpdateWeightFromTensor | + +---------+----------+ | otherwise → ...FromDistributed | + | +---------------+--------------------+ + v | + +-------------------------------+ | + | RolloutBackendRequest | | + +---------------+---------------+ | + | v + v +----------------------------+ + +--------------------------+ | VLLMEngine (Ray actor) | + | RolloutBackendClient | | weight transfer backend: | + +------------+-------------+ | colocate → IPC | + | | otherwise → NCCL | + +-------------+-------------+ +-------------+--------------+ + | | | + v v v ++------------+ +-----------+ +---------------+ +| SGLang | | VLLM | | vLLM server | +| Client | | Client | | /update_weights| ++-----+------+ +-----+-----+ | /pause /resume | + | | | /sleep /wake_up| + v v +---------------+ + +-----------+ +-------------+ + | SGLang | | vLLM server | + | Router | | /v1/compl. | + +-----------+ +-------------+ + \ / + \ / + v v + +-----------------------------------+ + | RolloutBackendResponse | + | (text, token_ids, token_logprobs, | + | finish_reason, backend_raw) | + +----------------+------------------+ + | + v + +-----------------------------------+ + | SampleUpdate + Training Pipeline | + | (backend 无关消费) | + +-----------------------------------+ +``` + +#### 组件职责对照 + +| 组件 | RFC 前 | RFC 后 | +|---|---|---| +| `RolloutFunction` | 包含通用 rollout + SGLang 协议细节 | 只包含通用 rollout 编排 | +| Backend 协议层 | 隐含在 rollout 逻辑里 | 显式的 `RolloutBackendClient` adapter | +| `SGLangClient` | 不存在(逻辑分散) | 拥有 SGLang 请求/响应/控制面的全部细节 | +| `VLLMClient` | 不存在 | 拥有 vLLM `/v1/completions` 请求/响应映射 | +| 训练/sample 管线 | 间接消费 SGLang 格式字段 | 消费统一契约字段,backend 无关 | +| 权重同步 | 仅 SGLang IPC 或 NCCL | 自动适配:SGLang IPC / vLLM IPC / vLLM NCCL | + +#### 控制面行为拆分 + +| 控制面行为 | SGLang 路径 | vLLM 路径 | +|---|---|---| +| Worker 注册/发现 API | 支持 | 不需要 | +| Worker 级 abort | 支持(`/abort_request`) | 降级为超时/取消策略 | +| Routed experts 回放 | 支持 | 显式不支持(capability-gated) | +| 健康检查 | 现有 SGLang/SlimeRouter | `GET /health` 直连 | +| 内存管理 (sleep/wake) | `release/resume_memory_occupation` | `POST /sleep` / `POST /wake_up` | + +#### A) Rollout backend 抽象层 + +- `RolloutBackendClient`(`slime/rollout/backends/base_client.py`):定义 `generate()` + `capabilities` 抽象接口 +- `BackendCapabilities` dataclass:声明 `supports_abort`、`supports_routed_experts`、`supports_prompt_logprobs` +- `SGLangClient`(从现有代码提取)、`VLLMClient`(新建) + +**影响**:rollout 逻辑调用统一接口,不再直接嵌入 SGLang HTTP 语义。 + +#### B) Rollout 执行路径 + +- `sglang_rollout.generate` 构造 `RolloutBackendRequest`,消费 `RolloutBackendResponse` +- 根据 `--rollout-backend` 选择 `SGLangClient` 或 `VLLMClient` +- `VLLMClient` 调用 vLLM OpenAI-compatible `/v1/completions` 端点 + +**影响**:训练侧逻辑保持稳定,backend 细节移入 adapter。 + +#### C) RolloutManager 启动路径 + +- 现有路径(SGLang):保持不变 +- vLLM 路径:`_start_vllm_rollout_servers()` 创建 `VLLMEngine` Ray actor,启动管理本地 vLLM server 进程 + +**影响**:vLLM 获得与 SGLang 相同的生命周期管理。 + +#### D) 权重同步 + +权重同步支持两种模式,根据部署方式自动选择: + +**模式选择逻辑**(`actor.py`): +```python +update_weight_cls = UpdateWeightFromTensor if args.colocate else UpdateWeightFromDistributed +``` +与 SGLang 使用完全相同的选择逻辑,backend 无关。 + +##### D.1) Colocate 模式:CUDA IPC + +vLLM server 启动时配置 `--weight-transfer-config '{"backend": "ipc"}'`。 + +调用链: +```text +UpdateWeightFromTensor.update_weights() + → engine.pause_generation.remote() # VLLMEngine → POST /pause?mode=abort + → _send_to_colocated_vllm_engine() + → 每个训练 rank: + reduce_tensor(tensor) # 创建 CUDA IPC handle + {gpu_uuid: ipc_handle} # 以物理 GPU UUID 为 key + → dist.gather_object (Gloo) # 收集所有 TP rank 的 handles + → pickle + base64 编码 + → engine.update_weights_from_tensor.remote() + → VLLMEngine → POST /update_weights + { "update_info": { + "names": [...], + "dtype_names": [...], + "shapes": [...], + "ipc_handles_pickled": "base64..." + }} + → vLLM IPCWeightTransferEngine.receive_weights() + → 每个 TP worker 根据自己的 GPU UUID 查找 IPC handle + → func(*args) 重建 tensor → load_weights() + → engine.continue_generation.remote() # VLLMEngine → POST /resume +``` + +关键点: +- 训练进程和 vLLM worker 共享同一物理 GPU,CUDA IPC 实现零拷贝权重传输 +- TP>1 时,各 training rank 的 IPC handles 通过 Gloo gather 合并,每个参数包含所有 GPU UUID 的映射 +- 训练侧必须保持 tensor 引用直到 `ray.get()` 返回(vLLM 读取完毕) + +##### D.2) Non-colocate 模式:NCCL broadcast + +vLLM server 启动时配置 `--weight-transfer-config '{"backend": "nccl"}'`。 + +调用链: +```text +UpdateWeightFromDistributed.update_weights() + → engine.init_weights_update_group.remote() # VLLMEngine → POST /init_weight_transfer_engine + → vLLM NCCLWeightTransferEngine 初始化 StatelessProcessGroup + PyNcclCommunicator + → 训练侧: NCCLWeightTransferEngine.trainer_init() + → 与 vLLM 建立匹配的 NCCL 通信组 + → PyNcclCommunicator.broadcast(tensor, src=0) # GPU 直传到 vLLM worker + → engine.update_weights_from_distributed.remote() + → VLLMEngine → POST /update_weights + { "update_info": { "names": [...], "dtype_names": [...], "shapes": [...] }} + → vLLM NCCLWeightTransferEngine.receive_weights() +``` + +关键点: +- 训练侧使用 vLLM 的 `NCCLWeightTransferEngine.trainer_init()` 初始化 NCCL,与 vLLM 内部的 `StatelessProcessGroup` + `PyNcclCommunicator` 兼容 +- 不能使用 `torch.distributed.init_process_group`(vLLM 不兼容) + +#### VLLMEngine 端点映射 + +`VLLMEngine` 暴露与 `SGLangEngine` 兼容的方法签名。内部 HTTP 端点映射: + +| 方法 | VLLMEngine 内部 HTTP 调用 | +|---|---| +| `pause_generation()` | `POST /pause?mode=abort` | +| `flush_cache()` | no-op | +| `continue_generation()` | `POST /resume` | +| `init_weights_update_group(...)` | `POST /init_weight_transfer_engine` | +| `update_weights_from_distributed(...)` | `POST /update_weights` (NCCL 模式) | +| `update_weights_from_tensor(...)` | `POST /update_weights` (IPC 模式) | +| `release_memory_occupation()` | `POST /sleep?level=1&mode=abort` | +| `resume_memory_occupation()` | `POST /wake_up` | +| `health_generate()` | `GET /health` | +| `shutdown()` | `process.terminate()` | + +**影响**:训练侧代码(`UpdateWeightFromTensor`、`UpdateWeightFromDistributed`、actor 代码)使用统一接口,不感知具体 backend。 + +#### E) Router 决策 + +vLLM 当前阶段不使用 router: +- 单 vLLM 实例,`VLLMClient` 直连本地 vLLM server 端口 +- SGLang Model Gateway 只支持 SGLang worker,不适用 +- SlimeRouter 仅在 R3 / radix-tree caching 时需要 + +### 4.2 接口改动 + +#### A) 新增 CLI 接口 + +- `--rollout-backend {sglang,vllm}` —— 选择 rollout backend +- `--vllm-base-url` —— 手动指定 vLLM server 地址(自动管理时无需设置) +- `--vllm-model` —— vLLM 加载的模型路径(默认同 `--hf-checkpoint`) +- `--vllm-max-retries` —— 生成请求最大重试次数 +- `--vllm-enforce-eager` —— 是否禁用 CUDA graph(默认 True) + +#### B) 新增内部统一接口 + +`slime/rollout/base_types.py` 中新增: + +- `RolloutBackendRequest`:input_ids、sampling_params、return_logprob、return_routed_experts、image_data、session_id +- `RolloutBackendResponse`:text、output_token_ids、output_token_logprobs、finish_reason(`stop|length|abort`)、prompt_tokens、completion_tokens、backend_raw、routed_experts + +#### C) 能力门控 + +`BackendCapabilities` dataclass 声明各 backend 支持的能力: + +| 能力 | SGLangClient | VLLMClient | +|---|---|---| +| `supports_abort` | True | False | +| `supports_routed_experts` | True | False | +| `supports_prompt_logprobs` | True | False | + +不支持的特性在调用时显式 gate、记录日志、或 fail fast。 + +## 5. 文件级改动清单 + +### 新增文件 + +| 文件 | 说明 | +|---|---| +| `slime/rollout/backends/base_client.py` | `RolloutBackendClient` 抽象接口 + `BackendCapabilities` | +| `slime/rollout/backends/sglang_client.py` | SGLang rollout client(从现有代码提取) | +| `slime/rollout/backends/vllm_client.py` | vLLM rollout client(`/v1/completions` 适配) | +| `slime/rollout/backends/__init__.py` | adapter 导出 | +| `slime/backends/vllm_utils/vllm_engine.py` | `VLLMEngine` Ray actor(进程管理 + 权重同步) | +| `run-qwen2.5-0.5B-vllm.sh` | vLLM 验证脚本(Qwen2.5-0.5B + GSM8K + colocate) | + +### 修改文件 + +| 文件 | 说明 | +|---|---| +| `slime/rollout/base_types.py` | 新增 `RolloutBackendRequest` / `RolloutBackendResponse` | +| `slime/rollout/sglang_rollout.py` | 使用 backend adapter 代替硬编码 SGLang 协议 | +| `slime/utils/arguments.py` | 新增 `--rollout-backend`、vLLM 参数;vLLM 时设置 sglang 别名默认值 | +| `slime/ray/rollout.py` | `_start_vllm_rollout_servers()` 启动 VLLMEngine | +| `slime/backends/megatron_utils/actor.py` | 权重更新器选择:`colocate → UpdateWeightFromTensor`(backend 无关) | +| `slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py` | 新增 `_send_to_colocated_vllm_engine()`(CUDA IPC 路径) | +| `slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py` | vLLM NCCL 兼容:`NCCLWeightTransferEngine.trainer_init()` + `PyNcclCommunicator.broadcast()` | + +### 不变文件 + +| 文件 | 说明 | +|---|---| +| `train.py` | 训练循环不感知 backend | +| SGLang 相关文件 | SGLang 路径完全不受影响 | + +### 未使用(与 RFC 初稿的差异) + +- **无 `worker_extension.py`**:vLLM 原生 weight transfer API 完全满足需求,无需自定义 WorkerExtension +- **无 `collective_rpc`**:使用 vLLM 的 `/init_weight_transfer_engine` 和 `/update_weights` 端点 + +## 6. 能力兼容矩阵 + +| 能力 | SGLang | vLLM | 处理方式 | +|---|---|---|---| +| Token 级响应 logprobs | 支持 | 支持(`choice.logprobs.token_logprobs`) | adapter 归一化 | +| Output token IDs | `meta_info` | `choice.token_ids` | adapter 归一化 | +| Finish reason | `stop\|length\|abort` | `stop\|length` + 其他 | canonical 映射 | +| Worker 级 abort | 支持 | 不支持 | 超时/取消降级 | +| Prompt logprobs(OPD) | 支持 | 部分 | capability-gated | +| Routed experts(R3) | 支持 | 不支持 | 显式 gate | +| Colocate(GPU 共享) | 支持(FlattenedTensorBucket IPC) | 支持(CUDA IPC handles) | 各自原生 IPC | +| 内存管理 | `release/resume_memory_occupation` | `POST /sleep` / `POST /wake_up` | 接口统一 | +| `include_stop_str_in_output` | `no_stop_trim` | `include_stop_str_in_output` | 参数映射 | + +## 7. 分阶段计划 + +### 第一阶段:Managed vLLM GRPO 训练 ✅ Done + +**目标**:Qwen2.5-0.5B GRPO 4 卡 colocate 训练,GSM8K 数据集,训练正常运行。 + +已完成交付物: +- `VLLMEngine` Ray actor:进程生命周期管理、`/pause`/`/resume`/`/sleep`/`/wake_up` +- `VLLMClient` rollout adapter:`/v1/completions` 请求映射、logprob/finish_reason/token_ids 归一化 +- `RolloutBackendRequest`/`Response` 统一契约 +- `SGLangClient` 提取 +- `_start_vllm_rollout_servers` 启动路径 +- 参数解析分流(`--rollout-backend vllm`,sglang 别名默认值) +- **Colocate 模式**:CUDA IPC 权重传输(`_send_to_colocated_vllm_engine` → `IPCWeightTransferEngine`) +- **Non-colocate 模式**:NCCL 权重传输(`NCCLWeightTransferEngine.trainer_init()` → `PyNcclCommunicator.broadcast()`) +- 验证脚本:`run-qwen2.5-0.5B-vllm.sh` + +技术决策: +- 不使用 router(单 vLLM 实例,直连) +- Colocate 权重同步:CUDA IPC(训练和 vLLM 共享 GPU 内存,零拷贝) +- Non-colocate 权重同步:vLLM 原生 NCCL(`StatelessProcessGroup` + `PyNcclCommunicator`) +- Rollout 函数:复用 `sglang_rollout.py`(backend 无关编排) +- 训练侧权重更新器选择与 SGLang 一致(`colocate → UpdateWeightFromTensor`) + +验收结果: +- ✅ Qwen2.5-0.5B + GSM8K + 4 GPU colocate 训练运行通过 +- ✅ 权重同步正确性:每轮 rollout 输出随训练更新变化 +- ✅ 现有 SGLang 路径不受影响 + +### 第二阶段:多 vLLM 实例 + 异步训推 + +**目标**:支持多 vLLM 实例 rollout,使用 [vllm-project/router](https://github.com/vllm-project/router) 做负载均衡,以异步训推(`train_async.py`)在更大模型上验证。 + +关键交付物: +- 集成 [vllm-project/router](https://github.com/vllm-project/router) 作为 vLLM 侧负载均衡器 +- `start_rollout_servers` 拉起 N 个 `VLLMEngine` actor + 一个 vllm-router 进程 +- `VLLMClient` 生成请求从直连改为走 vllm-router +- 验证异步训推(`train_async.py`)在 vLLM backend 下的正确性 + +验收标准: +- 多实例(如 2-4 个 vLLM worker)rollout 在 20+ rollout step 下稳定 +- 异步训推(`train_async.py`)无 hang、无权重同步竞态 +- 相比单实例基线有吞吐提升 +- 更大模型验证(如 Qwen3-4B,TP=2) + +### 第三阶段(未来):高级特性 + +- Abort/cancel 策略完善 +- Prompt logprob 支持(OPD 场景) +- 确定性计算验证 +- 性能基准测试(vLLM vs SGLang 吞吐量对比) + +## 8. 验证策略 + +### 已完成(Phase 1) + +- ✅ 端到端 GRPO 训练:`run-qwen2.5-0.5B-vllm.sh`(Qwen2.5-0.5B, GSM8K, 4 GPU, colocate) +- ✅ 权重同步正确性验证 +- ✅ SGLang 路径非回归 + +### 待完成 + +- 单元测试:finish reason 归一化、token/logprob 对齐、capability-gated 行为 +- 集成测试:SGLang vs vLLM 响应 schema 对比 +- 压测:更高并发 / 延迟压力 + +## 9. 实现中遇到的关键问题及解决 + +### 9.1 vLLM NCCL 不兼容 + +**问题**:vLLM 内部使用 `StatelessProcessGroup` + `PyNcclCommunicator` 管理 NCCL 通信,与 `torch.distributed.init_process_group` 不兼容。 + +**解决**:训练侧使用 `NCCLWeightTransferEngine.trainer_init()` 初始化 NCCL,确保与 vLLM 端建立匹配的通信组。 + +### 9.2 vLLM logprobs 格式差异 + +**问题**:vLLM `/v1/completions` 的 logprobs 格式(`choice.token_ids`、`choice.logprobs.token_logprobs`)与 SGLang 不同。 + +**解决**:`VLLMClient` 中做显式映射,将 vLLM 格式转换为 `RolloutBackendResponse` 统一格式。对于缺失 `token_ids` 的情况,回退到 tokenizer。 + +### 9.3 Colocate 模式下的 IPC 格式差异 + +**问题**:SGLang 使用 `FlattenedTensorBucket` + `MultiprocessingSerializer` 做 IPC,vLLM 使用独立的 CUDA IPC handles(`reduce_tensor`)。两者格式不兼容。 + +**解决**:新增 `_send_to_colocated_vllm_engine()` 函数,直接使用 `torch.multiprocessing.reductions.reduce_tensor` 创建 CUDA IPC handles,以 GPU UUID 为 key。TP>1 时通过 Gloo gather 合并各 rank 的 handles,然后 pickle + base64 编码后通过 `VLLMEngine` 转发到 vLLM 的 `/update_weights` 端点。 + +### 9.4 缺失的 sglang 参数别名 + +**问题**:vLLM backend 跳过 `sglang_validate_args()`,导致 `sglang_dp_size` 等别名未设置,后续代码 `AttributeError`。 + +**解决**:在 `arguments.py` 中为 vLLM backend 添加条件分支,设置 `sglang_dp_size`、`sglang_pp_size`、`sglang_ep_size`、`sglang_tp_size` 等默认值。 + +## 10. 风险与缓解 + +- **Logprob 语义不匹配**:严格 adapter 检查 + 归一化 + 测试 +- **Abort 不匹配**:能力模型 + 超时/取消降级 + 显式日志 +- **训练漂移**:强制 A/B 对照运行,固定 seed 和配置 +- **IPC handle 生命周期**:训练侧通过 `kept_alive` 列表 + `ray.get()` 阻塞确保 tensor 在 vLLM 读取完成前不被回收 + +## 11. 向 Slime 维护者提出的开放问题 + +1. 是否需要在 Phase 2 中将 vLLM 多实例 + router 作为默认部署模式? +2. 声称支持 GRPO 所需的最低 logprob 保真度是什么? +3. 是否应在完全对等前阻止 vLLM 的 OPD prompt-logprob 路径? +4. 推荐为默认 backend 之前需要满足什么质量门槛? +5. vLLM colocate IPC 路径是否需要支持混合模式(部分 colocate + 部分 distributed)? + +## 12. 请求决策 + +Phase 1 已实现并验证: + +- ✅ Contract-first adapter 架构 +- ✅ Managed vLLM 模式(与 SGLang 同等生命周期管理) +- ✅ 双模式权重同步(colocate: CUDA IPC / non-colocate: NCCL) +- ✅ 显式 capability-gated 行为 +- ✅ 严格 SGLang 非回归 + +请批准进入 Phase 2:多 vLLM 实例 + router + 异步训推。 + From 48fbde3c0ba7d1f1d68a427a99ca4baf3d4843ff Mon Sep 17 00:00:00 2001 From: SamitHuang <285365963@qq.com> Date: Tue, 3 Mar 2026 17:06:01 +0800 Subject: [PATCH 03/10] update Signed-off-by: SamitHuang <285365963@qq.com> --- goal_plan.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/goal_plan.md b/goal_plan.md index adb39bf047..d77618195b 100644 --- a/goal_plan.md +++ b/goal_plan.md @@ -1,17 +1,23 @@ ### 阶段一:打通Qwen2.5-0.5B GRPO 8卡同步/异步训练(train.py和train_async.py),GSM8K 数据集,loss/reward 收敛与 SGLang backend 基本一致,且满足确定性计算,多次重复运行Loss曲线完全一致。 +First Design and RFC by 03/06 + #### 初步方案: - 对标SGLang,Slime 在 Ray 内管理 vLLM 的完整生命周期,包括进程拉起、权重同步、推理暂停/恢复 - 暂不使用Router,SGLang Model Gateway仅只支持SGLang Worker,SlimeRouter仅在 R3 / radix-tree caching 时需要,Qwen2.5-0.5B 非 MoE 且用 token-in/token-out - 单vLLM实例,无router,通过vLLMClient 直连本地 vLLM 进程端口 -- 若训推不共卡,权重同步采用NCCL broadcast,对标SGLang update_weights_from_distributed (默认,优先用于支持异步训推) -- 若colocate,权重同步采用GPU IPC(vLLM update_weights_from_ipc, update_weights_from_tensor),对标SGLang update_weights_from_tensor +- 先支持和验证colocate,权重同步采用GPU IPC(vLLM update_weights_from_ipc, update_weights_from_tensor),对标SGLang update_weights_from_tensor,以验证Reproductivity +- 再支持训推不共卡,权重同步采用NCCL broadcast,对标SGLang update_weights_from_distributed (默认) #### 风险: - slime, sglang版本依赖,和vllm 0.16的版本依赖冲突(numpy, torch, transformers, etc) +- slime代码较挫,可靠性差,强依赖preset docker - 算力 -First Design and RFC by 03/06 + +#### Reference + +https://thudm.github.io/slime/advanced/reproducibility.html ### 阶段二:接入vllm-project/router,支持多实例vLLM From f8ceed67937f14bb7a2b95fedba0cf744077c317 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 4 Mar 2026 06:18:35 +0000 Subject: [PATCH 04/10] qwen2.5 0.5b non-colocate (first attempt ok, but nccl error later) Signed-off-by: samithuang <285365963@qq.com> --- ...wen2.5-0.5B-reproducibility-noncolocate.sh | 137 +++++++++++ run-qwen2.5-0.5B-vllm.sh | 137 +++++++++++ slime/backends/megatron_utils/actor.py | 3 +- .../update_weight_from_distributed.py | 73 ++++-- slime/backends/vllm_utils/__init__.py | 3 + slime/backends/vllm_utils/vllm_engine.py | 214 ++++++++++++++++++ slime/ray/rollout.py | 47 ++++ slime/rollout/backends/__init__.py | 10 + slime/rollout/backends/base_client.py | 31 +++ slime/rollout/backends/sglang_client.py | 82 +++++++ slime/rollout/backends/vllm_client.py | 91 ++++++++ slime/rollout/base_types.py | 26 +++ slime/rollout/sglang_rollout.py | 152 ++++++------- slime/utils/arguments.py | 29 ++- 14 files changed, 939 insertions(+), 96 deletions(-) create mode 100644 run-qwen2.5-0.5B-reproducibility-noncolocate.sh create mode 100644 run-qwen2.5-0.5B-vllm.sh create mode 100644 slime/backends/vllm_utils/__init__.py create mode 100644 slime/backends/vllm_utils/vllm_engine.py create mode 100644 slime/rollout/backends/__init__.py create mode 100644 slime/rollout/backends/base_client.py create mode 100644 slime/rollout/backends/sglang_client.py create mode 100644 slime/rollout/backends/vllm_client.py diff --git a/run-qwen2.5-0.5B-reproducibility-noncolocate.sh b/run-qwen2.5-0.5B-reproducibility-noncolocate.sh new file mode 100644 index 0000000000..5499949340 --- /dev/null +++ b/run-qwen2.5-0.5B-reproducibility-noncolocate.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# Non-colocate version of run-qwen2.5-0.5B-reproducibility.sh +# 2 GPUs: 1 for training, 1 for SGLang rollout + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/scripts/models/qwen2.5-0.5B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ + --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ +) + +ROLLOUT_ARGS=( + --prompt-data /root/gsm8k/train.parquet + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 100 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 1024 + --rollout-temperature 1 + + --global-batch-size 256 +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-entity samithuang + --wandb-project slime-rl + --wandb-group qwen2.5-0.5B-gsm8k-noncolocate +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + + --sglang-enable-deterministic-inference + --sglang-attention-backend flashinfer + + --deterministic-mode +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +ray start --head --node-ip-address 127.0.0.1 --num-gpus 2 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 1 \ + --num-gpus-per-node 2 \ + --rollout-num-gpus 1 \ + --calculate-per-token-loss \ + --use-slime-router \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh new file mode 100644 index 0000000000..bcc6807c69 --- /dev/null +++ b/run-qwen2.5-0.5B-vllm.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# vLLM rollout backend validation script (Phase 1) +# Based on run-qwen2.5-0.5B-reproducibility.sh + +# for rerun the task +pkill -9 vllm +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + + +set -ex + +export PYTHONBUFFERED=16 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/scripts/models/qwen2.5-0.5B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/ + --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ +) + +ROLLOUT_ARGS=( + --prompt-data /root/gsm8k/train.parquet + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout 100 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 1024 + --rollout-temperature 1 + + --global-batch-size 256 +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data gsm8k /root/gsm8k/test.parquet + --n-samples-per-eval-prompt 1 + --eval-max-response-len 1024 + --eval-top-k 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-host https://wandb.ai/ + --wandb-entity samithuang + --wandb-project slime-rl + --wandb-group qwen2.5-0.5B-gsm8k-vllm +) + +VLLM_ARGS=( + --rollout-backend vllm + --rollout-num-gpus-per-engine 1 + --sglang-server-concurrency 512 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --deterministic-mode +) + +ray start --head --node-ip-address 127.0.0.1 --num-gpus 2 --disable-usage-stats + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NCCL_ALGO": "Ring", + "NCCL_P2P_DISABLE": "1", + "NCCL_DEBUG": "INFO", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 1 \ + --num-gpus-per-node 2 \ + --rollout-num-gpus 1 \ + --calculate-per-token-loss \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${VLLM_ARGS[@]} \ + ${MISC_ARGS[@]} diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 9abb2f96ad..1f45e5fc94 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -128,7 +128,8 @@ def init( if self.args.vocab_size is None: self.args.vocab_size = self.tokenizer.vocab_size - update_weight_cls = UpdateWeightFromTensor if self.args.colocate else UpdateWeightFromDistributed + use_tensor_update = self.args.colocate and getattr(self.args, "rollout_backend", "sglang") != "vllm" + update_weight_cls = UpdateWeightFromTensor if use_tensor_update else UpdateWeightFromDistributed self.weight_updater = update_weight_cls( self.args, self.model, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 7b5b7817f1..44649f87f4 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,3 +1,4 @@ +import os import socket import time from argparse import Namespace @@ -241,6 +242,7 @@ def _update_bucket_weights_from_distributed( self.weight_version, self.rollout_engines, converted_named_tensors, + use_vllm=_is_vllm_backend(self.args), ) ray.get(refs) @@ -249,6 +251,10 @@ def _update_bucket_weights_from_distributed( pbar.update(1) +def _is_vllm_backend(args: Namespace) -> bool: + return getattr(args, "rollout_backend", "sglang") == "vllm" + + def connect_rollout_engines_from_distributed( args: Namespace, group_name: str, @@ -261,6 +267,10 @@ def connect_rollout_engines_from_distributed( ``engine_gpu_counts`` gives the number of GPUs per engine. When engines have heterogeneous TP sizes (e.g. prefill TP=2, decode TP=4), each engine occupies a different number of ranks in the NCCL group. + + For vLLM backend, uses vLLM's StatelessProcessGroup + PyNcclCommunicator + instead of torch.distributed, because vLLM's weight transfer engine uses + its own NCCL initialization protocol. """ if engine_gpu_counts is None: engine_gpu_counts = [args.rollout_num_gpus_per_engine] * len(rollout_engines) @@ -287,13 +297,39 @@ def connect_rollout_engines_from_distributed( ) for i, engine in enumerate(rollout_engines) ] - model_update_groups = init_process_group( - backend="nccl", - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) + + if _is_vllm_backend(args): + # vLLM uses StatelessProcessGroup + PyNcclCommunicator for weight transfer. + # The training side must use the same mechanism for NCCL compatibility. + # + # Disable P2P transport: in colocate mode the trainer and vLLM server + # share the same physical GPU but have different CUDA_VISIBLE_DEVICES, + # which causes NCCL P2P (cudaIpc*) to fail with "invalid argument". + # SHM transport works correctly in this scenario. + from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine + + old_p2p = os.environ.get("NCCL_P2P_DISABLE") + os.environ["NCCL_P2P_DISABLE"] = "1" + try: + model_update_groups = NCCLWeightTransferEngine.trainer_init({ + "master_address": master_address, + "master_port": master_port, + "world_size": world_size, + }) + finally: + if old_p2p is None: + os.environ.pop("NCCL_P2P_DISABLE", None) + else: + os.environ["NCCL_P2P_DISABLE"] = old_p2p + else: + model_update_groups = init_process_group( + backend="nccl", + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + ray.get(refs) return model_update_groups @@ -303,19 +339,24 @@ def disconnect_rollout_engines_from_distributed(args, group_name, model_update_g Destroy NCCL on training and engines. """ refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines] - dist.destroy_process_group(model_update_groups) + if _is_vllm_backend(args): + model_update_groups = None + else: + dist.destroy_process_group(model_update_groups) ray.get(refs) def update_weights_from_distributed( group_name: str, - group: dist.ProcessGroup, + group, weight_version: int, rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], + use_vllm: bool = False, ) -> list[ObjectRef]: """ Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). + For vLLM, uses PyNcclCommunicator.broadcast instead of dist.broadcast. """ refs = [ engine.update_weights_from_distributed.remote( @@ -328,11 +369,15 @@ def update_weights_from_distributed( for engine in rollout_engines ] - handles = [] - for _, param in converted_named_tensors: - handles.append(dist.broadcast(param.data, 0, group=group, async_op=True)) - for handle in handles: - handle.wait() + if use_vllm: + for _, param in converted_named_tensors: + group.broadcast(param.data, src=0, stream=torch.cuda.current_stream()) + else: + handles = [] + for _, param in converted_named_tensors: + handles.append(dist.broadcast(param.data, 0, group=group, async_op=True)) + for handle in handles: + handle.wait() return refs diff --git a/slime/backends/vllm_utils/__init__.py b/slime/backends/vllm_utils/__init__.py new file mode 100644 index 0000000000..ea4e7311c3 --- /dev/null +++ b/slime/backends/vllm_utils/__init__.py @@ -0,0 +1,3 @@ +from slime.backends.vllm_utils.vllm_engine import VLLMEngine + +__all__ = ["VLLMEngine"] diff --git a/slime/backends/vllm_utils/vllm_engine.py b/slime/backends/vllm_utils/vllm_engine.py new file mode 100644 index 0000000000..12283de21e --- /dev/null +++ b/slime/backends/vllm_utils/vllm_engine.py @@ -0,0 +1,214 @@ +"""VLLMEngine: Ray actor that launches and manages a vLLM server.""" + +import logging +import os +import subprocess +import tempfile +import time + +import requests + +from slime.ray.ray_actor import RayActor +from slime.utils.http_utils import get_host_info +from slime.utils.misc import get_free_port + +logger = logging.getLogger(__name__) + + +class VLLMEngine(RayActor): + """Ray actor that runs vLLM server with same interface as SGLangEngine for weight sync.""" + + def __init__(self, args, rank: int, base_gpu_id: int | None = None, gpu_ids: list[int] | None = None, **kwargs): + self.args = args + self.rank = rank + self.base_gpu_id = base_gpu_id or 0 + self.gpu_ids = gpu_ids or [self.base_gpu_id] + self.server_host = None + self.server_port = None + self.process = None + self._log_file = None + + def init(self, port=None, host=None, **kwargs): + self.server_host = host or get_host_info()[1] + self.server_port = port or get_free_port(15000) + + model = getattr(self.args, "vllm_model", None) or self.args.hf_checkpoint + tp = self.args.rollout_num_gpus_per_engine + gpu_ids = self.gpu_ids[:tp] + cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cvd: + visible = [int(x) for x in cvd.split(",") if x.strip()] + dev_str = ",".join(str(visible[gid]) if gid < len(visible) else str(gid) for gid in gpu_ids) + else: + dev_str = ",".join(str(g) for g in gpu_ids) + + cmd = [ + "vllm", "serve", model, + "--tensor-parallel-size", str(tp), + "--port", str(self.server_port), + "--host", "0.0.0.0", + "--weight-transfer-config", '{"backend": "nccl"}', + ] + if getattr(self.args, "offload_rollout", False): + cmd.append("--enable-sleep-mode") + if getattr(self.args, "vllm_enforce_eager", True): + cmd.append("--enforce-eager") + if getattr(self.args, "fp16", False): + cmd.extend(["--dtype", "float16"]) + + env = os.environ.copy() + env["VLLM_SERVER_DEV_MODE"] = "1" + env["CUDA_VISIBLE_DEVICES"] = dev_str + env.setdefault("NCCL_DEBUG", "INFO") + env.setdefault("NCCL_DEBUG_SUBSYS", "ALL") + env["NCCL_P2P_DISABLE"] = "1" + + self._log_file = tempfile.NamedTemporaryFile( + prefix="vllm_engine_", suffix=".log", delete=False, mode="w" + ) + logger.info("Launching vLLM: cmd=%s, CUDA_VISIBLE_DEVICES=%s, log=%s", + " ".join(cmd), dev_str, self._log_file.name) + self.process = subprocess.Popen( + cmd, + env=env, + stdout=self._log_file, + stderr=subprocess.STDOUT, + ) + self._wait_healthy() + + def _wait_healthy(self, timeout=300): + base = f"http://{self.server_host}:{self.server_port}" + start = time.time() + while time.time() - start < timeout: + try: + r = requests.get(f"{base}/health", timeout=5) + if r.status_code == 200: + logger.info("vLLM server healthy at %s:%s", self.server_host, self.server_port) + return + except Exception: + pass + if self.process and self.process.poll() is not None: + log_tail = self._read_log_tail() + raise RuntimeError(f"vLLM process exited with code {self.process.returncode}.\n{log_tail}") + time.sleep(2) + log_tail = self._read_log_tail() + raise TimeoutError(f"vLLM server failed to become healthy within {timeout}s.\n{log_tail}") + + def _read_log_tail(self, n=200): + if not self._log_file: + return "" + try: + self._log_file.flush() + with open(self._log_file.name) as f: + lines = f.readlines() + return "".join(lines[-n:]) + except Exception: + return "" + + def _post(self, path: str, json_data: dict | None = None, params: dict | None = None): + url = f"http://{self.server_host}:{self.server_port}{path}" + kwargs = {"timeout": 120} + if json_data is not None: + kwargs["json"] = json_data + if params is not None: + kwargs["params"] = params + r = requests.post(url, **kwargs) + if not r.ok: + body = r.text[:2000] if r.text else "(empty)" + log_tail = self._read_log_tail(50) + logger.error( + "vLLM %s returned %s: %s\n--- vLLM log tail ---\n%s", + path, r.status_code, body, log_tail, + ) + r.raise_for_status() + try: + return r.json() + except Exception: + return None + + def health_generate(self, timeout: float = 5.0) -> bool: + try: + r = requests.get( + f"http://{self.server_host}:{self.server_port}/health", + timeout=timeout, + ) + r.raise_for_status() + return True + except requests.RequestException: + return False + + def pause_generation(self): + self._post("/pause", params={"mode": "abort"}) + + def flush_cache(self): + pass + + def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name=None, backend=None): + logger.info( + "Initializing NCCL weight transfer: master=%s:%s, rank_offset=%d, world_size=%d", + master_address, master_port, rank_offset, world_size, + ) + self._post("/init_weight_transfer_engine", json_data={ + "init_info": { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + } + }) + + def update_weights_from_distributed(self, names, dtypes, shapes, group_name=None, flush_cache=False, weight_version=None): + dtype_names = [str(d).replace("torch.", "") for d in dtypes] + shape_lists = [list(s) for s in shapes] + self._post("/update_weights", json_data={ + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shape_lists, + "packed": False, + } + }) + + def continue_generation(self): + self._post("/resume") + + def destroy_weights_update_group(self, group_name): + pass + + def release_memory_occupation(self): + try: + self._post("/sleep", params={"level": "1", "mode": "abort"}) + except requests.RequestException as e: + logger.warning("vLLM sleep failed (need --enable-sleep-mode?): %s", e) + + def resume_memory_occupation(self, tags: list[str] | None = None): + try: + params = {} + if tags: + params["tags"] = tags + self._post("/wake_up", params=params) + except requests.RequestException as e: + logger.warning("vLLM wake_up failed: %s", e) + + def get_weight_version(self): + return None + + def check_weights(self, action: str): + pass + + def post_process_weights(self, **kwargs): + pass + + def shutdown(self): + if self.process: + self.process.terminate() + try: + self.process.wait(timeout=10) + except subprocess.TimeoutExpired: + self.process.kill() + self.process = None + if self._log_file: + try: + self._log_file.close() + except Exception: + pass diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 8f919c1172..ba18fb9c47 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -16,6 +16,7 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from slime.backends.sglang_utils.sglang_engine import SGLangEngine +from slime.backends.vllm_utils.vllm_engine import VLLMEngine from slime.rollout.base_types import call_rollout_fn from slime.utils import logging_utils from slime.utils.health_monitor import RolloutHealthMonitor @@ -1020,6 +1021,49 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool return router_ip, router_port +def _start_vllm_rollout_servers(args, pg) -> dict[str, RolloutServer]: + """Start vLLM rollout server (single instance, no router).""" + pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg + tp = args.rollout_num_gpus_per_engine + gpu_ids = [int(reordered_gpu_ids[i]) for i in range(tp)] + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg_obj, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[0], + ) + + env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} + + VLLMRayActor = ray.remote(VLLMEngine) + engine = VLLMRayActor.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=scheduling_strategy, + runtime_env={"env_vars": env_vars}, + ).remote(args, rank=0, base_gpu_id=gpu_ids[0], gpu_ids=gpu_ids) + + host, port = ray.get(engine._get_current_node_ip_and_free_port.remote(start_port=15000)) + ray.get(engine.init.remote(port=port, host=host)) + args.vllm_base_url = f"http://{host}:{port}" + args.sglang_router_ip = host + args.sglang_router_port = port + + group = EngineGroup( + args=args, + pg=pg, + all_engines=[engine], + num_gpus_per_engine=args.rollout_num_gpus_per_engine, + num_new_engines=1, + worker_type="regular", + rank_offset=0, + gpu_offset=0, + sglang_overrides={}, + router_ip=host, + router_port=port, + ) + return {"default": RolloutServer(engine_groups=[group], router_ip=host, router_port=port, model_name="default")} + + def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: """Start rollout servers: one per model, each with its own router. @@ -1033,6 +1077,9 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if getattr(args, "rollout_backend", "sglang") == "vllm": + return _start_vllm_rollout_servers(args, pg) + config = _resolve_sglang_config(args) servers: dict[str, RolloutServer] = {} diff --git a/slime/rollout/backends/__init__.py b/slime/rollout/backends/__init__.py new file mode 100644 index 0000000000..49c4a86eef --- /dev/null +++ b/slime/rollout/backends/__init__.py @@ -0,0 +1,10 @@ +from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient +from slime.rollout.backends.sglang_client import SGLangClient +from slime.rollout.backends.vllm_client import VLLMClient + +__all__ = [ + "BackendCapabilities", + "RolloutBackendClient", + "SGLangClient", + "VLLMClient", +] diff --git a/slime/rollout/backends/base_client.py b/slime/rollout/backends/base_client.py new file mode 100644 index 0000000000..1730d6cbac --- /dev/null +++ b/slime/rollout/backends/base_client.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from slime.rollout.base_types import RolloutBackendRequest, RolloutBackendResponse + + +@dataclass +class BackendCapabilities: + supports_abort: bool + supports_routed_experts: bool + supports_prompt_logprobs: bool + + +class RolloutBackendClient(ABC): + @property + @abstractmethod + def capabilities(self) -> BackendCapabilities: + ... + + @abstractmethod + async def generate( + self, + request: RolloutBackendRequest, + base_url: str, + headers: dict | None = None, + ) -> RolloutBackendResponse: + ... + + async def abort(self) -> list[str]: + """Return worker URLs for abort. Empty for backends without worker-level abort.""" + return [] diff --git a/slime/rollout/backends/sglang_client.py b/slime/rollout/backends/sglang_client.py new file mode 100644 index 0000000000..505070c37d --- /dev/null +++ b/slime/rollout/backends/sglang_client.py @@ -0,0 +1,82 @@ +import logging + +import numpy as np +import pybase64 +import sglang_router +from packaging.version import parse + +from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient +from slime.rollout.base_types import RolloutBackendRequest, RolloutBackendResponse +from slime.utils.http_utils import get, post + +logger = logging.getLogger(__name__) + + +class SGLangClient(RolloutBackendClient): + def __init__(self, args): + self.args = args + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_abort=True, + supports_routed_experts=bool(getattr(self.args, "use_rollout_routing_replay", False)), + supports_prompt_logprobs=True, + ) + + async def generate( + self, + request: RolloutBackendRequest, + base_url: str, + headers: dict | None = None, + ) -> RolloutBackendResponse: + payload = { + "input_ids": request.input_ids, + "sampling_params": request.sampling_params, + "return_logprob": request.return_logprob, + "return_routed_experts": request.return_routed_experts, + } + if request.image_data: + payload["image_data"] = request.image_data + + url = f"{base_url.rstrip('/')}/generate" + output = await post(url, payload, headers=headers) + + meta = output.get("meta_info", {}) + logprobs = meta.get("output_token_logprobs", []) + output_token_ids = [item[1] for item in logprobs] + output_token_logprobs = [item[0] for item in logprobs] + + finish_reason = meta.get("finish_reason", {}).get("type", "stop") + routed_experts = None + if "routed_experts" in meta and self.capabilities.supports_routed_experts: + num_layers = getattr(self.args, "num_layers", 0) + moe_topk = getattr(self.args, "moe_router_topk", 1) + if num_layers and moe_topk: + routed_experts = np.frombuffer( + pybase64.b64decode(meta["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(request.input_ids) + len(output_token_ids) - 1, + num_layers, + moe_topk, + ) + + return RolloutBackendResponse( + text=output.get("text", ""), + output_token_ids=output_token_ids, + output_token_logprobs=output_token_logprobs, + finish_reason=finish_reason, + prompt_tokens=meta.get("prompt_tokens", len(request.input_ids)), + completion_tokens=meta.get("completion_tokens", len(output_token_ids)), + backend_raw=output, + routed_experts=routed_experts, + ) + + async def abort(self) -> list[str]: + base = f"http://{self.args.sglang_router_ip}:{self.args.sglang_router_port}" + if parse(sglang_router.__version__) <= parse("0.2.1") or getattr(self.args, "use_slime_router", False): + r = await get(f"{base}/list_workers") + return r.get("urls", []) + r = await get(f"{base}/workers") + return [w["url"] for w in r.get("workers", [])] diff --git a/slime/rollout/backends/vllm_client.py b/slime/rollout/backends/vllm_client.py new file mode 100644 index 0000000000..537569499d --- /dev/null +++ b/slime/rollout/backends/vllm_client.py @@ -0,0 +1,91 @@ +import logging + +from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient +from slime.rollout.base_types import RolloutBackendRequest, RolloutBackendResponse +from slime.utils.http_utils import post + +logger = logging.getLogger(__name__) + +_FINISH_REASON_MAP = { + "stop": "stop", + "length": "length", + "abort": "abort", + "end_turn": "stop", + "max_tokens": "length", +} + + +class VLLMClient(RolloutBackendClient): + def __init__(self, args): + self.args = args + self._max_retries = getattr(args, "vllm_max_retries", 3) + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_abort=False, + supports_routed_experts=False, + supports_prompt_logprobs=False, + ) + + async def generate( + self, + request: RolloutBackendRequest, + base_url: str, + headers: dict | None = None, + ) -> RolloutBackendResponse: + sp = request.sampling_params + payload = { + "prompt": request.input_ids, + "max_tokens": sp.get("max_new_tokens", 1024), + "temperature": sp.get("temperature", 1.0), + "top_p": sp.get("top_p", 1.0), + "top_k": sp.get("top_k", -1), + "stop": sp.get("stop"), + "stop_token_ids": sp.get("stop_token_ids"), + "skip_special_tokens": sp.get("skip_special_tokens", True), + "logprobs": 1 if request.return_logprob else None, + "include_stop_str_in_output": sp.get("no_stop_trim", False), + "return_token_ids": True, + } + payload = {k: v for k, v in payload.items() if v is not None} + + url = f"{base_url.rstrip('/')}/v1/completions" + output = await post(url, payload, headers=headers) + + choice = output.get("choices", [{}])[0] + text = choice.get("text", "") + finish = choice.get("finish_reason", "stop") + finish_reason = _FINISH_REASON_MAP.get(finish, "stop") + + # vLLM /v1/completions response format: + # choice.token_ids: list[int] (output token IDs) + # choice.logprobs.token_logprobs: list[float|None] + # choice.logprobs.tokens: list[str] (token text, not IDs) + output_token_ids = choice.get("token_ids") or [] + logprobs_obj = choice.get("logprobs") or {} + raw_logprobs = logprobs_obj.get("token_logprobs") or [] + output_token_logprobs = [float(lp) if lp is not None else 0.0 for lp in raw_logprobs] + + # If token_ids not in response, fall back to tokenizer + if not output_token_ids and text: + logger.warning("vLLM response missing token_ids, falling back to tokenizer") + from slime.utils.processing_utils import load_tokenizer + tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True) + output_token_ids = tokenizer.encode(text, add_special_tokens=False) + + # Ensure logprobs list matches token count + if len(output_token_logprobs) < len(output_token_ids): + output_token_logprobs.extend([0.0] * (len(output_token_ids) - len(output_token_logprobs))) + + usage = output.get("usage", {}) + return RolloutBackendResponse( + text=text, + output_token_ids=output_token_ids, + output_token_logprobs=output_token_logprobs, + finish_reason=finish_reason, + prompt_tokens=usage.get("prompt_tokens", len(request.input_ids)), + completion_tokens=usage.get("completion_tokens", len(output_token_ids)), + backend_raw=output, + routed_experts=None, + ) diff --git a/slime/rollout/base_types.py b/slime/rollout/base_types.py index f0eb0d96ce..75c140778a 100644 --- a/slime/rollout/base_types.py +++ b/slime/rollout/base_types.py @@ -4,6 +4,32 @@ from slime.utils.types import Sample +@dataclass +class RolloutBackendRequest: + """Backend-agnostic rollout request.""" + + input_ids: list[int] + sampling_params: dict[str, Any] + return_logprob: bool = True + return_routed_experts: bool = False + image_data: list[str] | None = None + session_id: str | None = None + + +@dataclass +class RolloutBackendResponse: + """Backend-agnostic rollout response.""" + + text: str + output_token_ids: list[int] + output_token_logprobs: list[float] + finish_reason: str # "stop" | "length" | "abort" + prompt_tokens: int + completion_tokens: int + backend_raw: dict + routed_experts: Any = None + + @dataclass class RolloutFnTrainOutput: samples: list[list[Sample]] diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 042fdab315..0b59be4477 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -9,17 +9,14 @@ from typing import Any import numpy as np -import pybase64 -import sglang_router -from packaging.version import parse from tqdm import tqdm -from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from slime.rollout.base_types import RolloutBackendRequest, RolloutBackendResponse, RolloutFnEvalOutput, RolloutFnTrainOutput from slime.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter from slime.utils.async_utils import run from slime.utils.data import Dataset from slime.utils.eval_config import EvalDatasetConfig -from slime.utils.http_utils import get, post +from slime.utils.http_utils import post from slime.utils.misc import SingletonMeta, load_function from slime.utils.processing_utils import ( build_processor_kwargs, @@ -33,6 +30,37 @@ __all__ = ["generate_rollout"] + +def _get_backend_client(args): + backend = getattr(args, "rollout_backend", "sglang") + if backend == "vllm": + from .backends.vllm_client import VLLMClient + return VLLMClient(args) + from .backends.sglang_client import SGLangClient + return SGLangClient(args) + + +def _apply_backend_response(sample, resp: RolloutBackendResponse, args): + """Apply RolloutBackendResponse to sample.""" + sample.tokens = sample.tokens + resp.output_token_ids + sample.response_length += len(resp.output_token_ids) + sample.response += resp.text + if sample.loss_mask is not None: + assert args.partial_rollout and args.mask_offpolicy_in_partial_rollout + sample.loss_mask += [1] * len(resp.output_token_ids) + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += resp.output_token_logprobs + if resp.routed_experts is not None: + sample.rollout_routed_experts = resp.routed_experts + meta = { + "finish_reason": {"type": resp.finish_reason}, + "prompt_tokens": resp.prompt_tokens, + "completion_tokens": resp.completion_tokens, + **{k: v for k, v in resp.backend_raw.get("meta_info", {}).items() if k not in ("finish_reason",)}, + } + sample.update_from_meta_info(args, meta) + logger = logging.getLogger(__name__) @@ -106,13 +134,11 @@ def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: - """Generate using traditional SGLang router with token-based workflow""" + """Generate using backend client (SGLang or vLLM).""" if args.ci_test: assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - assert ( sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED ), f"Sample status is {sample.status}" @@ -137,70 +163,41 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.status = Sample.Status.TRUNCATED return sample - # Prepare payload for sglang server - payload = { - "sampling_params": sampling_params, - "return_logprob": True, - } + input_ids = sample.tokens if len(sample.response) > 0 else prompt_ids + if not sample.tokens: + sample.tokens = list(input_ids) - if args.use_rollout_routing_replay: - payload["return_routed_experts"] = True - - if sample.multimodal_inputs and sample.multimodal_inputs["images"]: - image_data = sample.multimodal_inputs["images"] - payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - - # Use existing tokens for multi-turn or tokenize the new prompt - if len(sample.response) > 0: - payload["input_ids"] = sample.tokens - else: - payload["input_ids"] = prompt_ids - if not sample.tokens: # Initialize sample.tokens for the first turn - sample.tokens = prompt_ids - - # Use session_id for consistent hashing routing if router uses consistent_hashing policy - headers = None - if args.sglang_router_policy == "consistent_hashing" and sample.session_id: - headers = {"X-SMG-Routing-Key": sample.session_id} - - output = await post(url, payload, headers=headers) - - if args.use_slime_router and "RadixTreeMiddleware" in args.slime_router_middleware_paths: + use_radix = args.use_slime_router and "RadixTreeMiddleware" in getattr(args, "slime_router_middleware_paths", []) + if use_radix: + # RadixTree path: SGLang-specific, keep direct post from slime.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + payload = { + "input_ids": input_ids, + "sampling_params": sampling_params, + "return_logprob": True, + "return_routed_experts": args.use_rollout_routing_replay, + } + if sample.multimodal_inputs and sample.multimodal_inputs.get("images"): + payload["image_data"] = [encode_image_for_rollout_engine(img) for img in sample.multimodal_inputs["images"]] + headers = {"X-SMG-Routing-Key": sample.session_id} if getattr(args, "sglang_router_policy") == "consistent_hashing" and sample.session_id else None + output = await post(url, payload, headers=headers) sample = await postprocess_sample_with_radix_tree(args, sample, output) else: - if "output_token_logprobs" in output["meta_info"]: - new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - else: - new_response_tokens, new_response_log_probs = [], [] - - # Update sample with tokens directly - avoiding re-tokenization - sample.tokens = sample.tokens + new_response_tokens - sample.response_length += len(new_response_tokens) - sample.response += output["text"] - - # When partial rollout and masking off policy is enabled, update the loss mask - if sample.loss_mask is not None: - assert args.partial_rollout and args.mask_offpolicy_in_partial_rollout - sample.loss_mask += [1] * len(new_response_tokens) - - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += new_response_log_probs - - if "routed_experts" in output["meta_info"]: - sample.rollout_routed_experts = np.frombuffer( - pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), - dtype=np.int32, - ).reshape( - len(sample.tokens) - 1, - args.num_layers, - args.moe_router_topk, + backend = _get_backend_client(args) + base_url = getattr(args, "vllm_base_url", None) or f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + headers = {"X-SMG-Routing-Key": sample.session_id} if getattr(args, "sglang_router_policy", None) == "consistent_hashing" and sample.session_id else None + req = RolloutBackendRequest( + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=True, + return_routed_experts=args.use_rollout_routing_replay, + image_data=[encode_image_for_rollout_engine(img) for img in sample.multimodal_inputs["images"]] if (sample.multimodal_inputs and sample.multimodal_inputs.get("images")) else None, + session_id=sample.session_id, ) - - sample.update_from_meta_info(args, output["meta_info"]) + resp = await backend.generate(req, base_url, headers=headers) + _apply_backend_response(sample, resp, args) return sample @@ -306,24 +303,19 @@ async def generate_and_rm_group( async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: aborted_samples = [] - state = GenerateState(args) assert not state.aborted state.aborted = True - if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_slime_router: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") - urls = response["urls"] - else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") - urls = [worker["url"] for worker in response["workers"]] - - logger.info(f"Abort request for {urls}") - abort_tasks = [post(f"{url}/abort_request", {"abort_all": True}) for url in urls] - abort_results = await asyncio.gather(*abort_tasks, return_exceptions=True) - for url, result in zip(urls, abort_results, strict=False): - if isinstance(result, Exception): - logger.warning(f"Failed to abort worker at {url}: {result}") + backend = _get_backend_client(args) + urls = await backend.abort() + if urls: + logger.info(f"Abort request for {urls}") + abort_tasks = [post(f"{url}/abort_request", {"abort_all": True}) for url in urls] + abort_results = await asyncio.gather(*abort_tasks, return_exceptions=True) + for url, result in zip(urls, abort_results, strict=False): + if isinstance(result, Exception): + logger.warning(f"Failed to abort worker at {url}: {result}") # make sure all the pending tasks are finished count = 0 diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index fb2d305a63..76e681a87f 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -238,6 +238,25 @@ def add_rollout_arguments(parser): "Also, sometimes this will help alleviate the bug that transformers cannot find certain model." ), ) + parser.add_argument( + "--rollout-backend", + type=str, + choices=["sglang", "vllm"], + default="sglang", + help="Rollout inference backend: sglang (default) or vllm.", + ) + parser.add_argument( + "--vllm-base-url", + type=str, + default=None, + help="vLLM server URL (set automatically when using managed vLLM).", + ) + parser.add_argument( + "--vllm-max-retries", + type=int, + default=3, + help="Max HTTP retries for vLLM requests.", + ) parser.add_argument( "--rollout-function-path", type=str, @@ -1484,8 +1503,16 @@ def parse_args(add_custom_arguments=None): if pre.train_backend == "megatron" and not args.debug_rollout_only: megatron_validate_args(args) - if not args.debug_train_only: + if not args.debug_train_only and getattr(args, "rollout_backend", "sglang") == "sglang": sglang_validate_args(args) + elif getattr(args, "rollout_backend", "sglang") == "vllm": + # Set sglang aliases that the rest of the codebase expects + args.sglang_dp_size = getattr(args, "sglang_data_parallel_size", 1) or 1 + args.sglang_pp_size = getattr(args, "sglang_pipeline_parallel_size", 1) or 1 + args.sglang_ep_size = getattr(args, "sglang_expert_parallel_size", 1) or 1 + args.sglang_tp_size = args.rollout_num_gpus_per_engine + if not hasattr(args, "sglang_speculative_algorithm"): + args.sglang_speculative_algorithm = None return args From 2caa4a067220d22ca24b782d642d63128c50ed60 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 4 Mar 2026 06:27:17 +0000 Subject: [PATCH 05/10] add convert script --- convert_qwen2.5_ckpt.sh | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 convert_qwen2.5_ckpt.sh diff --git a/convert_qwen2.5_ckpt.sh b/convert_qwen2.5_ckpt.sh new file mode 100644 index 0000000000..fae03587dc --- /dev/null +++ b/convert_qwen2.5_ckpt.sh @@ -0,0 +1,5 @@ +source scripts/models/qwen2.5-0.5B.sh +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen2.5-0.5B-Instruct \ + --save /root/Qwen2.5-0.5B-Instruct_torch_dist/ From 8caa8bae5f36615a238a6f3f1189544f691cd16b Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 4 Mar 2026 06:39:11 +0000 Subject: [PATCH 06/10] add setup doc --- setup_for_vllm.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 setup_for_vllm.md diff --git a/setup_for_vllm.md b/setup_for_vllm.md new file mode 100644 index 0000000000..ccdc8fd593 --- /dev/null +++ b/setup_for_vllm.md @@ -0,0 +1,21 @@ +``` +docker pull slimerl/slime:latest +``` + +``` +docker run -itd --gpus all --ipc=host --shm-size=128g --net=host --privileged=true --restart=always \ +--ulimit memlock=-1 --ulimit stack=67108864 \ +--ulimit nofile=65536:65536 \ +--name DNAME \ +-it slimerl/slime:latest /bin/bash \ + +``` +docker exec -it --user root DNAME bash +``` + +``` +pip install vllm=0.16 + +# for compatibility +pip install numpy==1.26.4 +``` From 25ee00543d07cacad5fb8e57ecaee28315e9d6c7 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 5 Mar 2026 03:45:03 +0000 Subject: [PATCH 07/10] fix nccl error by NcclBridge subprocess --- run-qwen2.5-0.5B-vllm.sh | 1 + .../update_weight_from_distributed.py | 335 +++++++++++++++--- slime/backends/vllm_utils/vllm_engine.py | 23 +- slime/ray/actor_group.py | 6 + slime/utils/arguments.py | 6 + 5 files changed, 308 insertions(+), 63 deletions(-) diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index bcc6807c69..2746098478 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -113,6 +113,7 @@ ray job submit --address="http://127.0.0.1:8265" \ "PYTHONPATH": "/root/Megatron-LM", "CUDA_DEVICE_MAX_CONNECTIONS": "1", "NCCL_ALGO": "Ring", + "NCCL_IB_DISABLE": "1", "NCCL_P2P_DISABLE": "1", "NCCL_DEBUG": "INFO", "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 44649f87f4..94fcecc203 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,6 +1,9 @@ +import logging +import multiprocessing as mp import os import socket import time +import traceback from argparse import Namespace from collections.abc import Callable, Mapping, Sequence @@ -17,6 +20,155 @@ from ..megatron_to_hf import convert_to_hf from .common import all_gather_param, named_params_and_buffers +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# NcclBridge: isolate vLLM's PyNcclCommunicator in a subprocess so that it +# never coexists with torch.distributed NCCL groups in the Megatron trainer. +# +# vLLM's weight transfer uses raw NCCL (PyNcclCommunicator) which conflicts +# with torch.distributed's NCCL backend when both exist in the same process +# (see https://github.com/vllm-project/vllm/issues/5477). sglang avoids +# this because it uses torch.distributed process groups for weight sync. +# --------------------------------------------------------------------------- + + +def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, cvd, env_snapshot): + """Subprocess entry-point: creates PyNcclCommunicator and serves requests. + + Protocol over *conn* (multiprocessing.Connection): + parent → child: + {"op": "broadcast", "tensors": [cpu_tensor, ...]} + {"op": "send_packed", "named_tensors": [(name, cpu_tensor), ...]} + None → shutdown + child → parent: + "ready" (after init) + "ok" (after each op) + "error: " + """ + try: + os.environ.update(env_snapshot) + if cvd: + os.environ["CUDA_VISIBLE_DEVICES"] = cvd + + import torch + torch.cuda.set_device(device) + + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=0, world_size=world_size, + ) + comm = PyNcclCommunicator(pg, device=device) + + conn.send("ready") + + while True: + cmd = conn.recv() + if cmd is None: + break + + op = cmd["op"] + if op == "broadcast": + for cpu_t in cmd["tensors"]: + gpu_t = cpu_t.cuda(device) + comm.broadcast(gpu_t, src=0, stream=torch.cuda.current_stream()) + del gpu_t + torch.cuda.synchronize() + conn.send("ok") + + elif op == "send_packed": + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + ) + + named_tensors = cmd["named_tensors"] + + def _gpu_iter(): + for name, cpu_t in named_tensors: + yield (name, cpu_t.cuda(device).contiguous()) + + NCCLWeightTransferEngine.trainer_send_weights( + iterator=_gpu_iter(), + group=comm, + packed=True, + ) + torch.cuda.synchronize() + conn.send("ok") + + except Exception as e: + try: + conn.send(f"error: {e}") + except Exception: + pass + traceback.print_exc() + + +class _NcclBridge: + """Runs vLLM's PyNcclCommunicator in a separate subprocess. + + This prevents NCCL communicator conflicts with torch.distributed groups + that already exist in the Megatron trainer process. Communication between + the trainer and the bridge uses a multiprocessing Pipe; GPU tensors are + staged through CPU (pinned memory when possible) for the transfer. + """ + + def __init__(self, master_address: str, master_port: int, world_size: int, device: int): + ctx = mp.get_context("spawn") + self._parent_conn, child_conn = ctx.Pipe() + + # Pass the full environment so the subprocess inherits all NCCL, CUDA, + # and networking settings from the trainer (set by Ray runtime_env). + env_snapshot = dict(os.environ) + cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "") + + self._process = ctx.Process( + target=_nccl_bridge_worker, + args=(child_conn, master_address, master_port, world_size, device, cvd, env_snapshot), + daemon=True, + ) + self._process.start() + + msg = self._parent_conn.recv() + if isinstance(msg, str) and msg.startswith("error:"): + raise RuntimeError(f"NcclBridge init failed: {msg}") + if msg != "ready": + raise RuntimeError(f"NcclBridge init unexpected response: {msg}") + logger.info("NcclBridge ready (pid=%d, device=%d)", self._process.pid, device) + + def broadcast_tensors(self, tensors: list[torch.Tensor]) -> None: + """Broadcast a list of tensors (one-by-one) via the bridge subprocess.""" + cpu_tensors = [t.cpu().contiguous() for t in tensors] + self._parent_conn.send({"op": "broadcast", "tensors": cpu_tensors}) + self._wait_ok("broadcast_tensors") + + def send_weights_packed(self, named_tensors: list[tuple[str, torch.Tensor]]) -> None: + """Send weights using vLLM's packed broadcast protocol.""" + cpu_pairs = [] + for name, t in named_tensors: + data = t.data if hasattr(t, "data") else t + cpu_pairs.append((name, data.cpu().contiguous())) + self._parent_conn.send({"op": "send_packed", "named_tensors": cpu_pairs}) + self._wait_ok("send_weights_packed") + + def _wait_ok(self, label: str, timeout: float = 600.0) -> None: + if not self._parent_conn.poll(timeout): + raise TimeoutError(f"NcclBridge {label} timed out after {timeout}s") + msg = self._parent_conn.recv() + if msg != "ok": + raise RuntimeError(f"NcclBridge {label} failed: {msg}") + + def shutdown(self) -> None: + try: + self._parent_conn.send(None) + self._process.join(timeout=30) + except Exception: + pass + if self._process.is_alive(): + self._process.terminate() + class UpdateWeightFromDistributed: """ @@ -99,34 +251,52 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) - buffer_size = 0 - converted_named_tensors = [] - # non expert params - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." in name: - continue - buffer_size = self._update_weight_from_distributed( - name, param, converted_named_tensors, buffer_size, pbar=pbar - ) + use_vllm_packed = self._use_vllm_packed() + if use_vllm_packed and self._is_pp_src_rank: + logger.info("Using vLLM packed weight sync (one-shot metadata + trainer_send_weights)") + if use_vllm_packed: + # vLLM packed path: gather all non-expert params, one-shot update (aligned with vLLM New Weight Syncing) + converted_named_tensors = [] + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." in name: + continue + param = all_gather_param(name, param) + if self._is_pp_src_rank: + converted_named_tensors += convert_to_hf( + self.args, self.model_name, name, param, self.quantization_config + ) + if converted_named_tensors and self._is_pp_src_rank: + self._update_weights_vllm_packed(converted_named_tensors) + else: + buffer_size = 0 + converted_named_tensors = [] + pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None + + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." in name: + continue + buffer_size = self._update_weight_from_distributed( + name, param, converted_named_tensors, buffer_size, pbar=pbar + ) - if converted_named_tensors: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) + if converted_named_tensors: + self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) dist.barrier(group=get_gloo_group()) - buffer_size = 0 - named_tensors = [] - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." not in name: - continue - buffer_size = self._update_expert_weight_from_distributed( - name, param, named_tensors, buffer_size, pbar=pbar - ) + if not use_vllm_packed: + buffer_size = 0 + named_tensors = [] + pbar = tqdm(desc=f"[{self._group_name}] Update weights (experts)", total=0) if self._is_pp_src_rank else None + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." not in name: + continue + buffer_size = self._update_expert_weight_from_distributed( + name, param, named_tensors, buffer_size, pbar=pbar + ) - if named_tensors: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) + if named_tensors: + self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: @@ -164,6 +334,41 @@ def _update_weight_from_distributed( buffer_size += param_size return buffer_size + def _use_vllm_packed(self) -> bool: + """Use vLLM packed weight transfer (one-shot metadata + trainer_send_weights).""" + if not _is_vllm_backend(self.args): + return False + if not getattr(self.args, "vllm_weight_sync_packed", True): + return False + # MoE models need expert path, skip packed + if any(".experts." in name for name, _ in named_params_and_buffers(self.args, self.model)): + return False + # compressed-tensors needs pre/post process + if self.quantization_config and self.quantization_config.get("quant_method") == "compressed-tensors": + return False + return True + + def _update_weights_vllm_packed( + self, converted_named_tensors: list[tuple[str, torch.Tensor]] + ) -> None: + """Single-shot vLLM weight update using packed broadcast (aligned with vLLM New Weight Syncing).""" + while not ray.get(self.rollout_engine_lock.acquire.remote()): + time.sleep(0.1) + + try: + refs = update_weights_from_distributed( + self._group_name, + self._model_update_groups, + self.weight_version, + self.rollout_engines, + converted_named_tensors, + use_vllm=True, + packed=True, + ) + ray.get(refs) + finally: + ray.get(self.rollout_engine_lock.release.remote()) + def _update_expert_weight_from_distributed( self, name: str, @@ -268,9 +473,10 @@ def connect_rollout_engines_from_distributed( have heterogeneous TP sizes (e.g. prefill TP=2, decode TP=4), each engine occupies a different number of ranks in the NCCL group. - For vLLM backend, uses vLLM's StatelessProcessGroup + PyNcclCommunicator - instead of torch.distributed, because vLLM's weight transfer engine uses - its own NCCL initialization protocol. + For vLLM backend, the trainer-side NCCL communicator is created inside a + separate subprocess (_NcclBridge) to avoid conflicts between vLLM's raw + NCCL (PyNcclCommunicator) and the torch.distributed NCCL groups that + Megatron already holds in this process. """ if engine_gpu_counts is None: engine_gpu_counts = [args.rollout_num_gpus_per_engine] * len(rollout_engines) @@ -286,6 +492,7 @@ def connect_rollout_engines_from_distributed( for c in engine_gpu_counts: cumulative.append(cumulative[-1] + c) + # Fire engine init remotes first (non-blocking Ray calls). refs = [ engine.init_weights_update_group.remote( master_address, @@ -298,29 +505,28 @@ def connect_rollout_engines_from_distributed( for i, engine in enumerate(rollout_engines) ] + torch.cuda.synchronize() + torch.cuda.empty_cache() + if _is_vllm_backend(args): - # vLLM uses StatelessProcessGroup + PyNcclCommunicator for weight transfer. - # The training side must use the same mechanism for NCCL compatibility. - # - # Disable P2P transport: in colocate mode the trainer and vLLM server - # share the same physical GPU but have different CUDA_VISIBLE_DEVICES, - # which causes NCCL P2P (cudaIpc*) to fail with "invalid argument". - # SHM transport works correctly in this scenario. - from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine - - old_p2p = os.environ.get("NCCL_P2P_DISABLE") - os.environ["NCCL_P2P_DISABLE"] = "1" - try: - model_update_groups = NCCLWeightTransferEngine.trainer_init({ - "master_address": master_address, - "master_port": master_port, - "world_size": world_size, - }) - finally: - if old_p2p is None: - os.environ.pop("NCCL_P2P_DISABLE", None) - else: - os.environ["NCCL_P2P_DISABLE"] = old_p2p + # vLLM uses StatelessProcessGroup + PyNcclCommunicator (raw NCCL). + # Creating PyNcclCommunicator in the Megatron trainer process would + # conflict with torch.distributed's NCCL groups (issue #5477). + # Instead, we spawn a bridge subprocess that owns the PyNcclCommunicator, + # keeping the trainer process free of raw NCCL communicators. + device = torch.cuda.current_device() + logger.info( + "vLLM weight transfer via NcclBridge: addr=%s port=%d " + "world_size=%d device=%d CVD=%s", + master_address, master_port, world_size, device, + os.environ.get("CUDA_VISIBLE_DEVICES", ""), + ) + model_update_groups = _NcclBridge( + master_address=master_address, + master_port=master_port, + world_size=world_size, + device=device, + ) else: model_update_groups = init_process_group( backend="nccl", @@ -340,6 +546,8 @@ def disconnect_rollout_engines_from_distributed(args, group_name, model_update_g """ refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines] if _is_vllm_backend(args): + if isinstance(model_update_groups, _NcclBridge): + model_update_groups.shutdown() model_update_groups = None else: dist.destroy_process_group(model_update_groups) @@ -353,25 +561,34 @@ def update_weights_from_distributed( rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], use_vllm: bool = False, + packed: bool = False, ) -> list[ObjectRef]: """ Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). - For vLLM, uses PyNcclCommunicator.broadcast instead of dist.broadcast. + + For vLLM the *group* is an ``_NcclBridge`` instance (subprocess) so that + raw NCCL never runs inside the Megatron trainer process. + For sglang the *group* is a ``torch.distributed.ProcessGroup``. """ + kwargs = { + "names": [name for name, _ in converted_named_tensors], + "dtypes": [param.dtype for _, param in converted_named_tensors], + "shapes": [param.shape for _, param in converted_named_tensors], + "group_name": group_name, + "weight_version": str(weight_version), + } + if use_vllm: + kwargs["packed"] = packed + refs = [ - engine.update_weights_from_distributed.remote( - names=[name for name, _ in converted_named_tensors], - dtypes=[param.dtype for _, param in converted_named_tensors], - shapes=[param.shape for _, param in converted_named_tensors], - group_name=group_name, - weight_version=str(weight_version), - ) + engine.update_weights_from_distributed.remote(**kwargs) for engine in rollout_engines ] - if use_vllm: - for _, param in converted_named_tensors: - group.broadcast(param.data, src=0, stream=torch.cuda.current_stream()) + if use_vllm and packed: + group.send_weights_packed(converted_named_tensors) + elif use_vllm: + group.broadcast_tensors([param.data for _, param in converted_named_tensors]) else: handles = [] for _, param in converted_named_tensors: diff --git a/slime/backends/vllm_utils/vllm_engine.py b/slime/backends/vllm_utils/vllm_engine.py index 12283de21e..22ff897ee5 100644 --- a/slime/backends/vllm_utils/vllm_engine.py +++ b/slime/backends/vllm_utils/vllm_engine.py @@ -62,6 +62,7 @@ def init(self, port=None, host=None, **kwargs): env.setdefault("NCCL_DEBUG", "INFO") env.setdefault("NCCL_DEBUG_SUBSYS", "ALL") env["NCCL_P2P_DISABLE"] = "1" + env.setdefault("NCCL_IB_DISABLE", "1") self._log_file = tempfile.NamedTemporaryFile( prefix="vllm_engine_", suffix=".log", delete=False, mode="w" @@ -145,8 +146,11 @@ def flush_cache(self): def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name=None, backend=None): logger.info( - "Initializing NCCL weight transfer: master=%s:%s, rank_offset=%d, world_size=%d", + "Initializing NCCL weight transfer: master=%s:%s, rank_offset=%d, " + "world_size=%d, vllm_url=http://%s:%s, vllm_log=%s", master_address, master_port, rank_offset, world_size, + self.server_host, self.server_port, + self._log_file.name if self._log_file else "", ) self._post("/init_weight_transfer_engine", json_data={ "init_info": { @@ -156,8 +160,19 @@ def init_weights_update_group(self, master_address, master_port, rank_offset, wo "world_size": world_size, } }) - - def update_weights_from_distributed(self, names, dtypes, shapes, group_name=None, flush_cache=False, weight_version=None): + log_tail = self._read_log_tail(30) + logger.info("vLLM log after init_weight_transfer_engine:\n%s", log_tail) + + def update_weights_from_distributed( + self, + names, + dtypes, + shapes, + group_name=None, + flush_cache=False, + weight_version=None, + packed: bool = True, + ): dtype_names = [str(d).replace("torch.", "") for d in dtypes] shape_lists = [list(s) for s in shapes] self._post("/update_weights", json_data={ @@ -165,7 +180,7 @@ def update_weights_from_distributed(self, names, dtypes, shapes, group_name=None "names": names, "dtype_names": dtype_names, "shapes": shape_lists, - "packed": False, + "packed": packed, } }) diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 5542835313..1150e9ed35 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -50,12 +50,18 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): assert pg is not None pg, reordered_bundle_indices, _reordered_gpu_ids = pg + # Restrict CUDA_VISIBLE_DEVICES to only this group's GPUs so that + # NCCL / PyTorch do not allocate memory on rollout GPUs. + trainer_gpu_ids = [_reordered_gpu_ids[rank] for rank in range(world_size)] + trainer_cvd = ",".join(str(g) for g in trainer_gpu_ids) + env_vars = { # because sglang will always set NCCL_CUMEM_ENABLE to 0 # we need also set it to 0 to prevent nccl error. "NCCL_CUMEM_ENABLE": os.environ.get("NCCL_CUMEM_ENABLE", "0"), "NVTE_FP8_BLOCK_SCALING_FP32_SCALES": os.environ.get("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "1"), **{name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST}, + "CUDA_VISIBLE_DEVICES": trainer_cvd, **self.args.train_env_vars, } diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 76e681a87f..7ef8e70b71 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -257,6 +257,12 @@ def add_rollout_arguments(parser): default=3, help="Max HTTP retries for vLLM requests.", ) + parser.add_argument( + "--vllm-weight-sync-packed", + action=argparse.BooleanOptionalAction, + default=True, + help="Use vLLM packed weight transfer for non-colocate (default: True). Disable for per-bucket mode.", + ) parser.add_argument( "--rollout-function-path", type=str, From ab7eb0bc11412fb1ff232c1f9a4560f8f1e86df0 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 5 Mar 2026 08:31:05 +0000 Subject: [PATCH 08/10] eliminate gpu to cpu weight transfer Signed-off-by: samithuang <285365963@qq.com> --- run-qwen2.5-0.5B-vllm.sh | 3 +- .../update_weight_from_distributed.py | 40 +++++++++---------- slime/backends/vllm_utils/vllm_engine.py | 3 ++ slime/rollout/backends/vllm_client.py | 2 + 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index 2746098478..208ec60c91 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -26,6 +26,7 @@ CKPT_ARGS=( --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/ ) +# num-rollout:100 ROLLOUT_ARGS=( --prompt-data /root/gsm8k/train.parquet --input-key messages @@ -33,7 +34,7 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type math - --num-rollout 100 + --num-rollout 500 --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 1024 diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 94fcecc203..03d480801f 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -1,5 +1,5 @@ import logging -import multiprocessing as mp +import torch.multiprocessing as mp import os import socket import time @@ -37,10 +37,13 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, cvd, env_snapshot): """Subprocess entry-point: creates PyNcclCommunicator and serves requests. + GPU tensors are shared from the parent via CUDA IPC (torch.multiprocessing + handles this transparently). No GPU→CPU→GPU copies are needed. + Protocol over *conn* (multiprocessing.Connection): parent → child: - {"op": "broadcast", "tensors": [cpu_tensor, ...]} - {"op": "send_packed", "named_tensors": [(name, cpu_tensor), ...]} + {"op": "broadcast", "tensors": [gpu_tensor, ...]} + {"op": "send_packed", "named_tensors": [(name, gpu_tensor), ...]} None → shutdown child → parent: "ready" (after init) @@ -53,6 +56,7 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, c os.environ["CUDA_VISIBLE_DEVICES"] = cvd import torch + import torch.multiprocessing # ensure CUDA IPC reducers are registered torch.cuda.set_device(device) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -72,10 +76,8 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, c op = cmd["op"] if op == "broadcast": - for cpu_t in cmd["tensors"]: - gpu_t = cpu_t.cuda(device) - comm.broadcast(gpu_t, src=0, stream=torch.cuda.current_stream()) - del gpu_t + for t in cmd["tensors"]: + comm.broadcast(t, src=0, stream=torch.cuda.current_stream()) torch.cuda.synchronize() conn.send("ok") @@ -84,14 +86,8 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, c NCCLWeightTransferEngine, ) - named_tensors = cmd["named_tensors"] - - def _gpu_iter(): - for name, cpu_t in named_tensors: - yield (name, cpu_t.cuda(device).contiguous()) - NCCLWeightTransferEngine.trainer_send_weights( - iterator=_gpu_iter(), + iterator=iter(cmd["named_tensors"]), group=comm, packed=True, ) @@ -110,9 +106,9 @@ class _NcclBridge: """Runs vLLM's PyNcclCommunicator in a separate subprocess. This prevents NCCL communicator conflicts with torch.distributed groups - that already exist in the Megatron trainer process. Communication between - the trainer and the bridge uses a multiprocessing Pipe; GPU tensors are - staged through CPU (pinned memory when possible) for the transfer. + that already exist in the Megatron trainer process. GPU tensors are shared + with the subprocess via CUDA IPC (handled transparently by + torch.multiprocessing), avoiding any GPU→CPU→GPU copies. """ def __init__(self, master_address: str, master_port: int, world_size: int, device: int): @@ -140,17 +136,17 @@ def __init__(self, master_address: str, master_port: int, world_size: int, devic def broadcast_tensors(self, tensors: list[torch.Tensor]) -> None: """Broadcast a list of tensors (one-by-one) via the bridge subprocess.""" - cpu_tensors = [t.cpu().contiguous() for t in tensors] - self._parent_conn.send({"op": "broadcast", "tensors": cpu_tensors}) + gpu_tensors = [t.contiguous() for t in tensors] + self._parent_conn.send({"op": "broadcast", "tensors": gpu_tensors}) self._wait_ok("broadcast_tensors") def send_weights_packed(self, named_tensors: list[tuple[str, torch.Tensor]]) -> None: """Send weights using vLLM's packed broadcast protocol.""" - cpu_pairs = [] + gpu_pairs = [] for name, t in named_tensors: data = t.data if hasattr(t, "data") else t - cpu_pairs.append((name, data.cpu().contiguous())) - self._parent_conn.send({"op": "send_packed", "named_tensors": cpu_pairs}) + gpu_pairs.append((name, data.contiguous())) + self._parent_conn.send({"op": "send_packed", "named_tensors": gpu_pairs}) self._wait_ok("send_weights_packed") def _wait_ok(self, label: str, timeout: float = 600.0) -> None: diff --git a/slime/backends/vllm_utils/vllm_engine.py b/slime/backends/vllm_utils/vllm_engine.py index 22ff897ee5..bf645cd0bd 100644 --- a/slime/backends/vllm_utils/vllm_engine.py +++ b/slime/backends/vllm_utils/vllm_engine.py @@ -42,12 +42,15 @@ def init(self, port=None, host=None, **kwargs): else: dev_str = ",".join(str(g) for g in gpu_ids) + seed = getattr(self.args, "seed", 1234) + self.rank cmd = [ "vllm", "serve", model, "--tensor-parallel-size", str(tp), "--port", str(self.server_port), "--host", "0.0.0.0", "--weight-transfer-config", '{"backend": "nccl"}', + "--seed", str(seed), + "--trust-remote-code", ] if getattr(self.args, "offload_rollout", False): cmd.append("--enable-sleep-mode") diff --git a/slime/rollout/backends/vllm_client.py b/slime/rollout/backends/vllm_client.py index 537569499d..32be01d66c 100644 --- a/slime/rollout/backends/vllm_client.py +++ b/slime/rollout/backends/vllm_client.py @@ -44,9 +44,11 @@ async def generate( "stop": sp.get("stop"), "stop_token_ids": sp.get("stop_token_ids"), "skip_special_tokens": sp.get("skip_special_tokens", True), + "spaces_between_special_tokens": sp.get("spaces_between_special_tokens", False), "logprobs": 1 if request.return_logprob else None, "include_stop_str_in_output": sp.get("no_stop_trim", False), "return_token_ids": True, + "seed": sp.get("sampling_seed"), } payload = {k: v for k, v in payload.items() if v is not None} From 546d2ad734a28e887457ce9e124398df150ded86 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 6 Mar 2026 00:19:19 +0800 Subject: [PATCH 09/10] Revise weight synchronization strategy in goal plan Reorder weight synchronization support for colocate and non-colocate scenarios in the goal plan. --- goal_plan.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/goal_plan.md b/goal_plan.md index d77618195b..f7b8827c8b 100644 --- a/goal_plan.md +++ b/goal_plan.md @@ -6,8 +6,8 @@ First Design and RFC by 03/06 - 对标SGLang,Slime 在 Ray 内管理 vLLM 的完整生命周期,包括进程拉起、权重同步、推理暂停/恢复 - 暂不使用Router,SGLang Model Gateway仅只支持SGLang Worker,SlimeRouter仅在 R3 / radix-tree caching 时需要,Qwen2.5-0.5B 非 MoE 且用 token-in/token-out - 单vLLM实例,无router,通过vLLMClient 直连本地 vLLM 进程端口 -- 先支持和验证colocate,权重同步采用GPU IPC(vLLM update_weights_from_ipc, update_weights_from_tensor),对标SGLang update_weights_from_tensor,以验证Reproductivity -- 再支持训推不共卡,权重同步采用NCCL broadcast,对标SGLang update_weights_from_distributed (默认) +- 先支持训推不共卡(non-colocate),权重同步采用NCCL broadcast,对标SGLang update_weights_from_distributed (默认) +- 再支持和验证colocate,权重同步采用GPU IPC(vLLM update_weights_from_ipc, update_weights_from_tensor),对标SGLang update_weights_from_tensor,以验证Reproductivity。**IPC 依赖vllm 0.17** #### 风险: - slime, sglang版本依赖,和vllm 0.16的版本依赖冲突(numpy, torch, transformers, etc) From d480da0da4ef1228b4f47b3bd5de88dafff6cffc Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Mon, 23 Mar 2026 15:55:22 +0700 Subject: [PATCH 10/10] Router for vllm (#5) * Draft router design Signed-off-by: knlnguyen1802 * Add vllm router Signed-off-by: knlnguyen1802 * Add router to script Signed-off-by: knlnguyen1802 * Fix gpu memory utilization Signed-off-by: knlnguyen1802 * Fix output token ids Signed-off-by: knlnguyen1802 * Add more nccl flag Signed-off-by: knlnguyen1802 * Fix bug Signed-off-by: knlnguyen1802 --------- Signed-off-by: knlnguyen1802 --- docs/en/vllm/ROUTER_DESIGN.md | 462 ++++++++++++++++++ run-qwen2.5-0.5B-vllm.sh | 4 + slime/backends/vllm_utils/__init__.py | 3 +- slime/backends/vllm_utils/vllm_engine.py | 130 ++++- .../vllm_utils/vllm_translation_sidecar.py | 454 +++++++++++++++++ slime/ray/rollout.py | 91 +++- slime/rollout/backends/vllm_client.py | 92 +++- slime/utils/arguments.py | 7 + 8 files changed, 1212 insertions(+), 31 deletions(-) create mode 100644 docs/en/vllm/ROUTER_DESIGN.md create mode 100644 slime/backends/vllm_utils/vllm_translation_sidecar.py diff --git a/docs/en/vllm/ROUTER_DESIGN.md b/docs/en/vllm/ROUTER_DESIGN.md new file mode 100644 index 0000000000..c18a29c997 --- /dev/null +++ b/docs/en/vllm/ROUTER_DESIGN.md @@ -0,0 +1,462 @@ +# RFC: Replace SGLang Backend with vLLM — Router Integration + +--- + +## Summary + +Replace the SGLang inference backend behind **SlimeRouter** with **vLLM** while keeping the existing router and middleware stack completely unchanged. +This RFC covers **only the router layer** — what APIs the vLLM backend must expose, how the existing SlimeRouter is reused, and what translation is needed between the two formats. + +**Key design decision:** Reuse vLLM's built-in [OpenAI-compatible API server](https://docs.vllm.ai/en/stable/serving/openai_compatible_server/) (`vllm serve`) + + +--- + +## 1. Target Architecture + +``` + Rollout Workers SlimeRouter (NO CHANGE) vLLM Engines (NEW) + ────────────── ──────────────────────── ────────────────── + ┌──────────────────────┐ + POST /generate ──────────────────▶│ RadixTreeMiddleware │ + │ • prefix cache │ + │ • retry on abort │ + │ • token/logprob cache│ + └──────────┬───────────┘ + │ + ┌──────────▼───────────┐ + │ SlimeRouter.proxy() │ ┌─────────────────────┐ + │ • least-connections │────────▶│ vLLM Translation │ + │ load balancer │ │ Sidecar (per engine) │ + │ • health check loop │ │ │ + └──────────────────────┘ │ POST /generate │ + │ ↓ translate │ + │ POST /v1/completions │ + │ ↓ translate back │ + │ → SGLang-format JSON │ + └─────────┬───────────┘ + │ + ┌─────────▼───────────┐ + │ vLLM Server │ + │ (vllm serve) │ + │ • /v1/completions │ + │ • /health │ + │ • /sleep, /wake_up │ + │ • /pause, /resume │ + │ • /update_weights │ + └─────────────────────┘ +``` + +### What stays the same + +| Component | Change | Reason | +|---|---|---| +| `SlimeRouter` ([router.py](slime/router/router.py)) | **None** | Engine-agnostic HTTP proxy; only reads JSON responses | +| `RadixTreeMiddleware` ([radix_tree_middleware.py](slime/router/middleware_hub/radix_tree_middleware.py)) | **None** | Operates on request/response JSON; has no engine-specific code | +| `StringRadixTrie` ([radix_tree.py](slime/router/middleware_hub/radix_tree.py)) | **None** | Pure data structure, no engine coupling | +| Middleware loading (`--slime-router-middleware-paths`) | **None** | Dynamic import via `load_function()` | + +### What is new + +| Component | Description | +|---|---| +| `vllm_translation_sidecar.py` | Lightweight FastAPI process co-located with each vLLM engine. Receives SGLang-format `/generate` requests, translates to vLLM's `/v1/completions`, translates responses back. Also proxies lifecycle endpoints (`/abort_request`, `/health_generate`, etc.). | +| `vllm_engine.py` | Ray actor that manages the vLLM server process lifecycle (via `vllm serve`), the translation sidecar, weight updates, and registration with the router. | + +--- + +## 2. Reusing SlimeRouter — Zero Modification + +The SlimeRouter communicates with backends through **five interaction points**. All are already engine-agnostic: + +### 2.1 Worker Registration + +**Flow:** Engine starts → engine calls `POST /add_worker?url=http://{host}:{port}` → router adds to pool. + +``` +Router state after registration: + worker_request_counts["http://10.0.0.1:10090"] = 0 + worker_failure_counts["http://10.0.0.1:10090"] = 0 +``` + +**vLLM action:** The `VLLMEngine` Ray actor calls this endpoint after verifying the vLLM server + translation sidecar are healthy. The registered URL points to the **sidecar**, not the raw vLLM server. No router change needed. + +### 2.2 Request Proxying + +**Flow:** `POST /generate` → middleware pipeline → `SlimeRouter.proxy()` → `httpx` forwards to backend (sidecar). + +The router selects a backend via **least-connections** (`_use_url()`), forwards the raw request body as-is, and returns the response as-is. It never inspects or transforms the request/response payload. + +**vLLM action:** The sidecar receives the forwarded request, translates it to `/v1/completions`, calls the co-located vLLM server, translates the response back to SGLang format, and returns it. + +### 2.3 Health Check + +**Flow:** Background loop calls `GET {worker_url}/health` every N seconds. + +- 200 → healthy, reset failure count +- Non-200 or timeout → increment failure count +- Failures ≥ threshold (default 3) → quarantine worker permanently + +**vLLM action:** The sidecar's `/health` proxies to vLLM's built-in `/health` endpoint (returns 200 when ready). Compatible out of the box. + +### 2.4 Worker Listing + +**Flow:** `GET /list_workers` → returns `{"urls": [...]}` + +Used by the rollout to discover engines for direct abort calls. No engine involvement. + +### 2.5 Retrieve from Text (Radix Tree) + +**Flow:** `POST /retrieve_from_text` → router looks up the radix tree cache → returns tokens/logprobs. + +Fully router-internal. Never reaches the engine. + +--- + +## 3. API Contract — What the Translation Sidecar Must Expose + +The translation sidecar sits between SlimeRouter and the vLLM server. It receives SGLang-format requests and returns SGLang-format responses. + +### 3.1 `POST /generate` — Generation + +This is the primary endpoint. The sidecar translates between Slime's format and vLLM's `/v1/completions`. + +#### Incoming Request (from router) + +```json +{ + "input_ids": [128000, 2610, 553, 264, 11190, 18328, 13], + "input_tokens": [128000, 2610, 553, 264, 11190, 18328, 13], + "sampling_params": { + "temperature": 0.7, + "top_p": 0.9, + "top_k": -1, + "max_new_tokens": 1024, + "stop": ["<|endoftext|>"], + "stop_token_ids": [128001], + "skip_special_tokens": false, + "no_stop_trim": true, + "spaces_between_special_tokens": false + }, + "return_logprob": true, + "stream": false +} +``` + +#### Translated Request (to vLLM `/v1/completions`) + +```json +{ + "model": "", + "prompt": [128000, 2610, 553, 264, 11190, 18328, 13], + "max_tokens": 1024, + "temperature": 0.7, + "top_p": 0.9, + "top_k": -1, + "stop": ["<|endoftext|>"], + "stop_token_ids": [128001], + "skip_special_tokens": false, + "include_stop_str_in_output": true, + "spaces_between_special_tokens": false, + "logprobs": 1, + "stream": false, + "extra_body": { + "return_token_ids": true + } +} +``` + +**Key translations:** +- `input_ids` → `prompt` (vLLM accepts `list[int]` as pre-tokenized prompt) +- `max_new_tokens` → `max_tokens` +- `no_stop_trim: true` → `include_stop_str_in_output: true` +- `return_logprob: true` → `logprobs: 1` + `extra_body.return_token_ids: true` + +#### vLLM Response (from `/v1/completions`) + +```json +{ + "id": "cmpl-abc123", + "choices": [{ + "text": "I'll help you with that. The answer is 42.", + "logprobs": { + "token_logprobs": [-0.152, -0.089, -0.203], + "tokens": ["I", "'ll", " help"] + }, + "token_ids": [40, 3358, 1520], + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10 + } +} +``` + +#### Translated Response (returned to router) + +```json +{ + "text": "I'll help you with that. The answer is 42.", + "output_ids": [40, 3358, 1520], + "meta_info": { + "output_token_logprobs": [ + [-0.152, 40], + [-0.089, 3358], + [-0.203, 1520] + ], + "finish_reason": { + "type": "stop" + }, + "weight_version": 3, + "prompt_tokens": 7, + "cached_tokens": 0 + } +} +``` + +##### Field-by-field contract + +| Field | Type | Required | Consumer | Description | +|---|---|---|---|---| +| `text` | `str` | **Yes** | Rollout, Middleware | Generated text (output only, not including prompt) | +| `output_ids` | `list[int]` | **Yes** | Middleware | Generated token IDs. Middleware checks existence as a gate for caching. | +| `meta_info.output_token_logprobs` | `list[[float, int]]` | **Yes** (if `return_logprob`) | Rollout, Middleware | Each element is `[logprob, token_id]`. Used for RL policy ratio calculation. | +| `meta_info.finish_reason` | `{"type": str}` | **Yes** | Rollout, Middleware | Must be `{"type": "stop"}`, `{"type": "length"}`, or `{"type": "abort"}`. **Not** a plain string. | +| `meta_info.weight_version` | `int` | **Yes** | Middleware, Rollout | Current model weight version. Tracked by the sidecar (incremented on each weight update). | +| `meta_info.prompt_tokens` | `int` | Nice-to-have | Rollout (stats) | From `usage.prompt_tokens`. | +| `meta_info.cached_tokens` | `int` | Nice-to-have | Rollout (stats) | vLLM doesn't expose this directly; default to `0`. | + +### 3.2 `GET /health` — Health Check + +``` +GET /health +→ Sidecar proxies to vLLM's GET /health +→ 200 OK (engine ready) +→ 503 or timeout (engine not ready / overloaded) +``` + +vLLM already provides this endpoint. **Passthrough — no translation needed.** + +### 3.3 `POST /abort_request` — Cancel Generation + +``` +POST /abort_request +Body: {"abort_all": true} +→ 200 OK +``` + +Called **directly** by the rollout to each engine (bypasses the router). The rollout discovers engine URLs via `GET /list_workers`, then sends abort to each. + +**vLLM approach:** vLLM uses **HTTP connection close** for abort (via its `@with_cancellation` decorator). When a client disconnects, the in-flight request is automatically cancelled. + +**Implementation options:** +1. **Track active connections.** The sidecar maintains a set of active `httpx` connections to the vLLM server. On `POST /abort_request`, close all of them — triggering vLLM's cancellation. +2. **Use vLLM's `/pause` endpoint.** Call `POST /pause` to block new requests, then `POST /resume` after the RL training step completes. This is semantically closer to how Slime uses abort (clearing the decks between training generations). + +> **Note:** vLLM has `POST /abort_requests` only in disaggregated mode. For standard mode, HTTP disconnect is the canonical abort mechanism. + +### 3.4 `GET /health_generate` — Startup Readiness Probe + +``` +GET /health_generate +→ 200 OK (model loaded, engine ready for generation) +``` + +Called by `VLLMEngine.init()` during startup to block until the engine is fully ready. The sidecar implements this by calling vLLM's `GET /health` and optionally performing a dummy `/v1/completions` call with `max_tokens=1` to verify end-to-end readiness. + +### 3.5 Sampling Params Translation + +The request uses SGLang-format parameter names. The sidecar translates to vLLM's `/v1/completions` format: + +| SGLang field (in request) | vLLM `/v1/completions` field | Notes | +|---|---|---| +| `input_ids` | `prompt` | Direct — vLLM accepts `list[int]` as pre-tokenized prompt | +| `temperature` | `temperature` | Direct | +| `top_p` | `top_p` | Direct | +| `top_k` | `top_k` | Both use `-1` for disabled | +| `max_new_tokens` | `max_tokens` | **Name change** | +| `stop` | `stop` | Direct (list of strings) | +| `stop_token_ids` | `stop_token_ids` | Direct | +| `skip_special_tokens` | `skip_special_tokens` | Direct | +| `no_stop_trim` | `include_stop_str_in_output` | **Same semantics, different name** | +| `spaces_between_special_tokens` | `spaces_between_special_tokens` | Direct | +| `return_logprob` | `logprobs` (set to `1`) | Also add `extra_body.return_token_ids = true` | +| `sampling_seed` | `seed` | Optional | +| — | `model` | Must be set to the model name served by vLLM | + +### 3.6 Response Translation Pseudocode + +```python +def translate_vllm_response(vllm_resp: dict, weight_version: int) -> dict: + """Translate vLLM /v1/completions response to SGLang format.""" + choice = vllm_resp["choices"][0] + usage = vllm_resp.get("usage", {}) + + # Build output_token_logprobs: zip logprobs with token IDs + output_token_logprobs = None + if choice.get("logprobs") and choice.get("token_ids"): + output_token_logprobs = [ + [logprob, token_id] + for logprob, token_id in zip( + choice["logprobs"]["token_logprobs"], + choice["token_ids"] + ) + ] + + # Translate finish_reason: plain string → {"type": str} + raw_reason = choice.get("finish_reason") + finish_reason = {"type": raw_reason if raw_reason else "abort"} + + return { + "text": choice["text"], + "output_ids": choice.get("token_ids", []), + "meta_info": { + "output_token_logprobs": output_token_logprobs, + "finish_reason": finish_reason, + "weight_version": weight_version, + "prompt_tokens": usage.get("prompt_tokens", 0), + "cached_tokens": 0, + } + } +``` + +### 3.7 `finish_reason` Translation Table + +| vLLM returns | Translate to | Notes | +|---|---|---| +| `"stop"` | `{"type": "stop"}` | Normal completion | +| `"length"` | `{"type": "length"}` | Hit `max_tokens` | +| `None` (aborted/incomplete) | `{"type": "abort"}` | Triggers middleware retry logic (sleep 30s, up to 5 retries) | + +--- + +## 4. Server Launch Configuration + +The `VLLMEngine` Ray actor should launch vLLM as follows: + +```bash +# Environment +export VLLM_SERVER_DEV_MODE=1 + +# Launch vLLM server +vllm serve \ + --host 0.0.0.0 \ + --port \ + --tensor-parallel-size \ + --enable-sleep-mode \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --disable-log-requests +``` + +The translation sidecar runs on a separate port (``) and is the URL registered with the router via `POST /add_worker?url=http://{host}:{sidecar_port}`. + +``` + Router + │ + ▼ + ┌─────────────────────────┐ + │ Translation Sidecar │ ◄── registered with router + │ port: sidecar_port │ + │ │ + │ /generate ──translate──▶│──┐ + │ /health ──passthrough──▶│ │ + │ /abort_request │ │ + │ /health_generate │ │ + └─────────────────────────┘ │ + │ + ┌─────────────────────────┐ │ + │ vLLM Server │◄─┘ + │ port: engine_port │ + │ │ + │ /v1/completions │ + │ /health │ + │ /sleep, /wake_up │ + │ /pause, /resume │ + │ /update_weights │ + │ /init_weight_transfer │ + └─────────────────────────┘ +``` + +--- + +## 5. Abort Strategy — Detailed Design + +vLLM's abort mechanism differs fundamentally from SGLang's: + +| Aspect | SGLang | vLLM | +|---|---|---| +| Abort granularity | Per-request via `POST /abort_request` with `rid` | Per-connection via HTTP disconnect | +| Bulk abort | `{"abort_all": true}` | No built-in equivalent | +| Mechanism | Engine tracks `request_id`, explicit `abort()` | `@with_cancellation` decorator; request cancelled when client disconnects | +| Between-generation abort | Abort + restart | `POST /pause` → training → `POST /resume` | + +### Recommended implementation + +For the Slime RL use case, the rollout calls `abort_all` between generation rounds (to clear the engine before the next batch). The best vLLM equivalent is: + +```python +# In the translation sidecar +@app.post("/abort_request") +async def abort_request(request: Request): + body = await request.json() + if body.get("abort_all"): + # Option 1: Close all tracked httpx connections → triggers vLLM cancellation + for conn in active_connections: + await conn.aclose() + active_connections.clear() + + # Option 2: Use pause/resume (cleaner) + await httpx.post(f"{vllm_url}/pause") + await httpx.post(f"{vllm_url}/resume") + + return {"status": "ok"} +``` + +--- + +## 6. Endpoints Summary — Gap Analysis + +### Engine-side endpoints (vLLM built-in vs. needs implementation) + +| Endpoint | SGLang | vLLM Built-in | Action | +|---|---|---|---| +| `POST /v1/completions` | — | ✅ | **Reuse** — target for translation | +| `GET /health` | ✅ | ✅ | **Reuse** as-is (passthrough) | +| `POST /pause` | — | ✅ (dev mode) | **Reuse** for abort/weight-update | +| `POST /resume` | — | ✅ (dev mode) | **Reuse** for abort/weight-update | +| `POST /sleep` | — | ✅ (dev mode) | **Reuse** for weight updates | +| `POST /wake_up` | — | ✅ (dev mode) | **Reuse** for weight updates | +| `POST /collective_rpc` | — | ✅ (dev mode) | **Reuse** for weight reload | +| `GET /is_sleeping` | — | ✅ (dev mode) | **Reuse** for state checks | +| `POST /init_weight_transfer_engine` | — | ✅ (dev mode) | **Reuse** for NCCL setup | +| `POST /update_weights` | — | ✅ (dev mode) | **Reuse** for NCCL weight apply | +| `GET /get_world_size` | — | ✅ (dev mode) | **Reuse** for TP world size | + +### Translation sidecar endpoints (to implement) + +| Endpoint | Description | Complexity | +|---|---|---| +| `POST /generate` | Translate SGLang → `/v1/completions` → SGLang | **Medium** — main logic | +| `GET /health` | Proxy to vLLM `/health` | **Trivial** | +| `GET /health_generate` | Health + optional dummy completion | **Low** | +| `POST /abort_request` | Close connections or pause/resume | **Low** | +| `GET /flush_cache` | `POST /sleep?level=1` + `POST /wake_up?tags=kv_cache` | **Low** | +| `GET /get_weight_version` | Return sidecar-tracked version counter | **Trivial** | + +### Router endpoints (no change needed) + +| Endpoint | Action | +|---|---| +| `POST /add_worker` | No change | +| `GET /list_workers` | No change | +| `POST /retrieve_from_text` | No change | +| Catch-all proxy | No change | + +--- + + + + diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index 208ec60c91..a315590808 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -95,6 +95,8 @@ VLLM_ARGS=( --rollout-backend vllm --rollout-num-gpus-per-engine 1 --sglang-server-concurrency 512 + --use-slime-router + --slime-router-middleware-paths slime.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware ) MISC_ARGS=( @@ -116,6 +118,8 @@ ray job submit --address="http://127.0.0.1:8265" \ "NCCL_ALGO": "Ring", "NCCL_IB_DISABLE": "1", "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + "NCCL_NET_GDR_LEVEL": "0", "NCCL_DEBUG": "INFO", "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", "CUBLAS_WORKSPACE_CONFIG": ":4096:8" diff --git a/slime/backends/vllm_utils/__init__.py b/slime/backends/vllm_utils/__init__.py index ea4e7311c3..41b5f9909c 100644 --- a/slime/backends/vllm_utils/__init__.py +++ b/slime/backends/vllm_utils/__init__.py @@ -1,3 +1,4 @@ from slime.backends.vllm_utils.vllm_engine import VLLMEngine +from slime.backends.vllm_utils.vllm_translation_sidecar import TranslationSidecar, run_sidecar -__all__ = ["VLLMEngine"] +__all__ = ["VLLMEngine", "TranslationSidecar", "run_sidecar"] diff --git a/slime/backends/vllm_utils/vllm_engine.py b/slime/backends/vllm_utils/vllm_engine.py index bf645cd0bd..713793b07c 100644 --- a/slime/backends/vllm_utils/vllm_engine.py +++ b/slime/backends/vllm_utils/vllm_engine.py @@ -1,6 +1,7 @@ -"""VLLMEngine: Ray actor that launches and manages a vLLM server.""" +"""VLLMEngine: Ray actor that launches and manages a vLLM server + translation sidecar.""" import logging +import multiprocessing import os import subprocess import tempfile @@ -25,14 +26,32 @@ def __init__(self, args, rank: int, base_gpu_id: int | None = None, gpu_ids: lis self.gpu_ids = gpu_ids or [self.base_gpu_id] self.server_host = None self.server_port = None + self.sidecar_port = None self.process = None + self.sidecar_process = None self._log_file = None + self._sidecar_log_file = None + self._weight_version: int = 0 - def init(self, port=None, host=None, **kwargs): + @property + def sidecar_url(self) -> str: + """URL of the translation sidecar (registered with the router).""" + return f"http://{self.server_host}:{self.sidecar_port}" + + @property + def vllm_url(self) -> str: + """URL of the raw vLLM server.""" + return f"http://{self.server_host}:{self.server_port}" + + def init(self, port=None, host=None, router_ip=None, router_port=None, **kwargs): self.server_host = host or get_host_info()[1] self.server_port = port or get_free_port(15000) + self.sidecar_port = get_free_port(self.server_port + 100) + self.router_ip = router_ip or getattr(self.args, "sglang_router_ip", None) + self.router_port = router_port or getattr(self.args, "sglang_router_port", None) model = getattr(self.args, "vllm_model", None) or self.args.hf_checkpoint + self._model_name = model tp = self.args.rollout_num_gpus_per_engine gpu_ids = self.gpu_ids[:tp] cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "") @@ -52,6 +71,8 @@ def init(self, port=None, host=None, **kwargs): "--seed", str(seed), "--trust-remote-code", ] + gpu_mem_util = getattr(self.args, "vllm_gpu_memory_utilization", 0.4) + cmd.extend(["--gpu-memory-utilization", str(gpu_mem_util)]) if getattr(self.args, "offload_rollout", False): cmd.append("--enable-sleep-mode") if getattr(self.args, "vllm_enforce_eager", True): @@ -80,6 +101,14 @@ def init(self, port=None, host=None, **kwargs): ) self._wait_healthy() + # Launch the translation sidecar + self._launch_sidecar() + self._wait_sidecar_healthy() + + # Register the sidecar URL with the router + if self.router_ip and self.router_port: + self._register_with_router() + def _wait_healthy(self, timeout=300): base = f"http://{self.server_host}:{self.server_port}" start = time.time() @@ -98,6 +127,78 @@ def _wait_healthy(self, timeout=300): log_tail = self._read_log_tail() raise TimeoutError(f"vLLM server failed to become healthy within {timeout}s.\n{log_tail}") + def _launch_sidecar(self): + """Launch the translation sidecar as a subprocess.""" + from slime.backends.vllm_utils.vllm_translation_sidecar import run_sidecar + + self._sidecar_log_file = tempfile.NamedTemporaryFile( + prefix="vllm_sidecar_", suffix=".log", delete=False, mode="w" + ) + + def _target(): + run_sidecar( + vllm_host="127.0.0.1", + vllm_port=self.server_port, + sidecar_host="0.0.0.0", + sidecar_port=self.sidecar_port, + model_name=self._model_name, + log_level="info", + ) + + self.sidecar_process = multiprocessing.Process(target=_target, daemon=True) + self.sidecar_process.start() + logger.info( + "Launched translation sidecar on port %s (vLLM → %s:%s), log=%s", + self.sidecar_port, + self.server_host, + self.server_port, + self._sidecar_log_file.name, + ) + + def _wait_sidecar_healthy(self, timeout: float = 60.0): + """Block until the sidecar /health endpoint responds 200.""" + url = f"{self.sidecar_url}/health" + start = time.time() + while time.time() - start < timeout: + try: + r = requests.get(url, timeout=5) + if r.status_code == 200: + logger.info("Translation sidecar healthy at %s", self.sidecar_url) + return + except Exception: + pass + if self.sidecar_process and not self.sidecar_process.is_alive(): + raise RuntimeError( + f"Sidecar process exited with code {self.sidecar_process.exitcode}" + ) + time.sleep(1) + raise TimeoutError(f"Sidecar failed to become healthy within {timeout}s") + + def _register_with_router(self): + """Register the sidecar URL with the SlimeRouter.""" + router_url = f"http://{self.router_ip}:{self.router_port}" + response = requests.post( + f"{router_url}/add_worker", + params={"url": self.sidecar_url}, + ) + response.raise_for_status() + logger.info( + "Registered sidecar %s with router at %s", + self.sidecar_url, + router_url, + ) + + def _bump_weight_version(self, version: int | None = None): + """Notify the sidecar to increment (or set) its weight version counter.""" + url = f"{self.sidecar_url}/set_weight_version" + payload = {"weight_version": version} if version is not None else {} + try: + r = requests.post(url, json=payload, timeout=10) + r.raise_for_status() + self._weight_version = r.json().get("weight_version", self._weight_version) + except Exception as exc: + logger.warning("Failed to bump sidecar weight version: %s", exc) + def _read_log_tail(self, n=200): if not self._log_file: return "" @@ -186,6 +287,8 @@ def update_weights_from_distributed( "packed": packed, } }) + # Notify the sidecar about the new weight version + self._bump_weight_version(weight_version) def continue_generation(self): self._post("/resume") @@ -209,7 +312,14 @@ def resume_memory_occupation(self, tags: list[str] | None = None): logger.warning("vLLM wake_up failed: %s", e) def get_weight_version(self): - return None + if self.sidecar_port: + try: + r = requests.get(f"{self.sidecar_url}/get_weight_version", timeout=5) + r.raise_for_status() + return r.json().get("weight_version", self._weight_version) + except Exception: + pass + return self._weight_version def check_weights(self, action: str): pass @@ -218,6 +328,20 @@ def post_process_weights(self, **kwargs): pass def shutdown(self): + # Shutdown translation sidecar first + if self.sidecar_process and self.sidecar_process.is_alive(): + self.sidecar_process.terminate() + self.sidecar_process.join(timeout=10) + if self.sidecar_process.is_alive(): + self.sidecar_process.kill() + self.sidecar_process = None + if self._sidecar_log_file: + try: + self._sidecar_log_file.close() + except Exception: + pass + + # Shutdown vLLM server if self.process: self.process.terminate() try: diff --git a/slime/backends/vllm_utils/vllm_translation_sidecar.py b/slime/backends/vllm_utils/vllm_translation_sidecar.py new file mode 100644 index 0000000000..73998d8bf4 --- /dev/null +++ b/slime/backends/vllm_utils/vllm_translation_sidecar.py @@ -0,0 +1,454 @@ +""" +vLLM Translation Sidecar +======================== + +Lightweight FastAPI process co-located with each vLLM engine. +Receives SGLang-format ``/generate`` requests from the SlimeRouter, +translates them to vLLM ``/v1/completions``, and translates responses back. + +Also proxies lifecycle endpoints: + /health, /health_generate, /abort_request, /flush_cache, /get_weight_version + +See docs/en/vllm/ROUTER_DESIGN.md for the full specification. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import signal +import sys +from contextlib import asynccontextmanager +from typing import Any + +import httpx +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Sampling-param translation tables +# --------------------------------------------------------------------------- + +# SGLang name → vLLM /v1/completions name (only entries that differ) +_PARAM_RENAME = { + "max_new_tokens": "max_tokens", + "no_stop_trim": "include_stop_str_in_output", + "sampling_seed": "seed", +} + +# Parameters passed through unchanged +_PARAM_DIRECT = frozenset( + { + "temperature", + "top_p", + "top_k", + "stop", + "stop_token_ids", + "skip_special_tokens", + "spaces_between_special_tokens", + } +) + +# finish_reason vLLM → SGLang-style {"type": ...} +_FINISH_REASON_MAP = { + "stop": "stop", + "length": "length", + None: "abort", +} + + +# --------------------------------------------------------------------------- +# Request / response translation helpers +# --------------------------------------------------------------------------- + + +def translate_generate_request( + body: dict[str, Any], + model_name: str, +) -> dict[str, Any]: + """Translate an SGLang-format /generate request → vLLM /v1/completions payload.""" + + sp: dict = body.get("sampling_params", {}) + + vllm_payload: dict[str, Any] = { + "model": model_name, + # vLLM accepts list[int] as a pre-tokenized prompt + "prompt": body.get("input_ids") or body.get("input_tokens", []), + "stream": False, + } + + # --- direct-copy params --- + for key in _PARAM_DIRECT: + if key in sp: + vllm_payload[key] = sp[key] + + # --- renamed params --- + for src, dst in _PARAM_RENAME.items(): + if src in sp: + vllm_payload[dst] = sp[src] + + # --- logprob handling --- + if body.get("return_logprob", False): + vllm_payload["logprobs"] = 1 + # request token IDs alongside logprobs + # NOTE: must be a top-level param; "extra_body" is an OpenAI SDK + # client concept and is ignored by vLLM's raw HTTP API. + vllm_payload["return_token_ids"] = True + + return vllm_payload + + +def translate_vllm_response( + vllm_resp: dict[str, Any], + weight_version: int, +) -> dict[str, Any]: + """Translate a vLLM /v1/completions response → SGLang-format JSON.""" + + choice: dict = vllm_resp.get("choices", [{}])[0] + usage: dict = vllm_resp.get("usage", {}) + + # --- output token IDs --- + output_ids: list[int] = choice.get("token_ids", []) + + # --- logprobs: zip(logprob, token_id) --- + output_token_logprobs: list[list[float | int]] = [] + logprobs_obj = choice.get("logprobs") + if logprobs_obj and output_ids: + raw_lp: list[float | None] = logprobs_obj.get("token_logprobs", []) + output_token_logprobs = [ + [float(lp) if lp is not None else 0.0, tid] + for lp, tid in zip(raw_lp, output_ids) + ] + + # --- finish reason --- + raw_reason = choice.get("finish_reason") + mapped = _FINISH_REASON_MAP.get(raw_reason, raw_reason or "abort") + finish_reason = {"type": mapped} + + meta_info: dict[str, Any] = { + "finish_reason": finish_reason, + "weight_version": weight_version, + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", len(output_ids)), + "cached_tokens": 0, + } + # Only include output_token_logprobs when we have valid paired data; + # a None value causes RadixTreeMiddleware to silently fail when iterating. + if output_token_logprobs: + meta_info["output_token_logprobs"] = output_token_logprobs + + return { + "text": choice.get("text", ""), + "output_ids": output_ids, + "meta_info": meta_info, + } + + +# --------------------------------------------------------------------------- +# Sidecar application +# --------------------------------------------------------------------------- + + +class TranslationSidecar: + """Manages state and provides the FastAPI app for the translation sidecar.""" + + def __init__( + self, + vllm_base_url: str, + model_name: str, + *, + timeout: float = 600.0, + max_connections: int = 256, + ): + self.vllm_base_url = vllm_base_url.rstrip("/") + self.model_name = model_name + self._weight_version: int = 0 + self._active_connections: set[httpx.Response] = set() + self._lock = asyncio.Lock() + + self._client: httpx.AsyncClient | None = None + self._timeout = timeout + self._max_connections = max_connections + + self.app = self._build_app() + + # ---- lifecycle ----------------------------------------------------------- + + async def startup(self): + self._client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=self._max_connections, + max_keepalive_connections=self._max_connections, + ), + timeout=httpx.Timeout(self._timeout), + ) + + async def shutdown(self): + if self._client: + await self._client.aclose() + self._client = None + + # ---- app factory --------------------------------------------------------- + + def _build_app(self) -> FastAPI: + + @asynccontextmanager + async def lifespan(app: FastAPI): + await self.startup() + yield + await self.shutdown() + + app = FastAPI(title="vLLM Translation Sidecar", lifespan=lifespan) + + app.post("/generate")(self.generate) + app.get("/health")(self.health) + app.get("/health_generate")(self.health_generate) + app.post("/abort_request")(self.abort_request) + app.get("/flush_cache")(self.flush_cache) + app.get("/get_weight_version")(self.get_weight_version) + app.post("/set_weight_version")(self.set_weight_version) + + return app + + # ---- endpoints ----------------------------------------------------------- + + async def generate(self, request: Request): + """Translate SGLang /generate → vLLM /v1/completions → SGLang response.""" + + body = await request.json() + vllm_payload = translate_generate_request(body, self.model_name) + + url = f"{self.vllm_base_url}/v1/completions" + + resp: httpx.Response | None = None + try: + async with self._lock: + # We don't actually hold the lock during the request, + # just use it to safely add to the tracking set. + pass + + resp = await self._client.post(url, json=vllm_payload) + self._active_connections.add(resp) + + resp.raise_for_status() + vllm_data = resp.json() + except httpx.HTTPStatusError as exc: + logger.error( + "vLLM /v1/completions returned %s: %s", + exc.response.status_code, + exc.response.text[:2000], + ) + # Return an abort-style response so the middleware retries + return JSONResponse( + content={ + "text": "", + "output_ids": [], + "meta_info": { + "finish_reason": {"type": "abort"}, + "weight_version": self._weight_version, + "prompt_tokens": 0, + "completion_tokens": 0, + "cached_tokens": 0, + }, + }, + status_code=200, + ) + except Exception as exc: + logger.error("Error calling vLLM: %s", exc, exc_info=True) + return JSONResponse( + content={ + "text": "", + "output_ids": [], + "meta_info": { + "finish_reason": {"type": "abort"}, + "weight_version": self._weight_version, + "prompt_tokens": 0, + "completion_tokens": 0, + "cached_tokens": 0, + }, + }, + status_code=200, + ) + finally: + if resp is not None: + self._active_connections.discard(resp) + await resp.aclose() + + translated = translate_vllm_response(vllm_data, self._weight_version) + return JSONResponse(content=translated) + + async def health(self): + """Proxy to vLLM's built-in /health endpoint.""" + try: + resp = await self._client.get(f"{self.vllm_base_url}/health", timeout=5.0) + return JSONResponse(content={"status": "ok"}, status_code=resp.status_code) + except Exception: + return JSONResponse(content={"status": "unhealthy"}, status_code=503) + + async def health_generate(self): + """ + Startup readiness probe. + + Checks vLLM /health and optionally fires a dummy /v1/completions + with max_tokens=1 to verify end-to-end readiness. + """ + try: + resp = await self._client.get(f"{self.vllm_base_url}/health", timeout=10.0) + if resp.status_code != 200: + return JSONResponse(content={"status": "not_ready"}, status_code=503) + + # Lightweight smoke test: single-token completion + dummy_payload = { + "model": self.model_name, + "prompt": "hi", + "max_tokens": 1, + "stream": False, + } + resp2 = await self._client.post( + f"{self.vllm_base_url}/v1/completions", + json=dummy_payload, + timeout=30.0, + ) + if resp2.status_code == 200: + return JSONResponse(content={"status": "ready"}, status_code=200) + else: + return JSONResponse(content={"status": "not_ready"}, status_code=503) + except Exception as exc: + logger.debug("health_generate check failed: %s", exc) + return JSONResponse(content={"status": "not_ready"}, status_code=503) + + async def abort_request(self, request: Request): + """ + Handle abort requests. + + vLLM uses HTTP disconnect for cancellation. We close all active + connections to the vLLM backend, which triggers vLLM's + ``@with_cancellation`` decorator to abort in-flight requests. + + Alternatively, use pause/resume for a cleaner between-generation abort. + """ + body = await request.json() + + if body.get("abort_all", False): + # Strategy: pause + resume clears the pipeline + try: + await self._client.post( + f"{self.vllm_base_url}/pause", + params={"mode": "abort"}, + timeout=30.0, + ) + await self._client.post( + f"{self.vllm_base_url}/resume", + timeout=30.0, + ) + except Exception as exc: + logger.warning("pause/resume abort failed, falling back to connection close: %s", exc) + # Fallback: close all tracked connections + conns = list(self._active_connections) + self._active_connections.clear() + for conn in conns: + try: + await conn.aclose() + except Exception: + pass + + return JSONResponse(content={"status": "ok"}) + + async def flush_cache(self): + """ + Flush the KV cache. + + vLLM equivalent: sleep(level=1) + wake_up(tags=kv_cache). + """ + try: + await self._client.post( + f"{self.vllm_base_url}/sleep", + params={"level": "1", "mode": "abort"}, + timeout=30.0, + ) + await self._client.post( + f"{self.vllm_base_url}/wake_up", + params={"tags": "kv_cache"}, + timeout=30.0, + ) + return JSONResponse(content={"status": "ok"}) + except Exception as exc: + logger.warning("flush_cache failed: %s", exc) + return JSONResponse(content={"status": "ok"}) + + async def get_weight_version(self): + """Return the sidecar-tracked weight version counter.""" + return JSONResponse(content={"weight_version": self._weight_version}) + + async def set_weight_version(self, request: Request): + """Increment or set the weight version (called by VLLMEngine after weight update).""" + body = await request.json() + if "weight_version" in body: + self._weight_version = int(body["weight_version"]) + else: + self._weight_version += 1 + return JSONResponse(content={"weight_version": self._weight_version}) + + +# --------------------------------------------------------------------------- +# Standalone entry point +# --------------------------------------------------------------------------- + + +def run_sidecar( + vllm_host: str = "127.0.0.1", + vllm_port: int = 8000, + sidecar_host: str = "0.0.0.0", + sidecar_port: int = 8100, + model_name: str = "default", + timeout: float = 600.0, + max_connections: int = 256, + log_level: str = "info", +): + """Launch the translation sidecar as a standalone uvicorn process.""" + + vllm_base_url = f"http://{vllm_host}:{vllm_port}" + sidecar = TranslationSidecar( + vllm_base_url=vllm_base_url, + model_name=model_name, + timeout=timeout, + max_connections=max_connections, + ) + uvicorn.run( + sidecar.app, + host=sidecar_host, + port=sidecar_port, + log_level=log_level, + ) + + +def main(): + parser = argparse.ArgumentParser(description="vLLM Translation Sidecar") + parser.add_argument("--vllm-host", type=str, default="127.0.0.1") + parser.add_argument("--vllm-port", type=int, default=8000) + parser.add_argument("--sidecar-host", type=str, default="0.0.0.0") + parser.add_argument("--sidecar-port", type=int, default=8100) + parser.add_argument("--model-name", type=str, default="default") + parser.add_argument("--timeout", type=float, default=600.0) + parser.add_argument("--max-connections", type=int, default=256) + parser.add_argument("--log-level", type=str, default="info") + args = parser.parse_args() + + run_sidecar( + vllm_host=args.vllm_host, + vllm_port=args.vllm_port, + sidecar_host=args.sidecar_host, + sidecar_port=args.sidecar_port, + model_name=args.model_name, + timeout=args.timeout, + max_connections=args.max_connections, + log_level=args.log_level, + ) + + +if __name__ == "__main__": + main() diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ba18fb9c47..9d1a786ae0 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1022,46 +1022,87 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool def _start_vllm_rollout_servers(args, pg) -> dict[str, RolloutServer]: - """Start vLLM rollout server (single instance, no router).""" + """Start vLLM rollout server(s) with optional SlimeRouter. + + When ``args.use_slime_router`` is True, a SlimeRouter is launched first + and each vLLM engine registers its translation sidecar with the router. + Otherwise, a single engine is started and used directly (no router). + """ pg_obj, reordered_bundle_indices, reordered_gpu_ids = pg tp = args.rollout_num_gpus_per_engine - gpu_ids = [int(reordered_gpu_ids[i]) for i in range(tp)] - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=pg_obj, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=reordered_bundle_indices[0], - ) + num_gpu_per_engine_local = min(tp, args.num_gpus_per_node) + num_engines = max(1, args.rollout_num_gpus // num_gpu_per_engine_local) - env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} + use_router = getattr(args, "use_slime_router", False) + + # --- Launch the SlimeRouter if requested --- + if use_router: + router_ip, router_port = _start_router(args, has_pd_disaggregation=False) + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port + else: + router_ip = None + router_port = None + env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} VLLMRayActor = ray.remote(VLLMEngine) - engine = VLLMRayActor.options( - num_cpus=0.2, - num_gpus=0.2, - scheduling_strategy=scheduling_strategy, - runtime_env={"env_vars": env_vars}, - ).remote(args, rank=0, base_gpu_id=gpu_ids[0], gpu_ids=gpu_ids) - - host, port = ray.get(engine._get_current_node_ip_and_free_port.remote(start_port=15000)) - ray.get(engine.init.remote(port=port, host=host)) - args.vllm_base_url = f"http://{host}:{port}" - args.sglang_router_ip = host - args.sglang_router_port = port + + engines = [] + init_handles = [] + for i in range(num_engines): + gpu_index = i * num_gpu_per_engine_local + gpu_ids = [int(reordered_gpu_ids[gpu_index + j]) for j in range(num_gpu_per_engine_local)] + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg_obj, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[gpu_index], + ) + + engine = VLLMRayActor.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=scheduling_strategy, + runtime_env={"env_vars": env_vars}, + ).remote(args, rank=i, base_gpu_id=gpu_ids[0], gpu_ids=gpu_ids) + + host, port = ray.get(engine._get_current_node_ip_and_free_port.remote(start_port=15000 + i * 10)) + + init_handles.append( + engine.init.remote( + port=port, + host=host, + router_ip=router_ip, + router_port=router_port, + ) + ) + engines.append(engine) + + # Use the first engine's host for args if no router + if i == 0 and not use_router: + args.vllm_base_url = f"http://{host}:{port}" + args.sglang_router_ip = host + args.sglang_router_port = port + + # Wait for all engines to be healthy + registered with router + ray.get(init_handles) + + first_host = args.sglang_router_ip + first_port = args.sglang_router_port group = EngineGroup( args=args, pg=pg, - all_engines=[engine], + all_engines=engines, num_gpus_per_engine=args.rollout_num_gpus_per_engine, - num_new_engines=1, + num_new_engines=num_engines, worker_type="regular", rank_offset=0, gpu_offset=0, sglang_overrides={}, - router_ip=host, - router_port=port, + router_ip=first_host, + router_port=first_port, ) - return {"default": RolloutServer(engine_groups=[group], router_ip=host, router_port=port, model_name="default")} + return {"default": RolloutServer(engine_groups=[group], router_ip=first_host, router_port=first_port, model_name="default")} def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: diff --git a/slime/rollout/backends/vllm_client.py b/slime/rollout/backends/vllm_client.py index 32be01d66c..d5e2ed6bdf 100644 --- a/slime/rollout/backends/vllm_client.py +++ b/slime/rollout/backends/vllm_client.py @@ -2,7 +2,7 @@ from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient from slime.rollout.base_types import RolloutBackendRequest, RolloutBackendResponse -from slime.utils.http_utils import post +from slime.utils.http_utils import get, post logger = logging.getLogger(__name__) @@ -16,14 +16,27 @@ class VLLMClient(RolloutBackendClient): + """Rollout backend client for vLLM. + + Supports two modes: + + 1. **Direct mode** (default): sends requests directly to vLLM's + ``/v1/completions`` endpoint and parses the native vLLM response. + 2. **Router mode** (``use_slime_router=True``): sends SGLang-format + requests to ``/generate`` through the SlimeRouter, which forwards + to the translation sidecar. The sidecar handles the vLLM translation + and returns SGLang-format responses. + """ + def __init__(self, args): self.args = args self._max_retries = getattr(args, "vllm_max_retries", 3) + self._use_router = getattr(args, "use_slime_router", False) @property def capabilities(self) -> BackendCapabilities: return BackendCapabilities( - supports_abort=False, + supports_abort=self._use_router, # abort is supported through the sidecar supports_routed_experts=False, supports_prompt_logprobs=False, ) @@ -33,6 +46,69 @@ async def generate( request: RolloutBackendRequest, base_url: str, headers: dict | None = None, + ) -> RolloutBackendResponse: + if self._use_router: + return await self._generate_via_router(request, base_url, headers) + return await self._generate_direct(request, base_url, headers) + + # ------------------------------------------------------------------ + # Router mode: SGLang-format /generate → sidecar → vLLM + # ------------------------------------------------------------------ + + async def _generate_via_router( + self, + request: RolloutBackendRequest, + base_url: str, + headers: dict | None = None, + ) -> RolloutBackendResponse: + """Send SGLang-format request through the SlimeRouter → sidecar pipeline.""" + + payload = { + "input_ids": request.input_ids, + "sampling_params": request.sampling_params, + "return_logprob": request.return_logprob, + "stream": False, + } + if request.image_data: + payload["image_data"] = request.image_data + + url = f"{base_url.rstrip('/')}/generate" + output = await post(url, payload, headers=headers) + + # Parse SGLang-format response (produced by translation sidecar) + meta = output.get("meta_info", {}) + logprobs_data = meta.get("output_token_logprobs", []) + + # output_token_logprobs is list of [logprob, token_id] + output_token_ids = [item[1] for item in logprobs_data] if logprobs_data else [] + output_token_logprobs = [item[0] for item in logprobs_data] if logprobs_data else [] + + # Fall back to output_ids if logprobs not available + if not output_token_ids: + output_token_ids = output.get("output_ids") or [] + + finish_reason = meta.get("finish_reason", {}).get("type", "stop") + + return RolloutBackendResponse( + text=output.get("text", ""), + output_token_ids=output_token_ids, + output_token_logprobs=output_token_logprobs, + finish_reason=finish_reason, + prompt_tokens=meta.get("prompt_tokens", len(request.input_ids)), + completion_tokens=len(output_token_ids), + backend_raw=output, + routed_experts=None, + ) + + # ------------------------------------------------------------------ + # Direct mode: vLLM /v1/completions (no router) + # ------------------------------------------------------------------ + + async def _generate_direct( + self, + request: RolloutBackendRequest, + base_url: str, + headers: dict | None = None, ) -> RolloutBackendResponse: sp = request.sampling_params payload = { @@ -91,3 +167,15 @@ async def generate( backend_raw=output, routed_experts=None, ) + + # ------------------------------------------------------------------ + # Abort (router mode only) + # ------------------------------------------------------------------ + + async def abort(self) -> list[str]: + """Return worker URLs for abort. Only works in router mode.""" + if not self._use_router: + return [] + base = f"http://{self.args.sglang_router_ip}:{self.args.sglang_router_port}" + r = await get(f"{base}/list_workers") + return r.get("urls", []) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 7ef8e70b71..4bb5d3506d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -263,6 +263,13 @@ def add_rollout_arguments(parser): default=True, help="Use vLLM packed weight transfer for non-colocate (default: True). Disable for per-bucket mode.", ) + parser.add_argument( + "--vllm-gpu-memory-utilization", + type=float, + default=0.4, + help="Fraction of GPU memory for vLLM KV cache (default: 0.85). " + "Lower this if the training model leaves insufficient free memory.", + ) parser.add_argument( "--rollout-function-path", type=str,