From c15407873cf697e81d51fe31108635322d6da576 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Mon, 30 Mar 2026 16:52:45 +0800 Subject: [PATCH 01/13] Plan refactor vllm/sglang Signed-off-by: knlnguyen1802 --- rfc-rollout-backend-separation-plan.md | 271 +++++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 rfc-rollout-backend-separation-plan.md diff --git a/rfc-rollout-backend-separation-plan.md b/rfc-rollout-backend-separation-plan.md new file mode 100644 index 000000000..ef715c2a6 --- /dev/null +++ b/rfc-rollout-backend-separation-plan.md @@ -0,0 +1,271 @@ +# RFC: Rollout Separation Plan (EngineGroup Generalization + Executor Cleanup) + +## 1. Summary + +This RFC proposes backend separation with minimal churn in runtime orchestration. + +Scope is four items: + +1. Refactor [slime/ray/rollout.py](slime/ray/rollout.py) by generalizing `EngineGroup.start_engines()` and abstracting engine/server creation (no new runtime manager class hierarchy). +2. Refactor [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) in-place: extract the one SGLang-specific code path (RadixTree) into a strategy hook, rename SGLang-prefixed args to generic names. No class hierarchy, no new files. +3. Refactor [slime/utils/arguments.py](slime/utils/arguments.py) into shared args + backend arg groups/finalizers. +4. Decouple [slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) from SGLang internals so FSDP weight sync works with both SGLang and vLLM engines. + +## 2. Already Done (Reuse, Do Not Rewrite) + +- Unified rollout contracts in [slime/rollout/base_types.py](slime/rollout/base_types.py). +- Backend client abstraction in [slime/rollout/backends/base_client.py](slime/rollout/backends/base_client.py). +- Backend adapters in [slime/rollout/backends/sglang_client.py](slime/rollout/backends/sglang_client.py) and [slime/rollout/backends/vllm_client.py](slime/rollout/backends/vllm_client.py). +- Managed vLLM engine actor in [slime/backends/vllm_utils/vllm_engine.py](slime/backends/vllm_utils/vllm_engine.py). +- vLLM translation sidecar in [slime/backends/vllm_utils/vllm_translation_sidecar.py](slime/backends/vllm_utils/vllm_translation_sidecar.py). + +## 3. Problem + +- [slime/ray/rollout.py](slime/ray/rollout.py) mixes shared and backend-specific engine creation paths. +- [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) is ~95% backend-agnostic but has one inlined SGLang-specific code path (RadixTree, ~14 lines) and uses SGLang-prefixed arg names for generic rollout concepts. +- [slime/utils/arguments.py](slime/utils/arguments.py) still has SGLang alias behavior in vLLM path. +- [slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) hard-imports SGLang internals (`FlattenedTensorBucket`, `MultiprocessingSerializer`, `monkey_patch_torch_reductions`) and calls SGLang-specific engine RPC names (`update_weights_from_tensor`, `update_weights_from_distributed`), making FSDP weight sync unusable with vLLM engines. + +## 4. Goals and Non-Goals + +### Goals + +- Keep runtime refactor minimal and localized to `EngineGroup` + creation abstraction. +- Isolate the one SGLang-specific executor code path behind a strategy hook; keep functions as functions. +- Rename SGLang-prefixed arg names to generic rollout names to eliminate naming coupling. +- Reduce backend leakage in argument finalization. +- Preserve current external behavior, call sites, and import paths. + +### Non-Goals + +- No rewrite of `SGLangEngine` or `VLLMEngine` internals. +- No algorithmic changes to GRPO/PPO. +- No mandatory feature parity for unsupported backend capabilities. + +## 5. Design + +### 5.1 Runtime: Generalize `EngineGroup` (No New Runtime Manager Classes) + +Keep [slime/ray/rollout.py](slime/ray/rollout.py) as the orchestration entry file. + +Refactor focus: + +1. Generalize `EngineGroup.start_engines()` to call backend-aware creation hooks. +2. Abstract engine creation and rollout-server assembly helpers. +3. Keep existing startup function API (`start_rollout_servers`) and return shape. + +Proposed helper abstraction points: + +- `create_engine_actor_cls(args, worker_type)` + - returns `ray.remote(SGLangEngine)` or `ray.remote(VLLMEngine)`. +- `create_engine_remote(args, actor_cls, scheduling_strategy, ...)` + - encapsulates `.options(...).remote(...)` with backend-specific init kwargs. +- `build_rollout_server(...)` + - standardizes `RolloutServer` construction from engine groups. + +`EngineGroup` and `RolloutServer` remain the main shared dataclasses. + +### 5.2 Executor: Isolate Backend Logic In-Place (No Class Hierarchy) + +#### Current state analysis + +[slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) (577 lines) is **already ~95% backend-agnostic**: + +| Function / Class | Lines | Backend-specific? | Notes | +|---|---|---|---| +| `_get_backend_client()` | 6 | Factory only | Delegates to existing `RolloutBackendClient` subclasses | +| `_apply_backend_response()` | 20 | No | Uses `RolloutBackendResponse` contract | +| `GenerateState` | 37 | **Naming only** | References `sglang_server_concurrency`, `sglang_dp_size`, `sglang_enable_deterministic_inference` — all are generic rollout concepts with SGLang-prefixed names | +| `generate()` | 58 | **14 lines** | RadixTree middleware path (L170-183) is 100% SGLang-specific; the else branch (L185-195) already uses `RolloutBackendClient` | +| `generate_and_rm()` | 60 | No | Shared orchestration (semaphore, custom func, reward) | +| `generate_and_rm_group()` | 37 | No | Group parallelism + deterministic seeds | +| `abort()` | 38 | No | Already uses `backend.abort()` | +| `generate_rollout_async()` | 71 | No | Main loop, filtering, metrics | +| `eval_rollout()` / `eval_rollout_single_dataset()` | 118 | **Naming only** | `sglang_enable_deterministic_inference` reference | +| `generate_rollout()` | 41 | No | Sync entry point | + +**Conclusion**: the actual backend logic that needs isolation is **one code path** (~14 lines) inside `generate()`. Everything else is either already abstracted through `RolloutBackendClient` or is a naming-only coupling (SGLang-prefixed arg names for generic concepts). + +#### Approach + +1. **Extract the RadixTree path into a strategy hook** that `generate()` calls conditionally. +2. **Rename SGLang-prefixed args** to generic names (coordinated with Phase 3 args refactor). +3. **Keep functions as functions** — they compose well and callers (`train.py`, OPD, multi-agent) import them directly. + +#### Concrete changes + +**Step 1 — Extract RadixTree strategy from `generate()`** + +Current `generate()` has an `if use_radix: ... else: backend.generate(...)` branch. +Refactor into: + +```python +# slime/rollout/sglang_rollout.py — generate() simplified + +async def generate(args, sample, sampling_params): + ... + input_ids = ... # shared prompt encoding (unchanged) + + strategy = _get_generate_strategy(args) + resp = await strategy(args, sample, input_ids, sampling_params) + _apply_backend_response(sample, resp, args) + return sample +``` + +Two strategies: + +```python +# Still in sglang_rollout.py (no new file needed) + +def _get_generate_strategy(args): + """Return the generate coroutine to use.""" + if _is_radix_tree_enabled(args): + return _generate_radix_tree # SGLang-only path + return _generate_via_backend_client # Generic path (SGLang or vLLM) + +def _is_radix_tree_enabled(args) -> bool: + return ( + args.use_slime_router + and "RadixTreeMiddleware" in getattr(args, "slime_router_middleware_paths", []) + ) + +async def _generate_radix_tree(args, sample, input_ids, sampling_params) -> RolloutBackendResponse: + """SGLang RadixTree middleware path — returns normalized response.""" + from slime.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + url = f"http://{args.rollout_router_ip}:{args.rollout_router_port}/generate" + payload = { ... } # existing payload construction + output = await post(url, payload, headers=headers) + sample = await postprocess_sample_with_radix_tree(args, sample, output) + return _extract_response_from_sample(sample) # normalize to RolloutBackendResponse + +async def _generate_via_backend_client(args, sample, input_ids, sampling_params) -> RolloutBackendResponse: + """Generic backend client path — works for SGLang and vLLM.""" + backend = _get_backend_client(args) + base_url = f"http://{args.rollout_router_ip}:{args.rollout_router_port}" + req = RolloutBackendRequest(...) + return await backend.generate(req, base_url, headers=headers) +``` + +**Step 2 — Rename SGLang-prefixed args to generic names** + +| Current name | New name | Reason | +|---|---|---| +| `sglang_server_concurrency` | `rollout_concurrency` | Controls request parallelism for any backend | +| `sglang_dp_size` | `rollout_dp_size` | Data-parallel sharding, not SGLang-specific | +| `sglang_router_ip` / `sglang_router_port` | `rollout_router_ip` / `rollout_router_port` | Router endpoint, backend-agnostic | +| `sglang_router_policy` | `rollout_router_policy` | Routing strategy | +| `sglang_enable_deterministic_inference` | `rollout_deterministic_inference` | Seed-based determinism | +| `vllm_base_url` | (remove) | Folded into `rollout_router_ip:port`, no special case | + +Legacy aliases kept in [slime/utils/arguments.py](slime/utils/arguments.py) for one release cycle (coordinated with Phase 3). + +**Step 3 — No new files** + +The file stays as [slime/rollout/sglang_rollout.py](slime/rollout/sglang_rollout.py) during this phase. +Optionally rename to `slime/rollout/rollout.py` in Phase 4 cleanup, since the file is backend-agnostic after the refactor. + +### 5.3 Arguments/Config Refactor + +In [slime/utils/arguments.py](slime/utils/arguments.py), split into: + +1. Shared rollout args. +2. SGLang backend args/validation. +3. vLLM backend args/validation. + +Add backend finalizers: + +- `finalize_sglang_args(args)` +- `finalize_vllm_args(args)` + +Move SGLang alias fallback out of shared finalize flow. + +### 5.4 Weight Sync: Decouple FSDP `update_weight_utils.py` from SGLang Internals + +#### Current state analysis + +[slime/backends/fsdp_utils/update_weight_utils.py](slime/backends/fsdp_utils/update_weight_utils.py) (287 lines) has two concrete classes: + +| Class | Weight-push method | SGLang coupling | +|---|---|---| +| `UpdateWeightFromTensor` | IPC via Gloo gather → `engine.update_weights_from_tensor.remote()` | Imports `FlattenedTensorBucket`, `MultiprocessingSerializer`, `monkey_patch_torch_reductions` directly from `sglang.srt.*` | +| `UpdateWeightFromDistributed` | NCCL broadcast → `engine.update_weights_from_distributed.remote()` | Calls `engine.init_weights_update_group.remote()` — SGLang engine API | + +The abstract base `UpdateWeight` itself is clean (only PyTorch + Ray). + +The Megatron side already solved this: [slime/backends/megatron_utils/sglang.py](slime/backends/megatron_utils/sglang.py) centralizes all SGLang imports into one shim. The FSDP side duplicates these imports inline. + +#### Coupling points + +1. **SGLang utility imports (L13-26)** — `monkey_patch_torch_reductions`, `MultiprocessingSerializer`, `FlattenedTensorBucket` are imported directly from `sglang.srt.*` with try/except version fallbacks. +2. **`UpdateWeightFromTensor.update_bucket_weights()`** — uses `FlattenedTensorBucket` to flatten tensors, `MultiprocessingSerializer` to serialize, then calls `engine.update_weights_from_tensor.remote()`. +3. **`UpdateWeightFromDistributed.connect_rollout_engines()`** — calls `engine.init_weights_update_group.remote()` which is an SGLang engine method. +4. **`UpdateWeightFromDistributed.update_bucket_weights()`** — calls `engine.update_weights_from_distributed.remote()` which is an SGLang engine method. + +All four points assume the engine actor exposes SGLang's RPC interface. vLLM engines expose different method names. + +#### Approach + +**Step 1 — Centralize SGLang imports via a shim (same pattern as Megatron)** + +Reuse or mirror the existing [slime/backends/megatron_utils/sglang.py](slime/backends/megatron_utils/sglang.py) pattern: + +```python +# slime/backends/fsdp_utils/sglang_compat.py (new, ~15 lines) +try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions +except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + +from sglang.srt.utils import MultiprocessingSerializer + +try: + from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket +except ImportError: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket +``` + +Then `update_weight_utils.py` imports from `sglang_compat` — one import line instead of five, and only loaded when SGLang is the active backend. + +**Step 2 — Abstract engine RPC calls behind a protocol** + +The engine actors (`SGLangEngine`, `VLLMEngine`) already expose weight-sync methods but with different names/signatures. Introduce a lightweight protocol or adapter: + +```python +# In update_weight_utils.py or a small helper + +def _call_engine_update_tensor(engine, backend: str, **kwargs): + """Dispatch IPC weight update to the correct engine method.""" + if backend == "vllm": + return engine.update_weights_from_tensor.remote(**kwargs) # vLLM uses same name via NcclBridge + return engine.update_weights_from_tensor.remote(**kwargs) # SGLang native + +def _call_engine_update_distributed(engine, backend: str, **kwargs): + if backend == "vllm": + return engine.update_weights_from_distributed.remote(**kwargs) + return engine.update_weights_from_distributed.remote(**kwargs) + +def _call_engine_init_weight_group(engine, backend: str, **kwargs): + if backend == "vllm": + return engine.init_weights_update_group.remote(**kwargs) + return engine.init_weights_update_group.remote(**kwargs) +``` + +> Note: Currently `VLLMEngine` already mirrors these method names (it wraps them via `NcclBridge`), so the dispatch functions may initially be identical. The value is making the indirection explicit so future method-name divergence is handled in one place. + +**Step 3 — Lazy-import SGLang utilities only when backend is SGLang** + +Move the `FlattenedTensorBucket` / `MultiprocessingSerializer` imports inside `UpdateWeightFromTensor.update_bucket_weights()` behind a lazy import, so the module can be loaded in a vLLM-only environment without SGLang installed. + +#### What changes + +| File | Change | +|---|---| +| `slime/backends/fsdp_utils/sglang_compat.py` | New shim file (~15 lines) centralizing SGLang imports | +| `slime/backends/fsdp_utils/update_weight_utils.py` | Replace 5 inline SGLang imports with one `from .sglang_compat import ...`; add backend-aware engine dispatch helpers | + +#### What does NOT change + +- `UpdateWeight` abstract base class — already clean. +- `UpdateWeightFromDistributed` NCCL logic — the broadcast itself is pure PyTorch; only the engine RPC dispatch gets a thin wrapper. +- Megatron-side weight sync — already has its own shim, not touched. + From 1a2dcf55acd2e4d3e408e5ce534020d3a25ed4bd Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 09:43:21 +0800 Subject: [PATCH 02/13] Code implemented Signed-off-by: knlnguyen1802 --- .../fsdp_utils/update_weight_utils.py | 34 +++++++++++++------ slime/ray/rollout.py | 12 +++++-- slime/rollout/backends/__init__.py | 10 +++++- slime/utils/arguments.py | 23 +++++++++---- 4 files changed, 58 insertions(+), 21 deletions(-) diff --git a/slime/backends/fsdp_utils/update_weight_utils.py b/slime/backends/fsdp_utils/update_weight_utils.py index 75e217beb..64a6770c0 100644 --- a/slime/backends/fsdp_utils/update_weight_utils.py +++ b/slime/backends/fsdp_utils/update_weight_utils.py @@ -10,23 +10,31 @@ from ray.actor import ActorHandle from torch.distributed.tensor import DTensor, Replicate -try: - from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions # type: ignore[import] -except ImportError: - from sglang.srt.patch_torch import monkey_patch_torch_reductions # type: ignore[import] +from slime.utils.distributed_utils import init_process_group -from sglang.srt.utils import MultiprocessingSerializer -from slime.utils.distributed_utils import init_process_group +logger = logging.getLogger(__name__) -try: - from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] -except ImportError: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +def _import_sglang_weight_sync_utils(): + """Lazy-import SGLang serialization utilities. + Centralizes the try/except version fallbacks so callers get a clean tuple. + Raises ImportError if sglang is not installed at all. + """ + try: + from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions # type: ignore[import] + except ImportError: + from sglang.srt.patch_torch import monkey_patch_torch_reductions # type: ignore[import] -logger = logging.getLogger(__name__) + from sglang.srt.utils import MultiprocessingSerializer + + try: + from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] + except ImportError: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + + return monkey_patch_torch_reductions, MultiprocessingSerializer, FlattenedTensorBucket class UpdateWeight(abc.ABC): @@ -135,6 +143,10 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: if self._ipc_gather_group is None: return + monkey_patch_torch_reductions, MultiprocessingSerializer, FlattenedTensorBucket = ( + _import_sglang_weight_sync_utils() + ) + monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 9d1a786ae..ee6d2da50 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -13,9 +13,13 @@ import torch import yaml from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS -from slime.backends.sglang_utils.sglang_engine import SGLangEngine +# GPU memory-tag constants (originally from sglang.srt.constants). +# Duplicated here to avoid a hard sglang dependency when running the vLLM backend. +GPU_MEMORY_TYPE_KV_CACHE = "kv_cache" +GPU_MEMORY_TYPE_WEIGHTS = "weights" +GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph" + from slime.backends.vllm_utils.vllm_engine import VLLMEngine from slime.rollout.base_types import call_rollout_fn from slime.utils import logging_utils @@ -234,6 +238,8 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis pg, reordered_bundle_indices, reordered_gpu_ids = self.pg + from slime.backends.sglang_utils.sglang_engine import SGLangEngine + RolloutRayActor = ray.remote(SGLangEngine) rollout_engines = [] @@ -989,7 +995,7 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool router_args.sglang_router_port = router_port else: - from sglang_router.launch_router import RouterArgs + from sglang_router.launch_router import RouterArgs # noqa: delayed import — only needed for sglang router from slime.utils.http_utils import run_router diff --git a/slime/rollout/backends/__init__.py b/slime/rollout/backends/__init__.py index 49c4a86ee..ea9ccd8a9 100644 --- a/slime/rollout/backends/__init__.py +++ b/slime/rollout/backends/__init__.py @@ -1,7 +1,15 @@ from slime.rollout.backends.base_client import BackendCapabilities, RolloutBackendClient -from slime.rollout.backends.sglang_client import SGLangClient from slime.rollout.backends.vllm_client import VLLMClient + +def __getattr__(name): + if name == "SGLangClient": + from slime.rollout.backends.sglang_client import SGLangClient + + return SGLangClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "BackendCapabilities", "RolloutBackendClient", diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 4bb5d3506..7874237ab 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -5,10 +5,7 @@ from typing import Any import yaml -from sglang_router.launch_router import RouterArgs -from slime.backends.sglang_utils.arguments import sglang_parse_args -from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -1069,7 +1066,12 @@ def add_router_arguments(parser): default=3, help="Number of consecutive failures before marking a worker as unhealthy.", ) - RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) + try: + from sglang_router.launch_router import RouterArgs + + RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) + except ImportError: + pass # sglang not installed — router args unavailable (vllm-only mode) return parser # wandb @@ -1463,6 +1465,7 @@ def _pre_parse_mode(): temp_parser.add_argument("--debug-rollout-only", action="store_true", default=False) temp_parser.add_argument("--debug-train-only", action="store_true", default=False) temp_parser.add_argument("--load-debug-rollout-data", type=str, default=None) + temp_parser.add_argument("--rollout-backend", type=str, choices=["sglang", "vllm"], default="sglang") temp_args, _ = temp_parser.parse_known_args() return temp_args @@ -1474,12 +1477,18 @@ def parse_args(add_custom_arguments=None): add_slime_arguments = get_slime_extra_args_provider(add_custom_arguments) pre = _pre_parse_mode() - skip_sglang = pre.debug_train_only or pre.load_debug_rollout_data is not None + skip_sglang = ( + pre.debug_train_only + or pre.load_debug_rollout_data is not None + or pre.rollout_backend == "vllm" + ) # Phase 1: Parse sglang args independently (separate parser, parse_known_args). - # Skipped when sglang servers are not needed. + # Skipped when sglang servers are not needed or when using vLLM backend. sglang_ns = None if not skip_sglang: + from slime.backends.sglang_utils.arguments import sglang_parse_args + sglang_ns = sglang_parse_args() # Phase 2: Parse megatron/fsdp + slime args. @@ -1517,6 +1526,8 @@ def parse_args(add_custom_arguments=None): megatron_validate_args(args) if not args.debug_train_only and getattr(args, "rollout_backend", "sglang") == "sglang": + from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args + sglang_validate_args(args) elif getattr(args, "rollout_backend", "sglang") == "vllm": # Set sglang aliases that the rest of the codebase expects From 8addb37546e33e55f22ee977e6d2c958bf393646 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 14:42:28 +0800 Subject: [PATCH 03/13] Fix bug Signed-off-by: knlnguyen1802 --- slime/backends/megatron_utils/sglang.py | 25 +- .../megatron_utils/weight_sync_utils.py | 268 ++++++++++++++++++ 2 files changed, 288 insertions(+), 5 deletions(-) create mode 100644 slime/backends/megatron_utils/weight_sync_utils.py diff --git a/slime/backends/megatron_utils/sglang.py b/slime/backends/megatron_utils/sglang.py index 97c82a31c..4c045c481 100644 --- a/slime/backends/megatron_utils/sglang.py +++ b/slime/backends/megatron_utils/sglang.py @@ -1,4 +1,9 @@ # the file to manage all sglang deps in the megatron actor +# When sglang is installed we prefer its implementations; otherwise we +# fall back to API-compatible reimplementations in +# slime.backends.megatron_utils.weight_sync_utils. + +# ── FP8 quantisation helpers (sglang-only, no local fallback) ─────── try: from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 @@ -7,19 +12,29 @@ transform_scale_ue8m0 = None should_deepgemm_weight_requant_ue8m0 = None +# ── monkey_patch_torch_reductions ─────────────────────────────────── try: from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions except ImportError: - from sglang.srt.patch_torch import monkey_patch_torch_reductions - - -from sglang.srt.utils import MultiprocessingSerializer + try: + from sglang.srt.patch_torch import monkey_patch_torch_reductions + except ImportError: + from .weight_sync_utils import monkey_patch_torch_reductions +# ── MultiprocessingSerializer ─────────────────────────────────────── +try: + from sglang.srt.utils import MultiprocessingSerializer +except ImportError: + from .weight_sync_utils import MultiprocessingSerializer +# ── FlattenedTensorBucket ─────────────────────────────────────────── try: from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] except ImportError: - from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + try: + from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] + except ImportError: + from .weight_sync_utils import FlattenedTensorBucket __all__ = [ "quant_weight_ue8m0", diff --git a/slime/backends/megatron_utils/weight_sync_utils.py b/slime/backends/megatron_utils/weight_sync_utils.py new file mode 100644 index 000000000..5fef12e58 --- /dev/null +++ b/slime/backends/megatron_utils/weight_sync_utils.py @@ -0,0 +1,268 @@ +""" +Local reimplementations of sglang weight-sync utilities. + +Used as fallback when sglang is not installed (e.g. vLLM-only mode). +The three classes/functions here are API-compatible with their sglang +counterparts so that the rest of the megatron weight-update code can +work unchanged. + +Origin (sglang): + - FlattenedTensorBucket → sglang.srt.weight_sync.tensor_bucket + - MultiprocessingSerializer / SafeUnpickler → sglang.srt.utils.common + - monkey_patch_torch_reductions → sglang.srt.utils.patch_torch +""" + +from __future__ import annotations + +import base64 +import io +import pickle +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import Callable, Union + +import torch +from torch.multiprocessing import reductions + +# ── FlattenedTensorBucket ─────────────────────────────────────────── + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket.""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single uint8 tensor + for efficient serialisation, while preserving all metadata needed + for reconstruction. + + API-compatible with ``sglang.srt.weight_sync.tensor_bucket.FlattenedTensorBucket``. + """ + + # Checked by callers to decide whether to group tensors by dtype. + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: list[tuple[str, torch.Tensor]] | None = None, + flattened_tensor: torch.Tensor | None = None, + metadata: list[FlattenedTensorMetadata] | None = None, + ): + if named_tensors is not None: + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + self.metadata: list[FlattenedTensorMetadata] = [None] * len(named_tensors) + current_idx = 0 + flat_parts: list[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flat = tensor.flatten().view(torch.uint8) + numel = flat.numel() + flat_parts[i] = flat + self.metadata[i] = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + current_idx += numel + + self.flattened_tensor: torch.Tensor = torch.cat(flat_parts, dim=0) + else: + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Return the single flat uint8 tensor.""" + return self.flattened_tensor + + def get_metadata(self) -> list[FlattenedTensorMetadata]: + """Return per-tensor metadata list.""" + return self.metadata + + def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]: + """Reconstruct the original named tensors from the flat representation.""" + reconstructed = [None] * len(self.metadata) + for i, meta in enumerate(self.metadata): + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) + ) + reconstructed[i] = (meta.name, tensor) + return reconstructed + + +# ── SafeUnpickler / MultiprocessingSerializer ─────────────────────── + + +class SafeUnpickler(pickle.Unpickler): + """ + Unpickler with an allow-list to prevent arbitrary code execution. + + API-compatible with the ``SafeUnpickler`` in ``sglang.srt.utils.common``. + """ + + ALLOWED_MODULE_PREFIXES = { + # Python builtins + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # PyTorch + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # torch.distributed + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # multiprocessing + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # HuggingFace / PEFT + "peft.", + "transformers.", + "huggingface_hub.", + # slime local reimplementation + "slime.backends.megatron_utils.weight_sync_utils.", + # sglang (if installed alongside) + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # NPU + "torch_npu.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module: str, name: str): + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " + f"to prevent exploitation of CVE-2025-10164" + ) + + +class MultiprocessingSerializer: + """ + Serialize / deserialize Python objects via ``ForkingPickler`` so that + CUDA tensors are transferred through shared memory (IPC handles). + + API-compatible with ``sglang.srt.utils.common.MultiprocessingSerializer``. + + Uses stdlib ``base64`` instead of ``pybase64`` to avoid adding a dependency. + """ + + @staticmethod + def serialize(obj, output_str: bool = False): + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + if output_str: + output = base64.b64encode(output).decode("utf-8") + return output + + @staticmethod + def deserialize(data): + if isinstance(data, str): + data = base64.b64decode(data, validate=True) + return SafeUnpickler(io.BytesIO(data)).load() + + +# ── monkey_patch_torch_reductions ─────────────────────────────────── + +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise RuntimeError("Invalid device_uuid=" + device_maybe_uuid) + raise RuntimeError(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return (*t[:index], modifier(t[index]), *t[index + 1 :]) + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def monkey_patch_torch_reductions(): + """ + Monkey-patch ``torch.multiprocessing.reductions`` so that CUDA tensors + are identified by device UUID rather than ordinal index. + + This works around https://github.com/pytorch/pytorch/pull/149248. + + API-compatible with ``sglang.srt.utils.patch_torch.monkey_patch_torch_reductions``. + """ + if hasattr(reductions, "_reduce_tensor_original"): + return # already patched + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + reductions.init_reductions() From 9689a970eb3174f26203454722cd8321103b9490 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 14:43:08 +0800 Subject: [PATCH 04/13] Fix bug Signed-off-by: knlnguyen1802 --- slime/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 7874237ab..5ef842406 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1500,7 +1500,7 @@ def parse_args(add_custom_arguments=None): args = megatron_parse_args( extra_args_provider=add_slime_arguments, - skip_hf_validate=pre.debug_rollout_only, + skip_hf_validate=True, #pre.debug_rollout_only, ) else: logger.warning( From 3753432aaea1d70402e05f5343188a1c6972a8e3 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 14:59:43 +0800 Subject: [PATCH 05/13] Fix bug Signed-off-by: knlnguyen1802 --- slime/backends/megatron_utils/model_provider.py | 12 ++++++++---- slime/utils/http_utils.py | 3 ++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index 31db8b0da..cc7da4813 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -209,12 +209,16 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage def wrap_model_provider_with_freeze(original_provider, args): - def wrapped_provider(pre_process=True, post_process=True, vp_stage=None): + def wrapped_provider(pre_process=True, post_process=True, vp_stage=None, **kwargs): sig = inspect.signature(original_provider) + call_kwargs = {"pre_process": pre_process, "post_process": post_process} if "vp_stage" in sig.parameters: - model = original_provider(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) - else: - model = original_provider(pre_process=pre_process, post_process=post_process) + call_kwargs["vp_stage"] = vp_stage + # Forward any extra kwargs (e.g. config) accepted by the provider + for k, v in kwargs.items(): + if k in sig.parameters: + call_kwargs[k] = v + model = original_provider(**call_kwargs) freeze_model_params(model, args) diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index d7807f3b7..f393cbb50 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -204,7 +204,8 @@ def init_http_client(args): if not args.rollout_num_gpus: return - _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + server_concurrency = getattr(args, "sglang_server_concurrency", 256) + _client_concurrency = server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency), From f1e75542f292bcfdcddb977b5688c45fe9a5c2cf Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 15:10:46 +0800 Subject: [PATCH 06/13] Fix port Signed-off-by: knlnguyen1802 --- slime/ray/rollout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ee6d2da50..149780e18 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -973,14 +973,14 @@ def _start_router(args, *, has_pd_disaggregation: bool = False, force_new: bool ``force_new`` is False, skip launching and return the existing values. When ``force_new`` is True (multi-model), always allocate a fresh port. """ - if not force_new and args.sglang_router_ip is not None: + if not force_new and getattr(args, "sglang_router_ip", None) is not None: return args.sglang_router_ip, args.sglang_router_port router_ip = _wrap_ipv6(get_host_info()[1]) if force_new: router_port = find_available_port(random.randint(3000, 4000)) else: - router_port = args.sglang_router_port + router_port = getattr(args, "sglang_router_port", None) if router_port is None: router_port = find_available_port(random.randint(3000, 4000)) From 91cc780f764a775fc5ded42e6695cccf166b9827 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 31 Mar 2026 17:50:07 +0800 Subject: [PATCH 07/13] Fix config Signed-off-by: knlnguyen1802 --- docs/en/examples/deepseek-r1.md | 4 ++-- docs/en/examples/qwen3-4B.md | 4 ++-- docs/zh/examples/deepseek-r1.md | 4 ++-- docs/zh/examples/qwen3-4B.md | 4 ++-- examples/fully_async/fully_async_rollout.py | 2 +- examples/tau-bench/run_qwen3_4B.sh | 2 +- run-qwen2.5-0.5B-vllm.sh | 2 +- scripts/low_precision/run-kimi-k2-Thinking-int4.sh | 2 +- scripts/run-deepseek-r1.sh | 2 +- scripts/run-kimi-k2-Instruct.sh | 2 +- scripts/run-kimi-k2-Thinking.sh | 2 +- slime/backends/sglang_utils/arguments.py | 1 - slime/rollout/sglang_rollout.py | 2 +- slime/router/router.py | 2 +- slime/utils/arguments.py | 7 +++++++ slime/utils/http_utils.py | 3 +-- 16 files changed, 25 insertions(+), 20 deletions(-) diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index ab10502ee..48cca18d3 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( These are the parameters required by sglang. Here, `--rollout-num-gpus-per-engine` basically corresponds to sglang's `tp_size`. Other sglang parameters are passed to slime by adding a `--sglang-` prefix. To fully leverage sglang's large EP inference capabilities, we have added configurations like ep64, dp\_attention dp8, and deepep mode auto. -The final `--sglang-server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. +The final `--server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank have 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 ) ``` diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 56711878f..8a4022ab9 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -280,10 +280,10 @@ In this case, 2 GPUs will be allocated for training, and 6 GPUs will be allocate ⚠️ If the concurrency on each sglang server is too high, it may exceed sglang's default CUDA graph concurrency limit (the default maximum is 160), which will affect inference speed. You can adjust this in the following two ways: -1. Use `--sglang-server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: +1. Use `--server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: ```bash - --sglang-server-concurrency 160 + --server-concurrency 160 ``` 2. Use `--sglang-cuda-graph-bs` (which corresponds to sglang's native `--cuda-graph-bs` argument) to increase the number of CUDA graphs initialized by sglang. For example: diff --git a/docs/zh/examples/deepseek-r1.md b/docs/zh/examples/deepseek-r1.md index 368653844..fba2e736a 100644 --- a/docs/zh/examples/deepseek-r1.md +++ b/docs/zh/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( sglang 所需的参数,这里 `--rollout-num-gpus-per-engine` 基本对应 sglang 的 `tp_size`,除此之外的 sglang 参数均通过添加 `--sglang-` 的前缀来传给 slime。为了充分利用 sglang 的大 EP 推理能力,我们加上了 ep64、dp_attention dp8、deepep mode auto 等配置。 -最后的 `--sglang-server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 +最后的 `--server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 ) ``` diff --git a/docs/zh/examples/qwen3-4B.md b/docs/zh/examples/qwen3-4B.md index 2afc70bfb..ac0fad5b6 100644 --- a/docs/zh/examples/qwen3-4B.md +++ b/docs/zh/examples/qwen3-4B.md @@ -280,10 +280,10 @@ ray job submit ... \ ⚠️ 在进行训推分离的时候,每个 sglang server 上的并发度太大,超过了 sglang 默认的 cuda graph 的并发度(默认最大 160),影响推理速度。可以用以下 2 种方式进行调整: -1. 通过 `--sglang-server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: +1. 通过 `--server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: ```bash - --sglang-server-concurrency 160 + --server-concurrency 160 ``` 2. 使用 `--sglang-cuda-graph-bs`,即 sglang 原生的 `--cuda-graph-bs`, 增大 sglang 初始化的 cuda graph 数量,例如: diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 7208365c1..25b832709 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -20,7 +20,7 @@ def get_global_worker(args, data_buffer): with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): print("Creating new global async worker...") - _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency) + _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.server_concurrency) _global_worker.start() return _global_worker diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/run_qwen3_4B.sh index a82173401..81ad6f63f 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/run_qwen3_4B.sh @@ -100,7 +100,7 @@ SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.7 # If gemini API reports concurrency limit error, set this parameter to reduce the concurrency - # --sglang-server-concurrency 32 + # --server-concurrency 32 ) MISC_ARGS=( diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index a31559080..3fb6c896a 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -94,7 +94,7 @@ WANDB_ARGS=( VLLM_ARGS=( --rollout-backend vllm --rollout-num-gpus-per-engine 1 - --sglang-server-concurrency 512 + --server-concurrency 512 --use-slime-router --slime-router-middleware-paths slime.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware ) diff --git a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh index c41ea3df8..d09a8f3bf 100644 --- a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh +++ b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( #--sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 --use-slime-router ) diff --git a/scripts/run-deepseek-r1.sh b/scripts/run-deepseek-r1.sh index c307d110e..9b8bc64b0 100644 --- a/scripts/run-deepseek-r1.sh +++ b/scripts/run-deepseek-r1.sh @@ -126,7 +126,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 ) MISC_ARGS=( diff --git a/scripts/run-kimi-k2-Instruct.sh b/scripts/run-kimi-k2-Instruct.sh index 3a591b923..506b576d2 100644 --- a/scripts/run-kimi-k2-Instruct.sh +++ b/scripts/run-kimi-k2-Instruct.sh @@ -132,7 +132,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 ) diff --git a/scripts/run-kimi-k2-Thinking.sh b/scripts/run-kimi-k2-Thinking.sh index 25cc3c475..088317761 100644 --- a/scripts/run-kimi-k2-Thinking.sh +++ b/scripts/run-kimi-k2-Thinking.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --sglang-server-concurrency 1024 + --server-concurrency 1024 ) diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 71543e02d..9185f5336 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -42,7 +42,6 @@ def add_sglang_arguments(parser): """ parser = add_sglang_router_arguments(parser) parser.set_defaults(router_balance_abs_threshold=10, router_balance_rel_threshold=1.2) - parser.add_argument("--sglang-server-concurrency", type=int, default=512) old_add_argument = parser.add_argument diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 0b59be447..416d38407 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -76,7 +76,7 @@ def __init__(self, args: Namespace) -> None: self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) self.semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, diff --git a/slime/router/router.py b/slime/router/router.py index 669094e44..9f660e24b 100644 --- a/slime/router/router.py +++ b/slime/router/router.py @@ -45,7 +45,7 @@ def __init__(self, args, verbose=False): max_connections = getattr(args, "slime_router_max_connections", None) if max_connections is None: max_connections = ( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) timeout = getattr(args, "slime_router_timeout", None) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 5ef842406..2699072f6 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -267,6 +267,13 @@ def add_rollout_arguments(parser): help="Fraction of GPU memory for vLLM KV cache (default: 0.85). " "Lower this if the training model leaves insufficient free memory.", ) + parser.add_argument( + "--server-concurrency", + type=int, + default=512, + help="Maximum number of concurrent requests sent to the rollout server. " + "Controls request parallelism for any backend (sglang or vllm).", + ) parser.add_argument( "--rollout-function-path", type=str, diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index f393cbb50..d8b92dc69 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -204,8 +204,7 @@ def init_http_client(args): if not args.rollout_num_gpus: return - server_concurrency = getattr(args, "sglang_server_concurrency", 256) - _client_concurrency = server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + _client_concurrency = args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency), From be1ecd490a77bf59b897b9929c18da72f4f6a2de Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 1 Apr 2026 14:14:15 +0800 Subject: [PATCH 08/13] Fix bug MOE weight sync --- .../megatron_utils/update_weight/common.py | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index a2e4e129b..e0a24cf1c 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -12,6 +12,36 @@ from slime.utils.types import ParamInfo +def _cat_partitions( + partitions: list[torch.Tensor], partition_dim: int, stride: int +) -> torch.Tensor: + """Concatenate TP-gathered partitions, handling interleaved (stride > 1) layouts. + + When ``stride == 1`` each rank holds a single contiguous shard and a plain + ``torch.cat`` along *partition_dim* is sufficient. + + When ``stride > 1`` (e.g. Megatron-Core GroupedMLP expert weights) each + rank's shard contains multiple interleaved groups of ``stride`` consecutive + elements. We split each shard into those groups and interleave across + ranks to reconstruct the original full tensor. + """ + if stride == 1: + return torch.cat(partitions, dim=partition_dim) + + # Number of contiguous groups stored on each rank + num_groups = partitions[0].shape[partition_dim] // stride + if num_groups <= 0: + # Fallback: stride >= shard size — simple concat should still be correct + return torch.cat(partitions, dim=partition_dim) + + rank_groups = [p.chunk(num_groups, dim=partition_dim) for p in partitions] + interleaved: list[torch.Tensor] = [] + for g in range(num_groups): + for r in range(len(partitions)): + interleaved.append(rank_groups[r][g]) + return torch.cat(interleaved, dim=partition_dim) + + def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: """ All-gather TP-sharded param to full tensor. expert_bias→param, non-TP/duplicated→param.data. @@ -34,7 +64,7 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] dist.all_gather(param_partitions, param.data, group=tp_group) partition_dim = param.partition_dim - assert param.partition_stride == 1, "partition_stride != 1 is not supported" + stride = getattr(param, "partition_stride", 1) # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in name: @@ -44,7 +74,7 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: if "linear_fc2.weight" in name: if partition_dim == 0: partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + param = _cat_partitions(param_partitions, partition_dim, stride) return param @@ -63,10 +93,10 @@ def all_gather_params_async( for info, param in param_infos_and_params: # Prepare async all_gather if "expert_bias" in info.name: - gather_tasks.append((info, param, None, None, None)) + gather_tasks.append((info, param, None, None, None, 1)) handles.append(None) elif not param.tensor_model_parallel or getattr(param, "parallel_mode", None) == "duplicated": - gather_tasks.append((info, param.data, None, None, None)) + gather_tasks.append((info, param.data, None, None, None, 1)) handles.append(None) else: # Start async all_gather @@ -79,7 +109,8 @@ def all_gather_params_async( param_partitions = [torch.empty_like(param.data) for _ in range(tp_size)] handle = dist.all_gather(param_partitions, param.data, group=tp_group, async_op=True) - gather_tasks.append((info, None, handle, param_partitions, param.partition_dim)) + partition_stride = getattr(param, "partition_stride", 1) + gather_tasks.append((info, None, handle, param_partitions, param.partition_dim, partition_stride)) handles.append(handle) # Phase 2: Wait for ALL async operations to complete at once @@ -90,13 +121,12 @@ def all_gather_params_async( # Phase 3: Process all results after all communications are done gathered_params = [] - for info, direct_param, handle, param_partitions, partition_dim in gather_tasks: + for info, direct_param, handle, param_partitions, partition_dim, partition_stride in gather_tasks: if handle is None: # No all_gather needed param = direct_param else: # Process the gathered partitions (same logic as original all_gather_param) - assert partition_dim is not None, "partition_stride != 1 is not supported" # TODO: here we did an extra copy during concat, maybe merge this with convert_to_hf is better? # TODO: check only GLU is used. if "linear_fc1.weight" in info.name: @@ -106,7 +136,7 @@ def all_gather_params_async( if "linear_fc2.weight" in info.name: if partition_dim == 0: partition_dim = 1 - param = torch.cat(param_partitions, dim=partition_dim) + param = _cat_partitions(param_partitions, partition_dim, partition_stride) gathered_params.append(param) From e7216d8356614479db66693818263312f6823195 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 1 Apr 2026 14:29:17 +0800 Subject: [PATCH 09/13] Fix bug vllm transfer weight Signed-off-by: knlnguyen1802 --- .../update_weight/update_weight_from_distributed.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 03d480801..e827fac6b 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -83,14 +83,18 @@ def _nccl_bridge_worker(conn, master_address, master_port, world_size, device, c elif op == "send_packed": from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine, ) - NCCLWeightTransferEngine.trainer_send_weights( - iterator=iter(cmd["named_tensors"]), + trainer_args = NCCLTrainerSendWeightsArgs( group=comm, packed=True, ) + NCCLWeightTransferEngine.trainer_send_weights( + iterator=iter(cmd["named_tensors"]), + trainer_args=trainer_args, + ) torch.cuda.synchronize() conn.send("ok") From 2343647735876472553fdfee9954f8bc0d96dbed Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 1 Apr 2026 15:07:07 +0800 Subject: [PATCH 10/13] Fix weight sync Signed-off-by: knlnguyen1802 --- .../megatron_utils/update_weight/common.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index e0a24cf1c..35faf97c1 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -15,31 +15,33 @@ def _cat_partitions( partitions: list[torch.Tensor], partition_dim: int, stride: int ) -> torch.Tensor: - """Concatenate TP-gathered partitions, handling interleaved (stride > 1) layouts. + """Concatenate TP-gathered partitions, reversing Megatron-Core's strided layout. When ``stride == 1`` each rank holds a single contiguous shard and a plain ``torch.cat`` along *partition_dim* is sufficient. - When ``stride > 1`` (e.g. Megatron-Core GroupedMLP expert weights) each - rank's shard contains multiple interleaved groups of ``stride`` consecutive - elements. We split each shard into those groups and interleave across - ranks to reconstruct the original full tensor. + When ``stride > 1``, Megatron-Core's ``_initialize_affine_weight_cpu`` + splits the master weight into ``world_size * stride`` equal sub-chunks and + assigns rank *r* the sub-chunks at indices ``r, r + world_size, + r + 2 * world_size, …`` (i.e. ``weight_list[rank::world_size]``). Each + rank therefore holds ``stride`` sub-chunks concatenated together. + + To reconstruct the original full tensor we split each rank's shard back + into ``stride`` sub-chunks and interleave them: + ``for i in range(stride): for r in range(world_size): append chunk[r][i]`` """ if stride == 1: return torch.cat(partitions, dim=partition_dim) - # Number of contiguous groups stored on each rank - num_groups = partitions[0].shape[partition_dim] // stride - if num_groups <= 0: - # Fallback: stride >= shard size — simple concat should still be correct - return torch.cat(partitions, dim=partition_dim) - - rank_groups = [p.chunk(num_groups, dim=partition_dim) for p in partitions] - interleaved: list[torch.Tensor] = [] - for g in range(num_groups): - for r in range(len(partitions)): - interleaved.append(rank_groups[r][g]) - return torch.cat(interleaved, dim=partition_dim) + world_size = len(partitions) + # Each rank holds exactly `stride` sub-chunks concatenated along partition_dim. + rank_chunks = [p.chunk(stride, dim=partition_dim) for p in partitions] + # Reconstruct original ordering: chunk index (i * world_size + r) + ordered: list[torch.Tensor] = [] + for i in range(stride): + for r in range(world_size): + ordered.append(rank_chunks[r][i]) + return torch.cat(ordered, dim=partition_dim) def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: From 8498b7b497026bb78426a8daa951f882105e600a Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 1 Apr 2026 15:24:12 +0800 Subject: [PATCH 11/13] Fix Signed-off-by: knlnguyen1802 --- slime/backends/megatron_utils/update_weight/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/slime/backends/megatron_utils/update_weight/common.py b/slime/backends/megatron_utils/update_weight/common.py index 35faf97c1..529bd793d 100644 --- a/slime/backends/megatron_utils/update_weight/common.py +++ b/slime/backends/megatron_utils/update_weight/common.py @@ -72,6 +72,9 @@ def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor: if "linear_fc1.weight" in name: param_partitions = [p.chunk(2, dim=0) for p in param_partitions] param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] + # GLU rechunking already reverses the stride=2 interleaving, so use + # plain concatenation from here on. + stride = 1 # this is bug in megatron's grouped moe. if "linear_fc2.weight" in name: if partition_dim == 0: @@ -134,6 +137,8 @@ def all_gather_params_async( if "linear_fc1.weight" in info.name: param_partitions = [p.chunk(2, dim=0) for p in param_partitions] param_partitions = [p[0] for p in param_partitions] + [p[1] for p in param_partitions] + # GLU rechunking already reverses the stride=2 interleaving. + partition_stride = 1 # this is bug in megatron's grouped moe. if "linear_fc2.weight" in info.name: if partition_dim == 0: From 8a41184c35c03c51720b8b67cd2dbbbfdc7e3e9b Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 2 Apr 2026 13:26:38 +0000 Subject: [PATCH 12/13] Fix config Signed-off-by: knlnguyen1802 --- slime/backends/megatron_utils/arguments.py | 6 ++++++ slime/utils/arguments.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/slime/backends/megatron_utils/arguments.py b/slime/backends/megatron_utils/arguments.py index 3e6e7a0d1..eb4fad6e6 100644 --- a/slime/backends/megatron_utils/arguments.py +++ b/slime/backends/megatron_utils/arguments.py @@ -50,6 +50,12 @@ def equal(x, y): ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): + if not hasattr(args, megatron_config_name): + logger.warning( + f"Megatron args missing '{megatron_config_name}' (mapped from HF '{hf_config_name}') , " + f"Skip validate" + ) + continue if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to " diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 2699072f6..7d61297d8 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1507,7 +1507,7 @@ def parse_args(add_custom_arguments=None): args = megatron_parse_args( extra_args_provider=add_slime_arguments, - skip_hf_validate=True, #pre.debug_rollout_only, + skip_hf_validate=pre.debug_rollout_only, ) else: logger.warning( From acc969034bd5fe2ebcf3213b89a63ae33a2295c2 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 2 Apr 2026 13:45:58 +0800 Subject: [PATCH 13/13] Change name config Signed-off-by: knlnguyen1802 --- docs/en/examples/deepseek-r1.md | 4 ++-- docs/en/examples/qwen3-4B.md | 4 ++-- docs/zh/examples/deepseek-r1.md | 4 ++-- docs/zh/examples/qwen3-4B.md | 4 ++-- examples/fully_async/fully_async_rollout.py | 2 +- examples/tau-bench/run_qwen3_4B.sh | 2 +- run-qwen2.5-0.5B-vllm.sh | 2 +- scripts/low_precision/run-kimi-k2-Thinking-int4.sh | 2 +- scripts/run-deepseek-r1.sh | 2 +- scripts/run-kimi-k2-Instruct.sh | 2 +- scripts/run-kimi-k2-Thinking.sh | 2 +- slime/rollout/sglang_rollout.py | 2 +- slime/router/router.py | 2 +- slime/utils/arguments.py | 2 +- slime/utils/http_utils.py | 2 +- 15 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index 48cca18d3..22e3fd814 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( These are the parameters required by sglang. Here, `--rollout-num-gpus-per-engine` basically corresponds to sglang's `tp_size`. Other sglang parameters are passed to slime by adding a `--sglang-` prefix. To fully leverage sglang's large EP inference capabilities, we have added configurations like ep64, dp\_attention dp8, and deepep mode auto. -The final `--server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. +The final `--rollout-server-concurrency` is a parameter specific to slime. It is used to prevent the sglang server's concurrent requests from becoming too large and crashing the HTTP server. The default is 512. However, since we now have one server for 8 nodes, we have adjusted it to 1024 to ensure that each dp rank can have a concurrency of 128. ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank have 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 ) ``` diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 8a4022ab9..688b30b03 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -280,10 +280,10 @@ In this case, 2 GPUs will be allocated for training, and 6 GPUs will be allocate ⚠️ If the concurrency on each sglang server is too high, it may exceed sglang's default CUDA graph concurrency limit (the default maximum is 160), which will affect inference speed. You can adjust this in the following two ways: -1. Use `--server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: +1. Use `--rollout-server-concurrency` to limit the maximum number of concurrent requests sent to a single sglang server. For example: ```bash - --server-concurrency 160 + --rollout-server-concurrency 160 ``` 2. Use `--sglang-cuda-graph-bs` (which corresponds to sglang's native `--cuda-graph-bs` argument) to increase the number of CUDA graphs initialized by sglang. For example: diff --git a/docs/zh/examples/deepseek-r1.md b/docs/zh/examples/deepseek-r1.md index fba2e736a..703a7730e 100644 --- a/docs/zh/examples/deepseek-r1.md +++ b/docs/zh/examples/deepseek-r1.md @@ -171,7 +171,7 @@ OPTIMIZER_ARGS=( sglang 所需的参数,这里 `--rollout-num-gpus-per-engine` 基本对应 sglang 的 `tp_size`,除此之外的 sglang 参数均通过添加 `--sglang-` 的前缀来传给 slime。为了充分利用 sglang 的大 EP 推理能力,我们加上了 ep64、dp_attention dp8、deepep mode auto 等配置。 -最后的 `--server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 +最后的 `--rollout-server-concurrency` 是 slime 的特有参数,是为了方式同时发给 sglang server 的并发太大打爆 http server,默认为 512。但是我们现在是 8 机一个 server,为了保证每个 dp rank 能有 128 的并发,我们调整为 1024。 ```bash SGLANG_ARGS=( @@ -190,7 +190,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 ) ``` diff --git a/docs/zh/examples/qwen3-4B.md b/docs/zh/examples/qwen3-4B.md index ac0fad5b6..1f3adc9c5 100644 --- a/docs/zh/examples/qwen3-4B.md +++ b/docs/zh/examples/qwen3-4B.md @@ -280,10 +280,10 @@ ray job submit ... \ ⚠️ 在进行训推分离的时候,每个 sglang server 上的并发度太大,超过了 sglang 默认的 cuda graph 的并发度(默认最大 160),影响推理速度。可以用以下 2 种方式进行调整: -1. 通过 `--server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: +1. 通过 `--rollout-server-concurrency` 限制发给一个 sglang server 的最大并发量,例如: ```bash - --server-concurrency 160 + --rollout-server-concurrency 160 ``` 2. 使用 `--sglang-cuda-graph-bs`,即 sglang 原生的 `--cuda-graph-bs`, 增大 sglang 初始化的 cuda graph 数量,例如: diff --git a/examples/fully_async/fully_async_rollout.py b/examples/fully_async/fully_async_rollout.py index 25b832709..1df005ed0 100644 --- a/examples/fully_async/fully_async_rollout.py +++ b/examples/fully_async/fully_async_rollout.py @@ -20,7 +20,7 @@ def get_global_worker(args, data_buffer): with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): print("Creating new global async worker...") - _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.server_concurrency) + _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.rollout_server_concurrency) _global_worker.start() return _global_worker diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/run_qwen3_4B.sh index 81ad6f63f..fa36af552 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/run_qwen3_4B.sh @@ -100,7 +100,7 @@ SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.7 # If gemini API reports concurrency limit error, set this parameter to reduce the concurrency - # --server-concurrency 32 + # --rollout-server-concurrency 32 ) MISC_ARGS=( diff --git a/run-qwen2.5-0.5B-vllm.sh b/run-qwen2.5-0.5B-vllm.sh index 3fb6c896a..b352fa37d 100644 --- a/run-qwen2.5-0.5B-vllm.sh +++ b/run-qwen2.5-0.5B-vllm.sh @@ -94,7 +94,7 @@ WANDB_ARGS=( VLLM_ARGS=( --rollout-backend vllm --rollout-num-gpus-per-engine 1 - --server-concurrency 512 + --rollout-server-concurrency 512 --use-slime-router --slime-router-middleware-paths slime.router.middleware_hub.radix_tree_middleware.RadixTreeMiddleware ) diff --git a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh index d09a8f3bf..16bc65eae 100644 --- a/scripts/low_precision/run-kimi-k2-Thinking-int4.sh +++ b/scripts/low_precision/run-kimi-k2-Thinking-int4.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( #--sglang-deepep-mode auto # make every dp rank has 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 --use-slime-router ) diff --git a/scripts/run-deepseek-r1.sh b/scripts/run-deepseek-r1.sh index 9b8bc64b0..dc6d6cb74 100644 --- a/scripts/run-deepseek-r1.sh +++ b/scripts/run-deepseek-r1.sh @@ -126,7 +126,7 @@ SGLANG_ARGS=( --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 ) MISC_ARGS=( diff --git a/scripts/run-kimi-k2-Instruct.sh b/scripts/run-kimi-k2-Instruct.sh index 506b576d2..0c9fc2c3d 100644 --- a/scripts/run-kimi-k2-Instruct.sh +++ b/scripts/run-kimi-k2-Instruct.sh @@ -132,7 +132,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 ) diff --git a/scripts/run-kimi-k2-Thinking.sh b/scripts/run-kimi-k2-Thinking.sh index 088317761..533019d69 100644 --- a/scripts/run-kimi-k2-Thinking.sh +++ b/scripts/run-kimi-k2-Thinking.sh @@ -134,7 +134,7 @@ SGLANG_ARGS=( # --sglang-deepep-mode auto # make every dp rank has 128 concurrency - --server-concurrency 1024 + --rollout-server-concurrency 1024 ) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 416d38407..f3aec2d55 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -76,7 +76,7 @@ def __init__(self, args: Namespace) -> None: self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) self.semaphore = asyncio.Semaphore( - args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, diff --git a/slime/router/router.py b/slime/router/router.py index 9f660e24b..b77dd975f 100644 --- a/slime/router/router.py +++ b/slime/router/router.py @@ -45,7 +45,7 @@ def __init__(self, args, verbose=False): max_connections = getattr(args, "slime_router_max_connections", None) if max_connections is None: max_connections = ( - args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) timeout = getattr(args, "slime_router_timeout", None) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 7d61297d8..b7abc68b9 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -268,7 +268,7 @@ def add_rollout_arguments(parser): "Lower this if the training model leaves insufficient free memory.", ) parser.add_argument( - "--server-concurrency", + "--rollout-server-concurrency", type=int, default=512, help="Maximum number of concurrent requests sent to the rollout server. " diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index d8b92dc69..e6ebe6bf8 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -204,7 +204,7 @@ def init_http_client(args): if not args.rollout_num_gpus: return - _client_concurrency = args.server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + _client_concurrency = args.rollout_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency),