diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index ab10502eec..22e3fd814d 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( These are the parameters required by sglang. Here, `--rollout-num-gpus-per-engine` basically corresponds to sglang's `tp_size`. Other sglang parameters are passed to slime by adding a `--sglang-` prefix. To fully leverage sglang's large EP inference capabilities, we have added configurations like ep64, dp\_attention dp8, and deepep mode auto. -The final `--sglang-server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. +The final `--rollout-server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank have 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 ) ``` diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 56711878f6..688b30b03c 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -280,10 +280,10 @@ In this case, 2 GPUs will be allocated for training, and 6 GPUs will be allocate ⚠️ If the concurrency on each sglang server is too high, it may exceed sglang's default CUDA graph concurrency limit (the default maximum is 160), which will affect inference speed. You can adjust this in the following two ways: -1. Use `--sglang-server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: +1. Use `--rollout-server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: ```bash - --sglang-server-concurrency 160 + --rollout-server-concurrency 160 ``` 2. Use `--sglang-cuda-graph-bs` (which corresponds to sglang's native `--cuda-graph-bs` argument) to increase the number of CUDA graphs initialized by sglang. For example: diff --git a/docs/zh/examples/deepseek-r1.md b/docs/zh/examples/deepseek-r1.md index 368653844d..703a7730e4 100644 --- a/docs/zh/examples/deepseek-r1.md +++ b/docs/zh/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( sglang 所需的参数,这里 `--rollout-num-gpus-per-engine` 基本对应 sglang 的 `tp_size`,除此之外的 sglang 参数均通过添加 `--sglang-` 的前缀来传给 slime。为了充分利用 sglang 的大 EP 推理能力,我们加上了 ep64、dp_attention dp8、deepep mode auto 等配置。 -最后的 `--sglang-server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 +最后的 `--rollout-server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 ) ``` diff --git a/docs/zh/examples/qwen3-4B.md b/docs/zh/examples/qwen3-4B.md index 2afc70bfbe..1f3adc9c51 100644 --- a/docs/zh/examples/qwen3-4B.md +++ b/docs/zh/examples/qwen3-4B.md @@ -280,10 +280,10 @@ ray job submit ... \ ⚠️ 在进行训推分离的时候,每个 sglang server 上的并发度太大,超过了 sglang 默认的 cuda graph 的并发度(默认最大 160),影响推理速度。可以用以下 2 种方式进行调整: -1. 通过 `--sglang-server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: +1. 通过 `--rollout-server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: ```bash - --sglang-server-concurrency 160 + --rollout-server-concurrency 160 ``` 2. 使用 `--sglang-cuda-graph-bs`,即 sglang 原生的 `--cuda-graph-bs`, 增大 sglang 初始化的 cuda graph 数量,例如: diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 7208365c18..1df005ed02 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -20,7 +20,7 @@ def get_global_worker(args, data_buffer): with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): print("Creating new global async worker...") - _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency) + _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.rollout_server_concurrency) _global_worker.start() return _global_worker diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/run_qwen3_4B.sh index a821734012..fa36af5526 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/run_qwen3_4B.sh @@ -100,7 +100,7 @@ SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.7 # If gemini API reports concurrency limit error, set this parameter to reduce the concurrency - # --sglang-server-concurrency 32 + # --rollout-server-concurrency 32 ) MISC_ARGS=( diff --git a/rfc-rollout-backend-separation-plan.md b/rfc-rollout-backend-separation-plan.md new file mode 100644 index 0000000000..ef715c2a6d --- /dev/null +++ b/rfc-rollout-backend-separation-plan.md @@ -0,0 +1,271 @@ +# RFC: Rollout Separation Plan (EngineGroup Generalization + Executor Cleanup) + +## 1. Summary + +This RFC proposes backend separation with minimal churn in runtime orchestration. + +Scope is four items: + +1. Refactor [slime/ray/rollout.py](slime/ray/rollout.py) by generalizing `EngineGroup.start_engines()` and abstracting engine/server creation (no new runtime manager class hierarchy). +2. Refactor [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) in-place: extract the one SGLang-specific code path (RadixTree) into a strategy hook, rename SGLang-prefixed args to generic names. No class hierarchy, no new files. +3. Refactor [slime/utils/arguments.py](slime/utils/arguments.py) into shared args + backend arg groups/finalizers. +4. Decouple [slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) from SGLang internals so FSDP weight sync works with both SGLang and vLLM engines. + +## 2. Already Done (Reuse, Do Not Rewrite) + +- Unified rollout contracts in [slime/rollout/base_types.py](slime/rollout/base_types.py). +- Backend client abstraction in [slime/rollout/backends/base_client.py](slime/rollout/backends/base_client.py). +- Backend adapters in [slime/rollout/backends/sglang_client.py](slime/rollout/backends/sglang_client.py) and [slime/rollout/backends/vllm_client.py](slime/rollout/backends/vllm_client.py). +- Managed vLLM engine actor in [slime/backends/vllm_utils/vllm_engine.py](slime/backends/vllm_utils/vllm_engine.py). +- vLLM translation sidecar in [slime/backends/vllm_utils/vllm_translation_sidecar.py](slime/backends/vllm_utils/vllm_translation_sidecar.py). + +## 3. Problem + +- [slime/ray/rollout.py](slime/ray/rollout.py) mixes shared and backend-specific engine creation paths. +- [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) is ~95% backend-agnostic but has one inlined SGLang-specific code path (RadixTree, ~14 lines) and uses SGLang-prefixed arg names for generic rollout concepts. +- [slime/utils/arguments.py](slime/utils/arguments.py) still has SGLang alias behavior in vLLM path. +- [slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) hard-imports SGLang internals (`FlattenedTensorBucket`, `MultiprocessingSerializer`, `monkey_patch_torch_reductions`) and calls SGLang-specific engine RPC names (`update_weights_from_tensor`, `update_weights_from_distributed`), making FSDP weight sync unusable with vLLM engines. + +## 4. Goals and Non-Goals + +### Goals + +- Keep runtime refactor minimal and localized to `EngineGroup` + creation abstraction. +- Isolate the one SGLang-specific executor code path behind a strategy hook; keep functions as functions. +- Rename SGLang-prefixed arg names to generic rollout names to eliminate naming coupling. +- Reduce backend leakage in argument finalization. +- Preserve current external behavior, call sites, and import paths. + +### Non-Goals + +- No rewrite of `SGLangEngine` or `VLLMEngine` internals. +- No algorithmic changes to GRPO/PPO. +- No mandatory feature parity for unsupported backend capabilities. + +## 5. Design + +### 5.1 Runtime: Generalize `EngineGroup` (No New Runtime Manager Classes) + +Keep [slime/ray/rollout.py](slime/ray/rollout.py) as the orchestration entry file. + +Refactor focus: + +1. Generalize `EngineGroup.start_engines()` to call backend-aware creation hooks. +2. Abstract engine creation and rollout-server assembly helpers. +3. Keep existing startup function API (`start_rollout_servers`) and return shape. + +Proposed helper abstraction points: + +- `create_engine_actor_cls(args, worker_type)` + - returns `ray.remote(SGLangEngine)` or `ray.remote(VLLMEngine)`. +- `create_engine_remote(args, actor_cls, scheduling_strategy, ...)` + - encapsulates `.options(...).remote(...)` with backend-specific init kwargs. +- `build_rollout_server(...)` + - standardizes `RolloutServer` construction from engine groups. + +`EngineGroup` and `RolloutServer` remain the main shared dataclasses. + +### 5.2 Executor: Isolate Backend Logic In-Place (No Class Hierarchy) + +#### Current state analysis + +[slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) (577 lines) is **already ~95% backend-agnostic**: + +| Function / Class | Lines | Backend-specific? | Notes | +|---|---|---|---| +| `_get_backend_client()` | 6 | Factory only | Delegates to existing `RolloutBackendClient` subclasses | +| `_apply_backend_response()` | 20 | No | Uses `RolloutBackendResponse` contract | +| `GenerateState` | 37 | **Naming only** | References `sglang_server_concurrency`, `sglang_dp_size`, `sglang_enable_deterministic_inference` — all are generic rollout concepts with SGLang-prefixed names | +| `generate()` | 58 | **14 lines** | RadixTree middleware path (L170-183) is 100% SGLang-specific; the else branch (L185-195) already uses `RolloutBackendClient` | +| `generate_and_rm()` | 60 | No | Shared orchestration (semaphore, custom func, reward) | +| `generate_and_rm_group()` | 37 | No | Group parallelism + deterministic seeds | +| `abort()` | 38 | No | Already uses `backend.abort()` | +| `generate_rollout_async()` | 71 | No | Main loop, filtering, metrics | +| `eval_rollout()` / `eval_rollout_single_dataset()` | 118 | **Naming only** | `sglang_enable_deterministic_inference` reference | +| `generate_rollout()` | 41 | No | Sync entry point | + +**Conclusion**: the actual backend logic that needs isolation is **one code path** (~14 lines) inside `generate()`. Everything else is either already abstracted through `RolloutBackendClient` or is a naming-only coupling (SGLang-prefixed arg names for generic concepts). + +#### Approach + +1. **Extract the RadixTree path into a strategy hook** that `generate()` calls conditionally. +2. **Rename SGLang-prefixed args** to generic names (coordinated with Phase 3 args refactor). +3. **Keep functions as functions** — they compose well and callers (`train.py`, OPD, multi-agent) import them directly. + +#### Concrete changes + +**Step 1 — Extract RadixTree strategy from `generate()`** + +Current `generate()` has an `if use_radix: ... else: backend.generate(...)` branch. +Refactor into: + +```python +# slime/rollout/sglang_rollout.py — generate() simplified + +async def generate(args, sample, sampling_params): + ... + input_ids = ... # shared prompt encoding (unchanged) + + strategy = _get_generate_strategy(args) + resp = await strategy(args, sample, input_ids, sampling_params) + _apply_backend_response(sample, resp, args) + return sample +``` + +Two strategies: + +```python +# Still in sglang_rollout.py (no new file needed) + +def _get_generate_strategy(args): + """Return the generate coroutine to use.""" + if _is_radix_tree_enabled(args): + return _generate_radix_tree # SGLang-only path + return _generate_via_backend_client # Generic path (SGLang or vLLM) + +def _is_radix_tree_enabled(args) -> bool: + return ( + args.use_slime_router + and "RadixTreeMiddleware" in getattr(args, "slime_router_middleware_paths", []) + ) + +async def _generate_radix_tree(args, sample, input_ids, sampling_params) -> RolloutBackendResponse: + """SGLang RadixTree middleware path — returns normalized response.""" + from slime.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + url = f"http://{args.rollout_router_ip}:{args.rollout_router_port}/generate" + payload = { ... } # existing payload construction + output = await post(url, payload, headers=headers) + sample = await postprocess_sample_with_radix_tree(args, sample, output) + return _extract_response_from_sample(sample) # normalize to RolloutBackendResponse + +async def _generate_via_backend_client(args, sample, input_ids, sampling_params) -> RolloutBackendResponse: + """Generic backend client path — works for SGLang and vLLM.""" + backend = _get_backend_client(args) + base_url = f"http://{args.rollout_router_ip}:{args.rollout_router_port}" + req = RolloutBackendRequest(...) + return await backend.generate(req, base_url, headers=headers) +``` + +**Step 2 — Rename SGLang-prefixed args to generic names** + +| Current name | New name | Reason | +|---|---|---| +| `sglang_server_concurrency` | `rollout_concurrency` | Controls request parallelism for any backend | +| `sglang_dp_size` | `rollout_dp_size` | Data-parallel sharding, not SGLang-specific | +| `sglang_router_ip` / `sglang_router_port` | `rollout_router_ip` / `rollout_router_port` | Router endpoint, backend-agnostic | +| `sglang_router_policy` | `rollout_router_policy` | Routing strategy | +| `sglang_enable_deterministic_inference` | `rollout_deterministic_inference` | Seed-based determinism | +| `vllm_base_url` | (remove) | Folded into `rollout_router_ip:port`, no special case | + +Legacy aliases kept in [slime/utils/arguments.py](slime/utils/arguments.py) for one release cycle (coordinated with Phase 3). + +**Step 3 — No new files** + +The file stays as [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) during this phase. +Optionally rename to `slime/rollout/rollout.py` in Phase 4 cleanup, since the file is backend-agnostic after the refactor. + +### 5.3 Arguments/Config Refactor + +In [slime/utils/arguments.py](slime/utils/arguments.py), split into: + +1. Shared rollout args. +2. SGLang backend args/validation. +3. vLLM backend args/validation. + +Add backend finalizers: + +- `finalize_sglang_args(args)` +- `finalize_vllm_args(args)` + +Move SGLang alias fallback out of shared finalize flow. + +### 5.4 Weight Sync: Decouple FSDP `update_weight_utils.py` from SGLang Internals + +#### Current state analysis + +[slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) (287 lines) has two concrete classes: + +| Class | Weight-push method | SGLang coupling | +|---|---|---| +| `UpdateWeightFromTensor` | IPC via Gloo gather → `engine.update_weights_from_tensor.remote()` | Imports `FlattenedTensorBucket`, `MultiprocessingSerializer`, `monkey_patch_torch_reductions` directly from `sglang.srt.*` | +| `UpdateWeightFromDistributed` | NCCL broadcast → `engine.update_weights_from_distributed.remote()` | Calls `engine.init_weights_update_group.remote()` — SGLang engine API | + +The abstract base `UpdateWeight` itself is clean (only PyTorch + Ray). + +The Megatron side already solved this: [slime/backends/megatron_utils/sglang.py](slime/backends/megatron_utils/sglang.py) centralizes all SGLang imports into one shim. The FSDP side duplicates these imports inline. + +#### Coupling points + +1. **SGLang utility imports (L13-26)** — `monkey_patch_torch_reductions`, `MultiprocessingSerializer`, `FlattenedTensorBucket` are imported directly from `sglang.srt.*` with try/except version fallbacks. +2. **`UpdateWeightFromTensor.update_bucket_weights()`** — uses `FlattenedTensorBucket` to flatten tensors, `MultiprocessingSerializer` to serialize, then calls `engine.update_weights_from_tensor.remote()`. +3. **`UpdateWeightFromDistributed.connect_rollout_engines()`** — calls `engine.init_weights_update_group.remote()` which is an SGLang engine method. +4. **`UpdateWeightFromDistributed.update_bucket_weights()`** — calls `engine.update_weights_from_distributed.remote()` which is an SGLang engine method. + +All four points assume the engine actor exposes SGLang's RPC interface. vLLM engines expose different method names. + +#### Approach + +**Step 1 — Centralize SGLang imports via a shim (same pattern as Megatron)** + +Reuse or mirror the existing [slime/backends/megatron_utils/sglang.py](slime/backends/megatron_utils/sglang.py) pattern: + +```python +# slime/backends/fsdp_utils/sglang_compat.py (new, ~15 lines) +try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions +except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + +from sglang.srt.utils import MultiprocessingSerializer + +try: + from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket +except ImportError: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket +``` + +Then `update_weight_utils.py` imports from `sglang_compat` — one import line instead of five, and only loaded when SGLang is the active backend. + +**Step 2 — Abstract engine RPC calls behind a protocol** + +The engine actors (`SGLangEngine`, `VLLMEngine`) already expose weight-sync methods but with different names/signatures. Introduce a lightweight protocol or adapter: + +```python +# In update_weight_utils.py or a small helper + +def _call_engine_update_tensor(engine, backend: str, **kwargs): + """Dispatch IPC weight update to the correct engine method.""" + if backend == "vllm": + return engine.update_weights_from_tensor.remote(**kwargs) # vLLM uses same name via NcclBridge + return engine.update_weights_from_tensor.remote(**kwargs) # SGLang native + +def _call_engine_update_distributed(engine, backend: str, **kwargs): + if backend == "vllm": + return engine.update_weights_from_distributed.remote(**kwargs) + return engine.update_weights_from_distributed.remote(**kwargs) + +def _call_engine_init_weight_group(engine, backend: str, **kwargs): + if backend == "vllm": + return engine.init_weights_update_group.remote(**kwargs) + return engine.init_weights_update_group.remote(**kwargs) +``` + +> Note: Currently `VLLMEngine` already mirrors these method names (it wraps them via `NcclBridge`), so the dispatch functions may initially be identical. The value is making the indirection explicit so future method-name divergence is handled in one place. + +**Step 3 — Lazy-import SGLang utilities only when backend is SGLang** + +Move the `FlattenedTensorBucket` / `MultiprocessingSerializer` imports inside `UpdateWeightFromTensor.update_bucket_weights()` behind a lazy import, so the module can be loaded in a vLLM-only environment without SGLang installed. + +#### What changes + +| File | Change | +|---|---| +| `slime/backends/fsdp_utils/sglang_compat.py` | New shim file (~15 lines) centralizing SGLang imports | +| `slime/backends/fsdp_utils/update_weight_utils.py` | Replace 5 inline SGLang imports with one `from .sglang_compat import ...`; add backend-aware engine dispatch helpers | + +#### What does NOT change + +- `UpdateWeight` abstract base class — already clean. +- `UpdateWeightFromDistributed` NCCL logic — the broadcast itself is pure PyTorch; only the engine RPC dispatch gets a thin wrapper. +- Megatron-side weight sync — already has its own shim, not touched. + diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index a315590808..b352fa37d0 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -94,7 +94,7 @@ WANDB_ARGS=( VLLM_ARGS=( --rollout-backend vllm --rollout-num-gpus-per-engine 1 - --sglang-server-concurrency 512 + --rollout-server-concurrency 512 --use-slime-router --slime-router-middleware-paths slime.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware ) diff --git a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh index c41ea3df82..16bc65eaee 100644 --- a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh +++ b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( #--sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 --use-slime-router ) diff --git a/scripts/run-deepseek-r1.sh b/scripts/run-deepseek-r1.sh index c307d110ec..dc6d6cb744 100644 --- a/scripts/run-deepseek-r1.sh +++ b/scripts/run-deepseek-r1.sh @@ -126,7 +126,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 ) MISC_ARGS=( diff --git a/scripts/run-kimi-k2-Instruct.sh b/scripts/run-kimi-k2-Instruct.sh index 3a591b923a..0c9fc2c3d4 100644 --- a/scripts/run-kimi-k2-Instruct.sh +++ b/scripts/run-kimi-k2-Instruct.sh @@ -132,7 +132,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 ) diff --git a/scripts/run-kimi-k2-Thinking.sh b/scripts/run-kimi-k2-Thinking.sh index 25cc3c475f..533019d697 100644 --- a/scripts/run-kimi-k2-Thinking.sh +++ b/scripts/run-kimi-k2-Thinking.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --rollout-server-concurrency 1024 ) diff --git a/slime/backends/fsdp_utils/update_weight_utils.py b/slime/backends/fsdp_utils/update_weight_utils.py index 75e217beb9..64a6770c06 100644 --- a/slime/backends/fsdp_utils/update_weight_utils.py +++ b/slime/backends/fsdp_utils/update_weight_utils.py @@ -10,23 +10,31 @@ from ray.actor import ActorHandle from torch.distributed.tensor import DTensor, Replicate -try: - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions # type: ignore[import] -except ImportError: - from sglang.srt.patch_torch import monkey_patch_torch_reductions # type: ignore[import] +from slime.utils.distributed_utils import init_process_group -from sglang.srt.utils import MultiprocessingSerializer -from slime.utils.distributed_utils import init_process_group +logger = logging.getLogger(__name__) -try: - from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] -except ImportError: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +def _import_sglang_weight_sync_utils(): + """Lazy-import SGLang serialization utilities. + Centralizes the try/except version fallbacks so callers get a clean tuple. + Raises ImportError if sglang is not installed at all. + """ + try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions # type: ignore[import] + except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions # type: ignore[import] -logger = logging.getLogger(__name__) + from sglang.srt.utils import MultiprocessingSerializer + + try: + from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] + except ImportError: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + + return monkey_patch_torch_reductions, MultiprocessingSerializer, FlattenedTensorBucket class UpdateWeight(abc.ABC): @@ -135,6 +143,10 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: if self._ipc_gather_group is None: return + monkey_patch_torch_reductions, MultiprocessingSerializer, FlattenedTensorBucket = ( + _import_sglang_weight_sync_utils() + ) + monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") diff --git a/slime/backends/megatron_utils/arguments.py b/slime/backends/megatron_utils/arguments.py index 3e6e7a0d19..eb4fad6e60 100644 --- a/slime/backends/megatron_utils/arguments.py +++ b/slime/backends/megatron_utils/arguments.py @@ -50,6 +50,12 @@ def equal(x, y): ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): + if not hasattr(args, megatron_config_name): + logger.warning( + f"Megatron args missing '{megatron_config_name}' (mapped from HF '{hf_config_name}') , " + f"Skip validate" + ) + continue if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to " diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index 31db8b0da8..cc7da48133 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -209,12 +209,16 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage def wrap_model_provider_with_freeze(original_provider, args): - def wrapped_provider(pre_process=True, post_process=True, vp_stage=None): + def wrapped_provider(pre_process=True, post_process=True, vp_stage=None, **kwargs): sig = inspect.signature(original_provider) + call_kwargs = {"pre_process": pre_process, "post_process": post_process} if "vp_stage" in sig.parameters: - model = original_provider(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) - else: - model = original_provider(pre_process=pre_process, post_process=post_process) + call_kwargs["vp_stage"] = vp_stage + # Forward any extra kwargs (e.g. config) accepted by the provider + for k, v in kwargs.items(): + if k in sig.parameters: + call_kwargs[k] = v + model = original_provider(**call_kwargs) freeze_model_params(model, args) diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 97c82a31cd..4c045c4817 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -1,4 +1,9 @@ # the file to manage all sglang deps in the megatron actor +# When sglang is installed we prefer its implementations; otherwise we +# fall back to API-compatible reimplementations in +# slime.backends.megatron_utils.weight_sync_utils. + +# ── FP8 quantisation helpers (sglang-only, no local fallback) ─────── try: from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 @@ -7,19 +12,29 @@ transform_scale_ue8m0 = None should_deepgemm_weight_requant_ue8m0 = None +# ── monkey_patch_torch_reductions ─────────────────────────────────── try: from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions except ImportError: - from sglang.srt.patch_torch import monkey_patch_torch_reductions - - -from sglang.srt.utils import MultiprocessingSerializer + try: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + except ImportError: + from .weight_sync_utils import monkey_patch_torch_reductions +# ── MultiprocessingSerializer ─────────────────────────────────────── +try: + from sglang.srt.utils import MultiprocessingSerializer +except ImportError: + from .weight_sync_utils import MultiprocessingSerializer +# ── FlattenedTensorBucket ─────────────────────────────────────────── try: from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] except ImportError: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + except ImportError: + from .weight_sync_utils import FlattenedTensorBucket __all__ = [ "quant_weight_ue8m0", diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index a2e4e129bc..529bd793d6 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -12,6 +12,38 @@ from slime.utils.types import ParamInfo +def _cat_partitions( + partitions: list[torch.Tensor], partition_dim: int, stride: int +) -> torch.Tensor: + """Concatenate TP-gathered partitions, reversing Megatron-Core's strided layout. + + When ``stride == 1`` each rank holds a single contiguous shard and a plain + ``torch.cat`` along *partition_dim* is sufficient. + + When ``stride > 1``, Megatron-Core's ``_initialize_affine_weight_cpu`` + splits the master weight into ``world_size * stride`` equal sub-chunks and + assigns rank *r* the sub-chunks at indices ``r, r + world_size, + r + 2 * world_size, …`` (i.e. ``weight_list[rank::world_size]``). Each + rank therefore holds ``stride`` sub-chunks concatenated together. + + To reconstruct the original full tensor we split each rank's shard back + into ``stride`` sub-chunks and interleave them: + ``for i in range(stride): for r in range(world_size): append chunk[r][i]`` + """ + if stride == 1: + return torch.cat(partitions, dim=partition_dim) + + world_size = len(partitions) + # Each rank holds exactly `stride` sub-chunks concatenated along partition_dim. + rank_chunks = [p.chunk(stride, dim=partition_dim) for p in partitions] + # Reconstruct original ordering: chunk index (i * world_size + r) + ordered: list[torch.Tensor] = [] + for i in range(stride): + for r in range(world_size): + ordered.append(rank_chunks[r][i]) + return torch.cat(ordered, dim=partition_dim) + + def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: """ All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data. @@ -34,17 +66,20 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] dist.all_gather(param_partitions, param.data, group=tp_group) partition_dim = param.partition_dim - assert param.partition_stride == 1, "partition_stride != 1 is not supported" + stride = getattr(param, "partition_stride", 1) # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in name: param_partitions = [p.chunk(2, dim=0) for p in param_partitions] param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] + # GLU rechunking already reverses the stride=2 interleaving, so use + # plain concatenation from here on. + stride = 1 # this is bug in megatron's grouped moe. if "linear_fc2.weight" in name: if partition_dim == 0: partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + param = _cat_partitions(param_partitions, partition_dim, stride) return param @@ -63,10 +98,10 @@ def all_gather_params_async( for info, param in param_infos_and_params: # Prepare async all_gather if "expert_bias" in info.name: - gather_tasks.append((info, param, None, None, None)) + gather_tasks.append((info, param, None, None, None, 1)) handles.append(None) elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated": - gather_tasks.append((info, param.data, None, None, None)) + gather_tasks.append((info, param.data, None, None, None, 1)) handles.append(None) else: # Start async all_gather @@ -79,7 +114,8 @@ def all_gather_params_async( param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] handle = dist.all_gather(param_partitions, param.data, group=tp_group, async_op=True) - gather_tasks.append((info, None, handle, param_partitions, param.partition_dim)) + partition_stride = getattr(param, "partition_stride", 1) + gather_tasks.append((info, None, handle, param_partitions, param.partition_dim, partition_stride)) handles.append(handle) # Phase 2: Wait for ALL async operations to complete at once @@ -90,23 +126,24 @@ def all_gather_params_async( # Phase 3: Process all results after all communications are done gathered_params = [] - for info, direct_param, handle, param_partitions, partition_dim in gather_tasks: + for info, direct_param, handle, param_partitions, partition_dim, partition_stride in gather_tasks: if handle is None: # No all_gather needed param = direct_param else: # Process the gathered partitions (same logic as original all_gather_param) - assert partition_dim is not None, "partition_stride != 1 is not supported" # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in info.name: param_partitions = [p.chunk(2, dim=0) for p in param_partitions] param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] + # GLU rechunking already reverses the stride=2 interleaving. + partition_stride = 1 # this is bug in megatron's grouped moe. if "linear_fc2.weight" in info.name: if partition_dim == 0: partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + param = _cat_partitions(param_partitions, partition_dim, partition_stride) gathered_params.append(param) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 03d480801f..e827fac6ba 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -83,14 +83,18 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, c elif op == "send_packed": from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine, ) - NCCLWeightTransferEngine.trainer_send_weights( - iterator=iter(cmd["named_tensors"]), + trainer_args = NCCLTrainerSendWeightsArgs( group=comm, packed=True, ) + NCCLWeightTransferEngine.trainer_send_weights( + iterator=iter(cmd["named_tensors"]), + trainer_args=trainer_args, + ) torch.cuda.synchronize() conn.send("ok") diff --git a/slime/backends/megatron_utils/weight_sync_utils.py b/slime/backends/megatron_utils/weight_sync_utils.py new file mode 100644 index 0000000000..5fef12e58f --- /dev/null +++ b/slime/backends/megatron_utils/weight_sync_utils.py @@ -0,0 +1,268 @@ +""" +Local reimplementations of sglang weight-sync utilities. + +Used as fallback when sglang is not installed (e.g. vLLM-only mode). +The three classes/functions here are API-compatible with their sglang +counterparts so that the rest of the megatron weight-update code can +work unchanged. + +Origin (sglang): + - FlattenedTensorBucket → sglang.srt.weight_sync.tensor_bucket + - MultiprocessingSerializer / SafeUnpickler → sglang.srt.utils.common + - monkey_patch_torch_reductions → sglang.srt.utils.patch_torch +""" + +from __future__ import annotations + +import base64 +import io +import pickle +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import Callable, Union + +import torch +from torch.multiprocessing import reductions + +# ── FlattenedTensorBucket ─────────────────────────────────────────── + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket.""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single uint8 tensor + for efficient serialisation, while preserving all metadata needed + for reconstruction. + + API-compatible with ``sglang.srt.weight_sync.tensor_bucket.FlattenedTensorBucket``. + """ + + # Checked by callers to decide whether to group tensors by dtype. + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: list[tuple[str, torch.Tensor]] | None = None, + flattened_tensor: torch.Tensor | None = None, + metadata: list[FlattenedTensorMetadata] | None = None, + ): + if named_tensors is not None: + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + self.metadata: list[FlattenedTensorMetadata] = [None] * len(named_tensors) + current_idx = 0 + flat_parts: list[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flat = tensor.flatten().view(torch.uint8) + numel = flat.numel() + flat_parts[i] = flat + self.metadata[i] = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + current_idx += numel + + self.flattened_tensor: torch.Tensor = torch.cat(flat_parts, dim=0) + else: + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Return the single flat uint8 tensor.""" + return self.flattened_tensor + + def get_metadata(self) -> list[FlattenedTensorMetadata]: + """Return per-tensor metadata list.""" + return self.metadata + + def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]: + """Reconstruct the original named tensors from the flat representation.""" + reconstructed = [None] * len(self.metadata) + for i, meta in enumerate(self.metadata): + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) + ) + reconstructed[i] = (meta.name, tensor) + return reconstructed + + +# ── SafeUnpickler / MultiprocessingSerializer ─────────────────────── + + +class SafeUnpickler(pickle.Unpickler): + """ + Unpickler with an allow-list to prevent arbitrary code execution. + + API-compatible with the ``SafeUnpickler`` in ``sglang.srt.utils.common``. + """ + + ALLOWED_MODULE_PREFIXES = { + # Python builtins + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # PyTorch + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # torch.distributed + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # multiprocessing + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # HuggingFace / PEFT + "peft.", + "transformers.", + "huggingface_hub.", + # slime local reimplementation + "slime.backends.megatron_utils.weight_sync_utils.", + # sglang (if installed alongside) + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # NPU + "torch_npu.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module: str, name: str): + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + + +class MultiprocessingSerializer: + """ + Serialize / deserialize Python objects via ``ForkingPickler`` so that + CUDA tensors are transferred through shared memory (IPC handles). + + API-compatible with ``sglang.srt.utils.common.MultiprocessingSerializer``. + + Uses stdlib ``base64`` instead of ``pybase64`` to avoid adding a dependency. + """ + + @staticmethod + def serialize(obj, output_str: bool = False): + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + if output_str: + output = base64.b64encode(output).decode("utf-8") + return output + + @staticmethod + def deserialize(data): + if isinstance(data, str): + data = base64.b64decode(data, validate=True) + return SafeUnpickler(io.BytesIO(data)).load() + + +# ── monkey_patch_torch_reductions ─────────────────────────────────── + +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise RuntimeError("Invalid device_uuid=" + device_maybe_uuid) + raise RuntimeError(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return (*t[:index], modifier(t[index]), *t[index + 1 :]) + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def monkey_patch_torch_reductions(): + """ + Monkey-patch ``torch.multiprocessing.reductions`` so that CUDA tensors + are identified by device UUID rather than ordinal index. + + This works around https://github.com/pytorch/pytorch/pull/149248. + + API-compatible with ``sglang.srt.utils.patch_torch.monkey_patch_torch_reductions``. + """ + if hasattr(reductions, "_reduce_tensor_original"): + return # already patched + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + reductions.init_reductions() diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 71543e02d3..9185f53365 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -42,7 +42,6 @@ def add_sglang_arguments(parser): """ parser = add_sglang_router_arguments(parser) parser.set_defaults(router_balance_abs_threshold=10, router_balance_rel_threshold=1.2) - parser.add_argument("--sglang-server-concurrency", type=int, default=512) old_add_argument = parser.add_argument diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 9d1a786ae0..149780e188 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -13,9 +13,13 @@ import torch import yaml from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS -from slime.backends.sglang_utils.sglang_engine import SGLangEngine +# GPU memory-tag constants (originally from sglang.srt.constants). +# Duplicated here to avoid a hard sglang dependency when running the vLLM backend. +GPU_MEMORY_TYPE_KV_CACHE = "kv_cache" +GPU_MEMORY_TYPE_WEIGHTS = "weights" +GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph" + from slime.backends.vllm_utils.vllm_engine import VLLMEngine from slime.rollout.base_types import call_rollout_fn from slime.utils import logging_utils @@ -234,6 +238,8 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis pg, reordered_bundle_indices, reordered_gpu_ids = self.pg + from slime.backends.sglang_utils.sglang_engine import SGLangEngine + RolloutRayActor = ray.remote(SGLangEngine) rollout_engines = [] @@ -967,14 +973,14 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool ``force_new`` is False, skip launching and return the existing values. When ``force_new`` is True (multi-model), always allocate a fresh port. """ - if not force_new and args.sglang_router_ip is not None: + if not force_new and getattr(args, "sglang_router_ip", None) is not None: return args.sglang_router_ip, args.sglang_router_port router_ip = _wrap_ipv6(get_host_info()[1]) if force_new: router_port = find_available_port(random.randint(3000, 4000)) else: - router_port = args.sglang_router_port + router_port = getattr(args, "sglang_router_port", None) if router_port is None: router_port = find_available_port(random.randint(3000, 4000)) @@ -989,7 +995,7 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool router_args.sglang_router_port = router_port else: - from sglang_router.launch_router import RouterArgs + from sglang_router.launch_router import RouterArgs # noqa: delayed import — only needed for sglang router from slime.utils.http_utils import run_router diff --git a/slime/rollout/backends/__init__.py b/slime/rollout/backends/__init__.py index 49c4a86eef..ea9ccd8a9c 100644 --- a/slime/rollout/backends/__init__.py +++ b/slime/rollout/backends/__init__.py @@ -1,7 +1,15 @@ from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient -from slime.rollout.backends.sglang_client import SGLangClient from slime.rollout.backends.vllm_client import VLLMClient + +def __getattr__(name): + if name == "SGLangClient": + from slime.rollout.backends.sglang_client import SGLangClient + + return SGLangClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "BackendCapabilities", "RolloutBackendClient", diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 0b59be4477..f3aec2d55b 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -76,7 +76,7 @@ def __init__(self, args: Namespace) -> None: self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) self.semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, diff --git a/slime/router/router.py b/slime/router/router.py index 669094e442..b77dd975f4 100644 --- a/slime/router/router.py +++ b/slime/router/router.py @@ -45,7 +45,7 @@ def __init__(self, args, verbose=False): max_connections = getattr(args, "slime_router_max_connections", None) if max_connections is None: max_connections = ( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) timeout = getattr(args, "slime_router_timeout", None) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 4bb5d3506d..b7abc68b98 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -5,10 +5,7 @@ from typing import Any import yaml -from sglang_router.launch_router import RouterArgs -from slime.backends.sglang_utils.arguments import sglang_parse_args -from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -270,6 +267,13 @@ def add_rollout_arguments(parser): help="Fraction of GPU memory for vLLM KV cache (default: 0.85). " "Lower this if the training model leaves insufficient free memory.", ) + parser.add_argument( + "--rollout-server-concurrency", + type=int, + default=512, + help="Maximum number of concurrent requests sent to the rollout server. " + "Controls request parallelism for any backend (sglang or vllm).", + ) parser.add_argument( "--rollout-function-path", type=str, @@ -1069,7 +1073,12 @@ def add_router_arguments(parser): default=3, help="Number of consecutive failures before marking a worker as unhealthy.", ) - RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) + try: + from sglang_router.launch_router import RouterArgs + + RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) + except ImportError: + pass # sglang not installed — router args unavailable (vllm-only mode) return parser # wandb @@ -1463,6 +1472,7 @@ def _pre_parse_mode(): temp_parser.add_argument("--debug-rollout-only", action="store_true", default=False) temp_parser.add_argument("--debug-train-only", action="store_true", default=False) temp_parser.add_argument("--load-debug-rollout-data", type=str, default=None) + temp_parser.add_argument("--rollout-backend", type=str, choices=["sglang", "vllm"], default="sglang") temp_args, _ = temp_parser.parse_known_args() return temp_args @@ -1474,12 +1484,18 @@ def parse_args(add_custom_arguments=None): add_slime_arguments = get_slime_extra_args_provider(add_custom_arguments) pre = _pre_parse_mode() - skip_sglang = pre.debug_train_only or pre.load_debug_rollout_data is not None + skip_sglang = ( + pre.debug_train_only + or pre.load_debug_rollout_data is not None + or pre.rollout_backend == "vllm" + ) # Phase 1: Parse sglang args independently (separate parser, parse_known_args). - # Skipped when sglang servers are not needed. + # Skipped when sglang servers are not needed or when using vLLM backend. sglang_ns = None if not skip_sglang: + from slime.backends.sglang_utils.arguments import sglang_parse_args + sglang_ns = sglang_parse_args() # Phase 2: Parse megatron/fsdp + slime args. @@ -1517,6 +1533,8 @@ def parse_args(add_custom_arguments=None): megatron_validate_args(args) if not args.debug_train_only and getattr(args, "rollout_backend", "sglang") == "sglang": + from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args + sglang_validate_args(args) elif getattr(args, "rollout_backend", "sglang") == "vllm": # Set sglang aliases that the rest of the codebase expects diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index d7807f3b7a..e6ebe6bf83 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -204,7 +204,7 @@ def init_http_client(args): if not args.rollout_num_gpus: return - _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + _client_concurrency = args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency),