diff --git a/astraflow/train_worker/api/cli_args.py b/astraflow/train_worker/api/cli_args.py index 096324e..41b1e6f 100644 --- a/astraflow/train_worker/api/cli_args.py +++ b/astraflow/train_worker/api/cli_args.py @@ -473,10 +473,22 @@ class TrainEngineConfig: trial_name: str = "" path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"}) attn_impl: str = field( - default="flash_attention_2", + default="kernels-community/flash-attn2", metadata={ - "help": "Attention implementation for huggingface transformers model.", - "choices": ["flash_attention_2"], + "help": ( + "Attention implementation for huggingface transformers model. " + "Default pulls a prebuilt FlashAttention-2 kernel from the HF kernels " + "hub (ABI-matched to torch, incl. varlen for packed sequences). The " + "literal 'flash_attention_2' loads the local flash-attn wheel, which is " + "ABI-broken on torch>=2.11+cu13; 'sdpa' works but relies on position_ids " + "resets for packed block-diagonal masking." + ), + "choices": [ + "kernels-community/flash-attn2", + "flash_attention_2", + "sdpa", + "eager", + ], }, ) init_from_scratch: bool = field( diff --git a/astraflow/train_worker/engine/fsdp_engine.py b/astraflow/train_worker/engine/fsdp_engine.py index 35c60a5..1385277 100644 --- a/astraflow/train_worker/engine/fsdp_engine.py +++ b/astraflow/train_worker/engine/fsdp_engine.py @@ -94,8 +94,6 @@ from astraflow.train_worker.utils.model import ( disable_dropout_in_model, is_gemma3_model, - is_qwen3_moe_model, - is_qwen3_vl_model, is_qwen_vl_model, is_valid_vision_model, ) @@ -1206,16 +1204,15 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: ] mb["use_cache"] = False padded_mb["use_cache"] = False - if is_qwen3_moe_model(self.model_config.model_type) or is_qwen3_vl_model( - self.model_config.model_type - ): - mb["attention_mask"] = None - padded_mb["attention_mask"] = None - else: - mb["attention_mask"] = dict(full_attention=None, sliding_attention=None) - padded_mb["attention_mask"] = dict( - full_attention=None, sliding_attention=None - ) + # Always pass attention_mask=None for the packed/varlen forward: per-sequence + # causal masking is driven by cu_seqlens + position_ids, and the model builds + # the right mask from None. The old dict(full_attention=None, + # sliding_attention=None) form is a transformers-4.x relic: on transformers>=5 + # a dense model (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips + # creation, and crashes. Passing None lets the model build its mask from + # cu_seqlens + position_ids instead. + mb["attention_mask"] = None + padded_mb["attention_mask"] = None if "multi_modal_input" in mb: image_grid_thw_list = [ item["image_grid_thw"] diff --git a/astraflow/train_worker/utils/model.py b/astraflow/train_worker/utils/model.py index d7d3d29..491f3d1 100644 --- a/astraflow/train_worker/utils/model.py +++ b/astraflow/train_worker/utils/model.py @@ -5,6 +5,9 @@ "qwen2_vl", "qwen2_5_vl", "qwen3_vl", + # Qwen3.5 dense math checkpoints ship as Qwen3_5ForConditionalGeneration, so they + # load via the ImageTextToText path even though these recipes train text-only. + "qwen3_5", "gemma3", ] # Registry of vision models verified to work with this framework. diff --git a/examples/math/README.md b/examples/math/README.md index 0b2735e..cc4a81f 100644 --- a/examples/math/README.md +++ b/examples/math/README.md @@ -15,3 +15,21 @@ Complete guidance: [`docs/en/recipes/math.md`](../../docs/en/recipes/math.md). Most math recipes default to one 8xH100 node. The `qwen3-1.7b-m2po-2gpus-*` recipes are smaller 2xH100 variants. + +--- +**Attention kernel** + +The dense Qwen3 recipes (`qwen3-1.7b-m2po-2gpus-*`, `qwen3-8b-m2po-*`) set +`attn_impl: kernels-community/flash-attn2` — a prebuilt, ABI-matched +FlashAttention-2 kernel pulled from the Hugging Face `kernels` hub (fetched and +cached on first use; no source build). This is the working FA2 on the validated +stack (`torch 2.11+cu130`): the literal `attn_impl: flash_attention_2` would +instead load the local `flash-attn` wheel and crash with an `undefined symbol` +ABI error (`is_flash_attn_2_available()` is metadata-only, so it never catches +the broken import). It is also the same kernel as `cli_args.py`'s default, so +recipes that omit `attn_impl` get it too. + +`sdpa` and `eager` remain available; `sdpa` works but relies on per-sequence +`position_ids` resets for packed block-diagonal masking, whereas FA2 varlen +derives the block-diagonal mask from `cu_seqlens` directly. The Qwen3.5 recipes +use `sdpa` (hybrid Gated-DeltaNet + attention model). diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml index e758135..f1a0a27 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml @@ -110,6 +110,7 @@ trainer_base: data_parallel_size: 1 actor: + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,6 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml index 766c242..1a8cd02 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml @@ -109,6 +109,7 @@ trainer_base: data_parallel_size: 1 actor: + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -134,6 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml index 1614b93..1629c36 100644 --- a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml @@ -110,6 +110,7 @@ trainer_base: data_parallel_size: 4 actor: + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,6 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml index f00f8d2..8fa6c90 100644 --- a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml @@ -109,6 +109,7 @@ trainer_base: data_parallel_size: 4 actor: + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -134,6 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3.5-4b-m2po-delta/README.md b/examples/math/qwen3.5-4b-m2po-delta/README.md new file mode 100644 index 0000000..971a1eb --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/README.md @@ -0,0 +1,17 @@ +# Qwen3.5-4B — Math RL (M2PO), delta weight transfer + +Same recipe as [`qwen3.5-4b-m2po-full`](../qwen3.5-4b-m2po-full/README.md), but +the trainer pushes **only changed weights** to the inference engine each sync +(`weight_transfer_strategies: delta`) instead of the full model. + +See the [full recipe's README](../qwen3.5-4b-m2po-full/README.md) for the +validated environment (transformers 5.8.1 / kernels 0.14.1 / SGLang +`0.5.13.post1` with `qwen3_5`, `attention_backend: flashinfer` / `fla` 0.5.0 / +torch 2.11.0+cu130), +GPU layout, install note, and validation results. + +## Run + +```bash +bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh +``` diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh new file mode 100755 index 0000000..851e565 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh new file mode 100755 index 0000000..b9cfe79 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..11ca6fe --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh new file mode 100755 index 0000000..e6a5ed3 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh @@ -0,0 +1,107 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3.5-4B, M2PO, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Requires: transformers>=5.8 (+ flash-linear-attention for training), +# SGLang main (qwen3_5 model). See yaml/raas.yaml for the backend note. +# +# Usage: +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow (Qwen3.5-4B, math, M2PO, ctx8k, TCP delta) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (FSDP dp${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml new file mode 100644 index 0000000..60cbbf3 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml @@ -0,0 +1,170 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3.5-4b-m2po-delta +# +# Qwen3.5-4B math RL with M2PO, ctx 8k, lr 5e-6, delta TCP weight transfer. +# +# NOTE: Qwen3.5-4B is a HYBRID (Gated-DeltaNet + attention) multimodal model +# (architecture Qwen3_5ForConditionalGeneration, model_type qwen3_5), trained +# here TEXT-ONLY for math RL. Requires transformers>=5.8 (+ `fla` kernels for +# training) and SGLang main (qwen3_5 + TritonGDNKernel) for inference. +# attn_impl=sdpa (prebuilt flash-attn is not ABI-compatible with this torch). +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +# ============================================================================ + +# ── Experiment: identity, model, shared settings ── +experiment: + experiment_name: astraflow-math + trial_name: qwen3.5-4b-m2po-delta + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3.5-4B" + tokenizer_path: "Qwen/Qwen3.5-4B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: delta + +# ── RaaS: what to generate (inference-level config) ── +# model keys here also determine expected_model_ids for AstraFlow service +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 4000 + min_new_tokens: 0 + +# ── AstraFlow: data pipeline ── +# auto-derives: expected_model_ids from raas.models keys +# auto-derives: dump_dir from experiment.fileroot +dataflow: + host: "0.0.0.0" + port: 8000 + delta_full_sync_interval: 10 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + amc: + dataset_fn: "astraflow.dataflow.dataset.amc24:get_amc_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + minerva: + dataset_fn: "astraflow.dataflow.dataset.minervamath:get_minerva_math_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +# ── Trainer base: shared config ── +# auto-derives from experiment: experiment_name, trial_name, fileroot, +# tokenizer_path, seed, dtype, weight_transfer_mode +# auto-derives from raas.models.: actor.path, actor.max_new_tokens, +# ref.path +# auto-derives: saver, recover, stats_logger fields from experiment section +# auto-derives: cluster.name_resolve from experiment.fileroot +# auto-derives: trial_name suffix from model_id (e.g. trial_name-model0) +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + engine: + backend: fsdp + data_parallel_size: 4 + + actor: + # sdpa: prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build; + # Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks use sdpa. + attn_impl: sdpa + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + # PPO / M2PO algorithm + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + + ref: + attn_impl: sdpa + mb_spec: + max_tokens_per_mb: 8192 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +# ── Trainer for model0 — only overrides ── +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3.5-4b", "tcp", "ctx8k", "hybrid-gdn", "delta"] diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml new file mode 100644 index 0000000..d8bb9fe --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml @@ -0,0 +1,52 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3.5-4b-m2po-delta +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Qwen3.5-4B is a hybrid Gated-DeltaNet model: SGLang allocates a Mamba/SSM +# state cache (~10 GB) in addition to the KV cache, so keep mem_fraction_static +# conservative. Served by SGLang main via its TritonGDNKernel backend. +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# ============================================================================ + +rollout: + max_concurrent_rollouts: 512 + # Cap concurrent eval prefills to bound peak KV pressure during the + # ~3.5k-item eval burst (5 datasets x repeat=4) — default 128 OOMs sglang. + max_concurrent_evals: 64 + pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 8192 + mem_fraction_static: 0.7 + # Attention backend. NOTE: this is specific to Qwen3.5's HYBRID + # Gated-DeltaNet architecture, NOT a general L40 limitation: + # - Plain dense Qwen3 (full-attention) runs fine with the default fa3 on + # L40/Ada. But for Qwen3.5's GDN path, fa3 dispatches a Hopper-only + # kernel (hopper/flash_fwd_launch_template.h) that fails on Ada (sm_89) + # with "CUDA error: invalid argument" under real load. + # - On non-Hopper archs SGLang auto-selects flashinfer (full-attn) + + # triton (the GDN/linear-attn + mamba layers). Both flashinfer and + # triton are verified working for Qwen3.5-4B here; flashinfer is set + # explicitly (SGLang's literal auto-default on Ada/L40). triton is an + # equally-valid alternative (set attention_backend: triton). + attention_backend: flashinfer + # Cap concurrency: with n_samples=8 x max_new_tokens=4000 and an + # unbounded queue, the hybrid GDN's KV + Mamba state cache overflows + # -> 'KV cache pool full, retract' -> CUDA OOM on a 44GB L40. 32 keeps + # peak KV bounded while still saturating the engine. + max_running_requests: 32 + skip_tokenizer_init: true diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md new file mode 100644 index 0000000..3d0dd25 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -0,0 +1,79 @@ +# Qwen3.5-4B — Math RL (M2PO) + +Text-only math RL on **Qwen/Qwen3.5-4B** with **M2PO**, context 8k, lr 5e-6, +DeepScaleR data, `math_verify` reward. + +Qwen3.5-4B is a **hybrid Gated-DeltaNet + attention multimodal** checkpoint +(architecture `Qwen3_5ForConditionalGeneration`, `model_type: qwen3_5`); these +recipes train it **text-only**. The checkpoint ships as an image-text-to-text +model, so AstraFlow loads it via the `AutoModelForImageTextToText` path (the +`model_type` is registered in `VALID_VISION_MODELS`), and the trainer uses +`attn_impl: sdpa` because a prebuilt flash-attn is not ABI-compatible with this +torch build. + +Two variants: + +| recipe | weight transfer | +|---|---| +| `qwen3.5-4b-m2po-full` | full (push the whole model each sync) | +| `qwen3.5-4b-m2po-delta` | delta (push only changed weights) | + +## Validated environment + +These recipes were validated end-to-end on the following stack (8× L40 / +Ada). The model and GDN kernels come from pip dependencies — there is no +hand-patched framework source: + +| package | version | +|---|---| +| `torch` | `2.11.0+cu130` | +| `transformers` | `5.8.1` | +| `kernels` | `0.14.1` | +| `sglang` | `0.5.13.post1` (published release with `qwen3_5` support), served with `attention_backend: flashinfer` | +| `flash-linear-attention` (`fla`) | `0.5.0` | +| `flashinfer-python` | `0.6.12` (pulled by sglang) | +| attention impl | `sdpa` (set in `experiment.yaml`) | + +> **Install note.** `pyproject.toml` pins the full validated stack: +> `transformers==5.8.1` (with `kernels>=0.14,<0.15`), `torch==2.11.0`, and +> `sglang==0.5.13.post1` — the published release that ships `qwen3_5` support (the +> older `0.5.12.post1` predated it). It pulls `flashinfer 0.6.12` in automatically, +> so `uv pip install -e ".[sglang]"` resolves the validated environment directly. + +## GPU layout (default, 8 GPUs) + +```text +SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS / SGLang inference (model0, dp=4) +TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +``` + +Override those env vars to use different GPUs. + +## Run + +```bash +bash examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh +# delta variant: +bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh +``` + +The launcher starts three processes (AstraFlow HTTP service, RaaS/SGLang +inference, FSDP trainer). See `scripts/` for the per-process scripts and +`yaml/` for the experiment / RaaS configs. + +## Validation + +Trained end-to-end on the stack above; eval rises steadily over training. +Qwen3.5-4B-full, overall metrics across the eval suite: + +| metric | step 0 | step 80 | Δ | +|---|---|---|---| +| overall avg@k | 47.8% | 57.4% | +9.6 | +| overall pass@k | 56.5% | 67.4% | +10.9 | + +The table above was produced on the predecessor SGLang git build. Both variants +(`full` and `delta`) were subsequently re-validated end-to-end on the pinned +`sglang==0.5.13.post1` release: training completes with no crashes, full +(`shard_copy`) and delta (~7× compressed) weight transfer both function, and +eval holds at the baseline (overall avg@k ≈ 49–51% over a short run) — i.e. the +published-release pin introduces no regression versus the git build. diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh new file mode 100755 index 0000000..4fe0968 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh new file mode 100755 index 0000000..be71a71 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..900e1ab --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh new file mode 100755 index 0000000..08d6681 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh @@ -0,0 +1,107 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3.5-4B, M2PO, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Requires: transformers>=5.8 (+ flash-linear-attention for training), +# SGLang main (qwen3_5 model). See yaml/raas.yaml for the backend note. +# +# Usage: +# bash examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow (Qwen3.5-4B, math, M2PO, ctx8k, TCP full) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (FSDP dp${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml new file mode 100644 index 0000000..739613a --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml @@ -0,0 +1,170 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3.5-4b-m2po-full +# +# Qwen3.5-4B math RL with M2PO, ctx 8k, lr 5e-6, full TCP weight transfer. +# +# NOTE: Qwen3.5-4B is a HYBRID (Gated-DeltaNet + attention) multimodal model +# (architecture Qwen3_5ForConditionalGeneration, model_type qwen3_5), trained +# here TEXT-ONLY for math RL. Requires transformers>=5.8 (+ `fla` kernels for +# training) and SGLang main (qwen3_5 + TritonGDNKernel) for inference. +# attn_impl=sdpa (prebuilt flash-attn is not ABI-compatible with this torch). +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +# ============================================================================ + +# ── Experiment: identity, model, shared settings ── +experiment: + experiment_name: astraflow-math + trial_name: qwen3.5-4b-m2po-full + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3.5-4B" + tokenizer_path: "Qwen/Qwen3.5-4B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: full + +# ── RaaS: what to generate (inference-level config) ── +# model keys here also determine expected_model_ids for AstraFlow service +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 4000 + min_new_tokens: 0 + +# ── AstraFlow: data pipeline ── +# auto-derives: expected_model_ids from raas.models keys +# auto-derives: dump_dir from experiment.fileroot +dataflow: + host: "0.0.0.0" + port: 8000 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + amc: + dataset_fn: "astraflow.dataflow.dataset.amc24:get_amc_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + minerva: + dataset_fn: "astraflow.dataflow.dataset.minervamath:get_minerva_math_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +# ── Trainer base: shared config ── +# auto-derives from experiment: experiment_name, trial_name, fileroot, +# tokenizer_path, seed, dtype, weight_transfer_mode +# auto-derives from raas.models.: actor.path, actor.max_new_tokens, +# ref.path +# auto-derives: saver, recover, stats_logger fields from experiment section +# auto-derives: cluster.name_resolve from experiment.fileroot +# auto-derives: trial_name suffix from model_id (e.g. trial_name-model0) +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + engine: + backend: fsdp + data_parallel_size: 4 + + actor: + # Prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build, so + # use sdpa: Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks + # use sdpa. + attn_impl: sdpa + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + # PPO / M2PO algorithm + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + + ref: + attn_impl: sdpa + mb_spec: + max_tokens_per_mb: 8192 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +# ── Trainer for model0 — only overrides ── +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3.5-4b", "tcp", "ctx8k", "hybrid-gdn"] diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml new file mode 100644 index 0000000..274a1bf --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml @@ -0,0 +1,51 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3.5-4b-m2po-full +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Qwen3.5-4B is a hybrid Gated-DeltaNet model: SGLang allocates a Mamba/SSM +# state cache (~10 GB) in addition to the KV cache, so keep mem_fraction_static +# conservative. Served by SGLang main via its TritonGDNKernel backend. +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# experiment.yaml provides: model_path, tokenizer_path, seed, dtype, models/gconfig +# ============================================================================ + +rollout: + max_concurrent_rollouts: 512 + max_concurrent_evals: 64 + pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 8192 + mem_fraction_static: 0.7 + # Attention backend. NOTE: this is specific to Qwen3.5's HYBRID + # Gated-DeltaNet architecture, NOT a general L40 limitation: + # - Plain dense Qwen3 (full-attention) runs fine with the default fa3 on + # L40/Ada. But for Qwen3.5's GDN path, fa3 dispatches a Hopper-only + # kernel (hopper/flash_fwd_launch_template.h) that fails on Ada (sm_89) + # with "CUDA error: invalid argument" under real load. + # - On non-Hopper archs SGLang auto-selects flashinfer (full-attn) + + # triton (the GDN/linear-attn + mamba layers). Both flashinfer and + # triton are verified working for Qwen3.5-4B here; flashinfer is set + # explicitly (SGLang's literal auto-default on Ada/L40). triton is an + # equally-valid alternative (set attention_backend: triton). + attention_backend: flashinfer + # Cap concurrency: with n_samples=8 x max_new_tokens=4000 and an + # unbounded queue, the hybrid GDN's KV + Mamba state cache overflows + # -> 'KV cache pool full, retract' -> CUDA OOM on a 44GB L40. 32 keeps + # peak KV bounded while still saturating the engine. + max_running_requests: 32 + skip_tokenizer_init: true diff --git a/pyproject.toml b/pyproject.toml index 1f6b4b4..35e599f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "torchdata", "huggingface_hub", "datasets>=3.0.0", - "transformers==5.6.1", + "transformers==5.8.1", "megatron-core==0.13.1", "mbridge==0.13.0", "torch_memory_saver==0.0.9.post1", @@ -86,7 +86,7 @@ dependencies = [ "numba", "blosc", "pybind11>=2.10.0", - "networkx==3.3", + "networkx==3.6.1", "aiofiles", "aiohttp>=3.11.10", "httpx>=0.28.1", @@ -140,7 +140,17 @@ te = [ ] sglang = [ - "sglang==0.5.12.post1", + # Pinned to the validated published release 0.5.13.post1: it ships qwen3_5 + # support (sglang/srt/models/qwen3_5.py) + the TritonGDNKernel and requires + # transformers==5.8.1. (A git build was pinned earlier while the then-latest + # 0.5.12.post1 still predated qwen3_5; 0.5.13.post1 supersedes it and is the + # release the Qwen3.5 + Qwen3-dense math recipes are validated against -- see + # examples/math/qwen3.5-4b-m2po-full/README.md.) + "sglang==0.5.13.post1", + # Fast Gated-DeltaNet kernels for Qwen3.5 GDN training. Optional (transformers + # falls back to a slower pure-torch path when absent), but the validated runs + # used it; pin to the validated version. + "flash-linear-attention==0.5.0", ] vllm = [ @@ -192,23 +202,21 @@ include = ["astraflow*"] exclude = ["tests*", "docs*", "examples*"] [tool.uv] -# sglang 0.5.12 depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, pulled -# in automatically as a sglang dependency). Without this, `uv pip install -# -e ".[sglang]"` fails to resolve with "pre-releases weren't enabled". +# The pinned sglang release (0.5.13.post1) depends on flash-attn-4>=4.0.0b9 (a +# pre-release wheel -- 4.0.0b15 in the validated env -- pulled in automatically as +# a sglang dependency). Without this, `uv pip install -e ".[sglang]"` fails to +# resolve with "pre-releases weren't enabled". prerelease = "allow" exclude-dependencies = ["flash-attn"] override-dependencies=[ "outlines-core==0.1.26", - # sglang 0.5.12 pins transformers==5.6.0, which has a flash-attention bug - # (unconditional s_aux.to() crashes non-sink models like Qwen3). 5.6.1 is - # a patch release that fixes it; override sglang's exact pin to pick it up. - "transformers==5.6.1", - # sglang requires an unbounded "kernels", so uv resolves the latest (0.15+), - # but transformers 5.6.1 only supports kernels<0.13 (its hub_kernels module - # calls LayerRepository() without a revision/version, which 0.15 rejects -> - # `import sglang` crashes with "Either a revision or a version must be - # specified."). Pin to the range transformers 5.6.1 was built against. - "kernels>=0.12.0,<0.13", + # Pin transformers to 5.8.1, the version the pinned sglang release requires and + # the Qwen3.5 + Qwen3-dense math recipes are validated on + # (see examples/math/qwen3.5-4b-m2po-full/README.md). + "transformers==5.8.1", + # The pinned sglang release requires kernels<0.15; pin >=0.14 so uv selects the + # validated 0.14.x rather than an older kernels release. + "kernels>=0.14.0,<0.15", ] [tool.uv.extra-build-dependencies]