diff --git a/astraflow/raas/engine/remote_inf_engine.py b/astraflow/raas/engine/remote_inf_engine.py index 0342e43..8388aca 100644 --- a/astraflow/raas/engine/remote_inf_engine.py +++ b/astraflow/raas/engine/remote_inf_engine.py @@ -306,6 +306,14 @@ def __init__( self.lock = Lock() self.lora_initialized = False + # Versioned LoRA adapter naming: each weight sync loads under a NEW + # name (``lora_v{seq}``) and we never unload. Unloading an adapter that + # still has paused/aborted in-flight requests deadlocks on SGLang's + # ``wait_for_unload`` (aborted requests never release their usage + # counter). New unique names avoid the unload entirely; SGLang's + # mem-pool LRU evicts stale adapters from GPU automatically. + self._lora_seq = 0 + self._current_lora_name: str | None = None self._executor: ProcessPoolExecutor | None = None self._paused: bool = False @@ -654,7 +662,7 @@ async def agenerate(self, req: ModelRequest) -> ModelResponse: f"agenerate() building HTTP request, rid={req.rid}, " f"iteration={iteration}, server_addr={server_addr}" ) - http_req = self.backend.build_generation_request(req, self.lora_initialized) + http_req = self.backend.build_generation_request(req, self._current_lora_name) # Loop until the generation is complete logger.debug( @@ -745,19 +753,33 @@ def load_weights_from_path( For full weights: ``/update_weights_from_disk`` includes ``abort_all_requests: True`` and ``flush_cache`` internally. - For LoRA adapters (``use_lora=True``): unloads the old adapter, - loads the new one, then flushes the KV cache via ``/flush_cache`` - to discard stale entries computed with the old LoRA weights. - Relies on sglang releasing the ``lora_registry`` counter for - aborted requests (fixed upstream in - ``TokenizerManager._handle_abort_finish_reason`` as of 0.5.12). + For LoRA adapters (``use_lora=True``): loads the new adapter under a + fresh versioned name (``lora_v{seq}``) without explicitly unloading the + previous one, then flushes the KV cache. SGLang's registry LRU evicts + old versions once ``max_loaded_loras`` is reached and its mem-pool LRU + reclaims GPU slots (bounded by ``max_loras_per_batch``); an evicted + adapter is transparently re-loaded on next use. + + Historically, explicitly unloading an adapter that still had + paused/aborted in-flight requests deadlocked SGLang's ``wait_for_unload`` + because the adapter's usage counter was never released on abort. That + leak is now fixed at the source by ``LoRACounterLeakPatch`` + (``astraflow/raas/patch/sglang.py``), so unload/eviction is safe. We keep + the fresh-name scheme because it stays correct without draining under + ``lora_update_lock`` on every sync. """ import time as _time _t0 = _time.monotonic() - lora_name = "lora_1" if use_lora: + # Load under a NEW versioned name and do NOT explicitly unload the + # old one. The abort-time usage-counter leak that used to make + # ``wait_for_unload`` (and thus registry-LRU eviction) hang is fixed + # by LoRACounterLeakPatch, so eviction is safe; the fresh name also + # avoids draining under ``lora_update_lock`` on every sync. + self._lora_seq += 1 + lora_name = f"lora_v{self._lora_seq}" logger.info( "load_weights_from_path: sending /load_lora_adapter " "to %d servers (path=%s, lora_name=%s) ...", @@ -766,19 +788,13 @@ def load_weights_from_path( lora_name, ) try: - if self.lora_initialized: - unload_req = HttpRequest( - endpoint="/unload_lora_adapter", - payload={"lora_name": lora_name}, - ) - self._run_request_on_all_servers(unload_req) - load_req = HttpRequest( endpoint="/load_lora_adapter", payload={"lora_name": lora_name, "lora_path": str(path)}, ) self._run_request_on_all_servers(load_req) self.lora_initialized = True + self._current_lora_name = lora_name # Flush stale KV cache entries computed with old LoRA weights. # Safe because caller already paused generation (is_pause=True diff --git a/astraflow/raas/engine/sglang_remote.py b/astraflow/raas/engine/sglang_remote.py index b49f63e..964ad9f 100644 --- a/astraflow/raas/engine/sglang_remote.py +++ b/astraflow/raas/engine/sglang_remote.py @@ -22,9 +22,13 @@ class SGLangBackend: """Backend that translates engine operations into SGLang HTTP API calls.""" def build_generation_request( - self, req: ModelRequest, with_lora: bool + self, req: ModelRequest, lora_name: str | None ) -> HttpRequest: - """Convert a ModelRequest into an SGLang /generate HTTP request.""" + """Convert a ModelRequest into an SGLang /generate HTTP request. + + ``lora_name`` is the currently-active versioned adapter name (e.g. + ``lora_v3``) or ``None`` when no adapter is loaded. + """ gconfig = req.gconfig stop_token_ids = gconfig.stop_token_ids stop = gconfig.stop @@ -55,8 +59,8 @@ def build_generation_request( "stream": False, } - if with_lora: - payload["lora_path"] = "lora_1" + if lora_name: + payload["lora_path"] = lora_name return HttpRequest(endpoint="/generate", payload=payload) diff --git a/astraflow/raas/engine/vllm_remote.py b/astraflow/raas/engine/vllm_remote.py index 68391dc..1c999f1 100644 --- a/astraflow/raas/engine/vllm_remote.py +++ b/astraflow/raas/engine/vllm_remote.py @@ -31,9 +31,14 @@ def __init__(self): pass def build_generation_request( - self, req: ModelRequest, with_lora: bool + self, req: ModelRequest, lora_name: str | None ) -> HttpRequest: - """Convert a ModelRequest into a vLLM completions or chat HTTP request.""" + """Convert a ModelRequest into a vLLM completions or chat HTTP request. + + ``lora_name`` is a truthy marker that a LoRA is active; vLLM selects + the adapter via ``gconfig.lora_name`` (its own naming), so the marker's + value is unused here. + """ gconfig = req.gconfig stop_token_ids = gconfig.stop_token_ids stop = gconfig.stop @@ -54,7 +59,7 @@ def build_generation_request( if stop: payload["stop"] = stop - if with_lora and len(gconfig.lora_name) > 0: + if lora_name and len(gconfig.lora_name) > 0: payload["model"] = gconfig.lora_name if req.vision_msg_vllm: @@ -181,6 +186,10 @@ def __init__(self, config: InferenceEngineConfig): self.config = config self._engine = RemoteInfEngine(config, VLLMBackend()) self._engine.lora_initialized = config.use_lora + # vLLM selects the adapter via gconfig.lora_name; this just marks LoRA + # active so the shared generation-request builder passes a truthy flag. + if config.use_lora: + self._engine._current_lora_name = "vllm_lora" def __getattr__(self, name: str): return getattr(self._engine, name) diff --git a/astraflow/raas/patch/__init__.py b/astraflow/raas/patch/__init__.py index e291e04..2b702cb 100644 --- a/astraflow/raas/patch/__init__.py +++ b/astraflow/raas/patch/__init__.py @@ -83,12 +83,14 @@ def _validate_patch_results(results: Dict[str, bool], strict: bool) -> None: def _run_sglang_patches(strict: bool) -> bool: from astraflow.raas.patch.sglang import ( HttpServerPatch, + LoRACounterLeakPatch, ServerArgsPatch, ) manager = PatchManager() manager.register(ServerArgsPatch()) manager.register(HttpServerPatch()) + manager.register(LoRACounterLeakPatch()) results = manager.apply_all() _log_patch_results(results) diff --git a/astraflow/raas/patch/sglang.py b/astraflow/raas/patch/sglang.py index 03481ec..e20086f 100644 --- a/astraflow/raas/patch/sglang.py +++ b/astraflow/raas/patch/sglang.py @@ -7,6 +7,9 @@ can register with RaaS at startup. 2. HttpServerPatch — register SGLang instance with the rollout manager during ``launch_server``. +3. LoRACounterLeakPatch — guarantee the LoRA adapter usage counter is + released for every request, including aborted / client-disconnected ones, + fixing a weight-sync deadlock at its source (see the class docstring). """ import logging @@ -16,6 +19,37 @@ logger = logging.getLogger(__name__) +async def release_lora_ref_once(tm, sub_obj) -> None: + """Release ``sub_obj``'s LoRA usage counter on ``tm`` (TokenizerManager) + exactly once, if it is still held. + + Idempotency invariant: SGLang's two native release sites both + ``del rid_to_state[rid]`` immediately before releasing, so ``rid in + rid_to_state`` iff the request has NOT yet been released. The membership + check and ``pop`` have no ``await`` between them, so they are atomic on the + single-threaded event loop — guaranteeing release is awaited at most once + per request. This matters because ``ConcurrentCounter.decrement`` has no + floor: a double-release would drive the counter to -1 and make + ``wait_for_zero`` (hence ``wait_for_unload``) hang forever. + """ + if not getattr(tm.server_args, "enable_lora", False): + return + if not getattr(sub_obj, "lora_path", None): + return + rid = getattr(sub_obj, "rid", None) + if rid is None or rid not in tm.rid_to_state: + return + tm.rid_to_state.pop(rid, None) + lora_id = getattr(sub_obj, "lora_id", None) + if lora_id is not None: + try: + await tm.lora_registry.release(lora_id) + except Exception: + logger.exception( + "release_lora_ref_once: release failed for rid=%s", rid + ) + + class ServerArgsPatch(BasePatch): """Add ``--rollout-manager-address`` to SGLang's ServerArgs.""" @@ -94,3 +128,82 @@ def patched_launch_server(server_args, *args, **kwargs): traceback.print_exc() return False + + +class LoRACounterLeakPatch(BasePatch): + """Release the LoRA adapter usage counter on EVERY request teardown. + + Root cause of the LoRA weight-sync deadlock: SGLang's ``LoRARegistry`` keeps + a per-adapter ``ConcurrentCounter`` (``lora/lora_registry.py``). It is + ``acquire()``-ed for every generate request but ``release()``-ed only on two + conditional branches in the tokenizer manager — normal completion + (``_handle_batch_output``) and one scheduler-abort case (``_wait_one_response``, + status SERVICE_UNAVAILABLE / INTERNAL_SERVER_ERROR). Requests that are aborted + or whose client disconnects (which the RaaS per-step drain routinely creates) + exit ``_wait_one_response`` without releasing — via a ``raise`` (client + disconnect, BAD_REQUEST) or a plain ``break`` (waiting-queue abort). The + adapter's counter then never returns to zero, so ``LoRARegistry.wait_for_unload`` + blocks forever. That hangs both an explicit ``/unload_lora_adapter`` AND the + ``load_lora_adapter`` LRU eviction that fires once ``max_loaded_loras`` versioned + adapters accumulate — while holding ``lora_update_lock``, freezing all further + LoRA ops. (The RaaS versioned-name scheme merely defers this to ~``max_loaded_loras`` + steps; this patch removes the leak so unload/eviction is always safe.) + + Fix: wrap ``TokenizerManager.generate_request`` — the single outermost + per-request async generator, where ``acquire`` happens (via + ``_validate_and_resolve_lora``) — and release in a ``finally`` so it runs on + every exit (normal return, raise, ``GeneratorExit``, ``CancelledError``). + Release is idempotent via the invariant that both native release sites + ``del rid_to_state[rid]`` immediately before releasing: ``rid in rid_to_state`` + iff not yet released. The membership check and ``pop`` have no ``await`` + between them, so they are atomic on the single-threaded event loop — no + double-release (which would drive the counter to -1 and hang ``wait_for_zero`` + permanently, since ``ConcurrentCounter.decrement`` has no floor). + """ + + def apply(self) -> bool: + import os + + if os.getenv("ASTRAFLOW_DISABLE_LORA_LEAK_FIX", "0").lower() in ("1", "true"): + logger.warning( + "LoRACounterLeakPatch disabled via ASTRAFLOW_DISABLE_LORA_LEAK_FIX; " + "LoRA weight-sync may deadlock on registry-LRU eviction." + ) + return True + + try: + from sglang.srt.managers.tokenizer_manager import TokenizerManager + except Exception as e: + logger.error(f"LoRACounterLeakPatch failed: {e}") + return False + + original_generate_request = TokenizerManager.generate_request + if self._is_patched(original_generate_request, "generate_request"): + return True + + async def patched_generate_request(self, obj, request=None): + try: + async for response in original_generate_request(self, obj, request): + yield response + finally: + # Guaranteed release on every exit path (normal, raise, + # GeneratorExit, CancelledError). ``obj`` has been normalized by + # ``original_generate_request`` before it reached the scheduler. + try: + if getattr(obj, "is_single", True): + await release_lora_ref_once(self, obj) + else: + # Batch request: release each sub-request that still + # holds its counter. (RaaS rollouts are single; + # best-effort.) + rids = getattr(obj, "rid", None) + if isinstance(rids, (list, tuple)): + for i in range(len(rids)): + await release_lora_ref_once(self, obj[i]) + except Exception: + logger.exception("LoRACounterLeakPatch cleanup error") + + self._mark_as_patched(patched_generate_request, "generate_request") + TokenizerManager.generate_request = patched_generate_request + + return True diff --git a/astraflow/raas/patch/tests/__init__.py b/astraflow/raas/patch/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/astraflow/raas/patch/tests/test_lora_counter_leak.py b/astraflow/raas/patch/tests/test_lora_counter_leak.py new file mode 100644 index 0000000..be27246 --- /dev/null +++ b/astraflow/raas/patch/tests/test_lora_counter_leak.py @@ -0,0 +1,120 @@ +"""Regression tests for the LoRA usage-counter release helper. + +These guard the deadlock fix in ``LoRACounterLeakPatch``: the per-adapter +``ConcurrentCounter`` in SGLang's ``LoRARegistry`` must be released exactly once +per request. A double-release is *fatal* — ``ConcurrentCounter.decrement`` has no +floor, so -1 makes ``wait_for_zero`` (hence ``wait_for_unload`` and the +``max_loaded_loras`` LRU eviction) hang forever. + +The helper is dependency-free (no sglang import), so these tests exercise the +idempotency logic with fakes. +""" + +import asyncio +from types import SimpleNamespace + +from astraflow.raas.patch.sglang import release_lora_ref_once + + +class _RecordingRegistry: + def __init__(self): + self.released = [] + + async def release(self, lora_id): + self.released.append(lora_id) + + +def _make_tm(rid_to_state, enable_lora=True): + reg = _RecordingRegistry() + tm = SimpleNamespace( + server_args=SimpleNamespace(enable_lora=enable_lora), + rid_to_state=rid_to_state, + lora_registry=reg, + ) + return tm, reg + + +def test_release_once_then_idempotent(): + """First call releases; a second call (rid already popped) is a no-op.""" + + async def run(): + rid = "rid-abc" + tm, reg = _make_tm({rid: object()}) + sub = SimpleNamespace(lora_path="/adapter", rid=rid, lora_id="lid-1") + + await release_lora_ref_once(tm, sub) # releases exactly once + await release_lora_ref_once(tm, sub) # no-op: rid no longer tracked + + assert reg.released == ["lid-1"], reg.released + assert rid not in tm.rid_to_state + + asyncio.run(run()) + + +def test_no_release_when_request_already_completed(): + """If normal/scheduler-abort already released (rid absent), never release.""" + + async def run(): + tm, reg = _make_tm({}) # rid_to_state already cleaned by the native path + sub = SimpleNamespace(lora_path="/adapter", rid="rid-xyz", lora_id="lid-1") + + await release_lora_ref_once(tm, sub) + + assert reg.released == [] + + asyncio.run(run()) + + +def test_no_release_when_lora_disabled_or_no_adapter(): + """No release for non-LoRA requests (enable_lora False or lora_path unset).""" + + async def run(): + tm, reg = _make_tm({"r": object()}, enable_lora=False) + await release_lora_ref_once( + tm, SimpleNamespace(lora_path="/adapter", rid="r", lora_id="lid") + ) + assert reg.released == [] + + tm2, reg2 = _make_tm({"r": object()}, enable_lora=True) + await release_lora_ref_once( + tm2, SimpleNamespace(lora_path=None, rid="r", lora_id="lid") + ) + assert reg2.released == [] + + asyncio.run(run()) + + +def test_concurrent_teardown_releases_once(): + """Two coroutines tearing down the same request race to release only once.""" + + async def run(): + rid = "rid-race" + tm, reg = _make_tm({rid: object()}) + sub = SimpleNamespace(lora_path="/adapter", rid=rid, lora_id="lid-1") + + await asyncio.gather( + release_lora_ref_once(tm, sub), + release_lora_ref_once(tm, sub), + ) + + assert reg.released == ["lid-1"], reg.released + + asyncio.run(run()) + + +if __name__ == "__main__": + # Standalone runner (avoids third-party pytest plugins that may be missing + # in the inference image, e.g. hypothesis -> pkg_resources). + import sys + + failures = 0 + for _name, _fn in sorted(globals().items()): + if _name.startswith("test_") and callable(_fn): + try: + _fn() + print(f"PASS {_name}") + except Exception as exc: # noqa: BLE001 + failures += 1 + print(f"FAIL {_name}: {type(exc).__name__}: {exc}") + print(f"{'ALL PASSED' if failures == 0 else f'{failures} FAILED'}") + sys.exit(1 if failures else 0) diff --git a/astraflow/raas/server/manager.py b/astraflow/raas/server/manager.py index 6dcf672..79a38f1 100644 --- a/astraflow/raas/server/manager.py +++ b/astraflow/raas/server/manager.py @@ -1815,11 +1815,15 @@ async def _do_weight_update( model_id, exc_info=True, ) - # Sync LoRA state to eval engines + # Sync LoRA state to eval engines. They share the same SGLang server, + # so they must use the same versioned adapter name in generation. if use_lora: + main_inner = getattr(engine, "_engine", engine) + cur_name = getattr(main_inner, "_current_lora_name", None) for eval_eng in self._eval_engines.values(): inner = getattr(eval_eng, "_engine", eval_eng) inner.lora_initialized = True + inner._current_lora_name = cur_name _timing = ( f"notify_version: loaded {model_id} v={version} " diff --git a/docs/en/recipes/math.md b/docs/en/recipes/math.md index 9d1a5d9..20dcc5b 100644 --- a/docs/en/recipes/math.md +++ b/docs/en/recipes/math.md @@ -43,6 +43,14 @@ bash examples/math/qwen3-1.7b-m2po-2gpus-full/scripts/run_qwen3-1.7b-m2po-2gpus- | Train dataset | DeepScaleR | | Eval datasets | AIME24, AIME25, AMC, Minerva Math, MATH500 | +### LoRA variant + +[`qwen3-1.7b-m2po-2gpus-lora/`](https://github.com/Infini-AI-Lab/astraflow/tree/main/examples/math/qwen3-1.7b-m2po-2gpus-lora) trains a LoRA adapter on the actor instead of full fine-tuning, keeping the same 2-GPU layout. Each step the trainer syncs the adapter to the SGLang server under a fresh versioned name (`lora_v{n}`) and never unloads it — SGLang's memory-pool LRU reclaims old versions — which avoids the unload deadlock that occurs when an adapter still holds aborted in-flight requests. One important caveat: a LoRA update is effectively much larger than a full-fine-tuning step at the same learning rate (the `alpha/rank` scaling), so LoRA needs near-on-policy training to stay stable. The recipe therefore sets `ppo_n_minibatches: 1`, `max_staleness: 1`, and `recompute_logprob: true` (with `lr` 5e-6); with these it shows a clean rising eval curve. On each weight sync the server first pauses generation and drains its in-flight requests (aborting any still running), then loads the new adapter under the fresh versioned name and flushes the stale KV cache before resuming — the old adapter is deliberately never unloaded, because unloading one that still holds aborted requests would block SGLang's `wait_for_unload` forever. Run it with: + +```bash +bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh +``` + ## Qwen3-8B — 8 GPUs The full-scale recipe. It needs an 8-GPU node — 4 GPUs for inference, 4 for training — and also comes in full and delta transfer variants: diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh new file mode 100755 index 0000000..fba722c --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh new file mode 100755 index 0000000..a7a3f05 --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/3_trainer_model0.sh b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..4481241 --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-1}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh new file mode 100755 index 0000000..a795180 --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh @@ -0,0 +1,104 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3-1.7B, M2PO, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Usage: +# bash examples/math/qwen3-1.7b-m2po-2gpus-lora/scripts/run_qwen3-1.7b-m2po-2gpus-lora.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 1 GPU for inference, 1 GPU for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-1}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow v2 (Qwen3-1.7B, math, M2PO, ctx7k, TCP full, LoRA, 2gpus) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (FSDP dp${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/experiment.yaml new file mode 100644 index 0000000..f63b03e --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/experiment.yaml @@ -0,0 +1,179 @@ +# ============================================================================ +# Experiment config -- AstraFlow service + Trainer +# Experiment: math / qwen3-1.7b-m2po-2gpus-lora +# +# Qwen3-1.7B math RL with M2PO, ctx 7k, lr 5e-6, full TCP weight transfer, +# LoRA (rank 32) on the actor. Mirrors the -full recipe; only the actor LoRA +# block and the SGLang LoRA flags (raas.yaml) differ. +# +# GPU layout (default, 2 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0 -> RaaS (model0 dp=1) +# TRAINER_MODEL0_GPUS=1 -> Trainer model0 (FSDP, 1 GPU) +# ============================================================================ + +# -- Experiment: identity, model, shared settings -- +experiment: + experiment_name: astraflow-math + trial_name: qwen3-1.7b-m2po-2gpus-lora + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3-1.7B" + tokenizer_path: "Qwen/Qwen3-1.7B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + # LoRA must use full transfer (delta is not used with LoRA). + weight_transfer_strategies: full + +# -- RaaS: what to generate (inference-level config) -- +# model keys here also determine expected_model_ids for AstraFlow service +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 4000 + min_new_tokens: 0 + +# -- AstraFlow: data pipeline -- +# auto-derives: expected_model_ids from raas.models keys +# auto-derives: dump_dir from experiment.fileroot +dataflow: + host: "0.0.0.0" + port: 8000 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + # LoRA needs near-on-policy rollouts: its alpha/rank scaling makes each + # weight update ~8x larger than full-FT at the same lr, so the staleness + # full-FT tolerates pushes LoRA past the stability edge (entropy collapse, + # destructive adapter). Keep staleness tight; pair with ppo_n_minibatches=1. + max_staleness: 1 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + amc: + dataset_fn: "astraflow.dataflow.dataset.amc24:get_amc_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + minerva: + dataset_fn: "astraflow.dataflow.dataset.minervamath:get_minerva_math_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +# -- Trainer base: shared config -- +# auto-derives from experiment: experiment_name, trial_name, fileroot, +# tokenizer_path, seed, dtype, weight_transfer_mode +# auto-derives from raas.models.: actor.path, actor.max_new_tokens, +# ref.path +# auto-derives: saver, recover, stats_logger fields from experiment section +# auto-derives: cluster.name_resolve from experiment.fileroot +# auto-derives: trial_name suffix from model_id (e.g. trial_name-model0) +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + engine: + backend: fsdp + data_parallel_size: 1 + + actor: + gradient_checkpointing: true + # Recompute old_logp in the trainer so the PPO ratio is computed against + # the exact policy being optimized (with ppo_n_minibatches=1 this pins the + # importance weight at 1.0 — fully on-policy, required for stable LoRA). + recompute_logprob: true + # -- LoRA -- + # Explicit target_modules (not "all-linear") so the SGLang launch + # `lora_target_modules` in raas.yaml matches deterministically. + use_lora: true + lora_rank: 32 + lora_alpha: 16 + target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj] + mb_spec: + max_tokens_per_mb: 17408 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + # PPO / M2PO algorithm + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + # On-policy single update: avoids within-batch policy drift, which LoRA's + # amplified effective step does not tolerate (see max_staleness note above). + ppo_n_minibatches: 1 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + + ref: + mb_spec: + max_tokens_per_mb: 17408 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +# -- Trainer for model0 -- only overrides -- +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3-1.7b", "tcp", "ctx7k", "lora", "2gpus", "examples", "sglang-d1", "fsdp-d1"] diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/raas.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/raas.yaml new file mode 100644 index 0000000..50917a6 --- /dev/null +++ b/examples/math/qwen3-1.7b-m2po-2gpus-lora/yaml/raas.yaml @@ -0,0 +1,56 @@ +# ============================================================================ +# RaaS config -- Inference serving instance (hardware/resources) +# Experiment: math / qwen3-1.7b-m2po-2gpus-lora +# +# Hardware: 1x GPU, TP=1 +# model0: DP=1, TP=1 +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# experiment.yaml provides: model_path, tokenizer_path, seed, dtype, models/gconfig +# +# LoRA notes: +# - enable_lora + max_lora_rank (>= actor.lora_rank) + lora_target_modules +# must be set at launch so SGLang pre-sizes its adapter memory pool. +# - lora_target_modules uses HF leaf names; SGLang normalizes q/k/v -> qkv_proj +# and gate/up -> gate_up_proj internally, matching the receiver's adapter. +# - lora_backend=triton is required (the LoRA kernels assume the triton path). +# - Versioned pause-and-swap: each step loads a fresh adapter name; the registry +# (max_loaded_loras) and GPU pool (max_loras_per_batch) LRU-evict old versions. +# ============================================================================ + +rollout: + max_concurrent_rollouts: 256 + pause_grace_period: 3 + # Adaptive availability -- drive /availability off sglang /get_load. + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 1 + # Mirror the actor's LoRA setting (drives the inference engine's LoRA mode). + use_lora: true + +sglang: + context_length: 7168 + mem_fraction_static: 0.8 + max_running_requests: 48 + skip_tokenizer_init: true + attention_backend: triton + # -- LoRA -- + enable_lora: true + max_lora_rank: 32 + lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj] + # Versioned adapters accumulate in the registry; we don't explicitly unload. + # max_loaded_loras caps the registry (LRU-evicts old versions once full); + # max_loras_per_batch caps GPU-resident adapters (mem-pool LRU evicts the rest) + # and lets a couple versions coexist at weight-sync boundaries. The registry-LRU + # eviction internally calls SGLang's wait_for_unload, which is now safe: + # LoRACounterLeakPatch (astraflow/raas/patch/sglang.py) fixes the abort-time + # usage-counter leak that used to hang it once this cap was reached. + max_loaded_loras: 256 + max_loras_per_batch: 4 + lora_backend: triton