From 021a836f1acf95f72c36752e8eda79ad44a3a726 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 1 Jun 2026 22:57:09 -0400 Subject: [PATCH 01/11] Add Qwen3.5-4B math RL recipes (full + delta) + Qwen3.5 enablement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New recipes examples/math/qwen3.5-4b-m2po-{full,delta} for training Qwen3.5 (dense, hybrid Gated-DeltaNet text backbone; model_type qwen3_5) with M2PO on AstraFlow, mirroring the existing qwen3-8b-m2po recipe structure. Trained text-only for math (the checkpoint ships as Qwen3_5ForConditionalGeneration). Verified end-to-end on NVIDIA L40 (Ada, sm_89): a full run trained 86+ steps with no crash and steadily rising eval — overall avg@k 47.8% -> 57.4% (+9.6) and pass@k 56.5% -> 67.4% over the first 80 steps (AIME24/AIME25/AMC/Minerva/MATH500, eval every 10 steps). Minimal framework changes for Qwen3.5 / transformers>=5 compatibility: - model.py: register qwen3_5 + is_qwen3_5_model() - fsdp_engine.py: pass attention_mask=None for qwen3_5 (transformers>=5 create_causal_mask calls .ndim; the old dict form raised AttributeError) - fsdp/__init__.py: normalize _no_split_modules set->list (qwen3_5 exposes a set) - rlvr.py: unwrap BatchEncoding from apply_chat_template (transformers>=5) Recipes use the standard packed training forward. attention_backend=flashinfer (fa3 dispatches a Hopper-only kernel that fails on Ada/L40 for the GDN path; flashinfer + triton both verified); max_running_requests=32, mem_fraction_static=0.7 on inference and FSDP dp=4 + max_tokens_per_mb=8192 on the trainer to fit 44GB L40. Co-Authored-By: Claude Opus 4.8 (1M context) --- astraflow/train_worker/engine/fsdp_engine.py | 7 +- astraflow/train_worker/utils/model.py | 5 + .../scripts/1_astraflow.sh | 36 ++++ .../qwen3.5-4b-m2po-delta/scripts/2_raas.sh | 44 +++++ .../scripts/3_trainer_model0.sh | 47 ++++++ .../scripts/run_qwen3.5-4b-m2po-delta.sh | 107 ++++++++++++ .../yaml/experiment.yaml | 154 ++++++++++++++++++ .../math/qwen3.5-4b-m2po-delta/yaml/raas.yaml | 49 ++++++ .../scripts/1_astraflow.sh | 36 ++++ .../qwen3.5-4b-m2po-full/scripts/2_raas.sh | 44 +++++ .../scripts/3_trainer_model0.sh | 47 ++++++ .../scripts/run_qwen3.5-4b-m2po-full.sh | 107 ++++++++++++ .../qwen3.5-4b-m2po-full/yaml/experiment.yaml | 153 +++++++++++++++++ .../math/qwen3.5-4b-m2po-full/yaml/raas.yaml | 49 ++++++ 14 files changed, 883 insertions(+), 2 deletions(-) create mode 100755 examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh create mode 100755 examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh create mode 100755 examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh create mode 100755 examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh create mode 100644 examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml create mode 100644 examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml create mode 100755 examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh create mode 100755 examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh create mode 100755 examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh create mode 100755 examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh create mode 100644 examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml create mode 100644 examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml diff --git a/astraflow/train_worker/engine/fsdp_engine.py b/astraflow/train_worker/engine/fsdp_engine.py index 35c60a5..437a131 100644 --- a/astraflow/train_worker/engine/fsdp_engine.py +++ b/astraflow/train_worker/engine/fsdp_engine.py @@ -94,6 +94,7 @@ from astraflow.train_worker.utils.model import ( disable_dropout_in_model, is_gemma3_model, + is_qwen3_5_model, is_qwen3_moe_model, is_qwen3_vl_model, is_qwen_vl_model, @@ -1206,8 +1207,10 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: ] mb["use_cache"] = False padded_mb["use_cache"] = False - if is_qwen3_moe_model(self.model_config.model_type) or is_qwen3_vl_model( - self.model_config.model_type + if ( + is_qwen3_moe_model(self.model_config.model_type) + or is_qwen3_vl_model(self.model_config.model_type) + or is_qwen3_5_model(self.model_config.model_type) ): mb["attention_mask"] = None padded_mb["attention_mask"] = None diff --git a/astraflow/train_worker/utils/model.py b/astraflow/train_worker/utils/model.py index d7d3d29..c5c51f7 100644 --- a/astraflow/train_worker/utils/model.py +++ b/astraflow/train_worker/utils/model.py @@ -5,6 +5,7 @@ "qwen2_vl", "qwen2_5_vl", "qwen3_vl", + "qwen3_5", "gemma3", ] # Registry of vision models verified to work with this framework. @@ -25,6 +26,10 @@ def is_qwen3_vl_model(model_type: str) -> bool: return model_type in ["qwen3_vl"] +def is_qwen3_5_model(model_type: str) -> bool: + return model_type in ["qwen3_5"] + + def is_qwen_vl_model(model_type: str) -> bool: return is_qwen2_vl_model(model_type) or is_qwen3_vl_model(model_type) diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh new file mode 100755 index 0000000..851e565 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh new file mode 100755 index 0000000..b9cfe79 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..11ca6fe --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh b/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh new file mode 100755 index 0000000..e6a5ed3 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh @@ -0,0 +1,107 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3.5-4B, M2PO, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Requires: transformers>=5.8 (+ flash-linear-attention for training), +# SGLang main (qwen3_5 model). See yaml/raas.yaml for the backend note. +# +# Usage: +# bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow (Qwen3.5-4B, math, M2PO, ctx8k, TCP delta) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (FSDP dp${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml new file mode 100644 index 0000000..704382e --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml @@ -0,0 +1,154 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3.5-4b-m2po-delta +# +# Qwen3.5-4B math RL with M2PO, ctx 8k, lr 5e-6, delta TCP weight transfer. +# +# NOTE: Qwen3.5-4B is a HYBRID (Gated-DeltaNet + attention) multimodal model +# (architecture Qwen3_5ForConditionalGeneration, model_type qwen3_5), trained +# here TEXT-ONLY for math RL. Requires transformers>=5.8 (+ `fla` kernels for +# training) and SGLang main (qwen3_5 + TritonGDNKernel) for inference. +# attn_impl=sdpa (prebuilt flash-attn is not ABI-compatible with this torch). +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +# ============================================================================ + +experiment: + experiment_name: astraflow-math + trial_name: qwen3.5-4b-m2po-delta + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3.5-4B" + tokenizer_path: "Qwen/Qwen3.5-4B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: delta + +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 4000 + min_new_tokens: 0 + +dataflow: + host: "0.0.0.0" + port: 8000 + delta_full_sync_interval: 10 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + amc: + dataset_fn: "astraflow.dataflow.dataset.amc24:get_amc_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + minerva: + dataset_fn: "astraflow.dataflow.dataset.minervamath:get_minerva_math_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + engine: + backend: fsdp + data_parallel_size: 4 + + actor: + # sdpa: prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build; + # Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks use sdpa. + attn_impl: sdpa + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + + ref: + attn_impl: sdpa + mb_spec: + max_tokens_per_mb: 8192 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3.5-4b", "tcp", "ctx8k", "hybrid-gdn", "delta"] diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml new file mode 100644 index 0000000..9e949b9 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml @@ -0,0 +1,49 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3.5-4b-m2po-delta +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Qwen3.5-4B is a hybrid Gated-DeltaNet model: SGLang allocates a Mamba/SSM +# state cache (~10 GB) in addition to the KV cache, so keep mem_fraction_static +# conservative. Served by SGLang main via its TritonGDNKernel backend. +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# ============================================================================ + +rollout: + max_concurrent_rollouts: 512 + max_concurrent_evals: 64 + pause_grace_period: 3 + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 8192 + mem_fraction_static: 0.7 + # Attention backend. NOTE: this is specific to Qwen3.5's HYBRID + # Gated-DeltaNet architecture, NOT a general L40 limitation: + # - Plain dense Qwen3 (full-attention) runs fine with the default fa3 on + # L40/Ada. But for Qwen3.5's GDN path, fa3 dispatches a Hopper-only + # kernel (hopper/flash_fwd_launch_template.h) that fails on Ada (sm_89) + # with "CUDA error: invalid argument" under real load. + # - On non-Hopper archs SGLang auto-selects flashinfer (full-attn) + + # triton (the GDN/linear-attn + mamba layers). Both flashinfer and + # triton are verified working for Qwen3.5-4B here; flashinfer is set + # explicitly (SGLang's literal auto-default on Ada/L40). triton is an + # equally-valid alternative (set attention_backend: triton). + attention_backend: flashinfer + # Cap concurrency: with n_samples=8 x max_new_tokens=4000 and an + # unbounded queue, the hybrid GDN's KV + Mamba state cache overflows + # -> 'KV cache pool full, retract' -> CUDA OOM on a 44GB L40. 32 keeps + # peak KV bounded while still saturating the engine. + max_running_requests: 32 + skip_tokenizer_init: true diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh new file mode 100755 index 0000000..4fe0968 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -euo pipefail +# [1/3] Launch AstraFlow HTTP service +# +# Usage (terminal 1): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/1_astraflow.sh + +export CUDA_VISIBLE_DEVICES="" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== AstraFlow HTTP Service ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "Port : ${ASTRAFLOW_PORT}" +echo "===============================" + +python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh new file mode 100755 index 0000000..be71a71 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -euo pipefail +# [2/3] Launch RaaS inference server (SGLang + TCP receiver) +# +# Usage (terminal 2, after AstraFlow is ready): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/2_raas.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="${ASTRAFLOW_URL:-http://127.0.0.1:${ASTRAFLOW_PORT}}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== RaaS Inference Server (SGLang + TCP receiver) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES}" +echo "Port : ${RAAS_PORT}" +echo "AstraFlow URL : ${ASTRAFLOW_URL}" +echo "=======================================================" + +python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh new file mode 100755 index 0000000..900e1ab --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -euo pipefail +# [3/3] Launch Trainer for model0 (TCP, sender_agent on local_rank 0) +# +# Usage (terminal 3, after AstraFlow and RaaS are ready): +# bash examples/math/qwen3.5-4b-m2po-full/scripts/3_trainer_model0.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +astraflow_load_experiment_env + +export CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +TRAINER0_NPROC="$(echo "${CUDA_VISIBLE_DEVICES}" | awk -F',' '{print NF}')" + +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +# sender_agent (in trainer) listens on this HTTP port +export WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. Defined in examples/_common/utils.sh. +astraflow_setup_env + +echo "=== Trainer model0 (TCP) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "GPUs : ${CUDA_VISIBLE_DEVICES} (FSDP dp${TRAINER0_NPROC})" +echo "AstraFlow : ${ASTRAFLOW_URL}" +echo "RaaS : ${ASTRAFLOW_RAAS_URL}" +echo "Sender HTTP : ${WEIGHT_TRANSFER_HTTP_PORT}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================" + +torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh b/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh new file mode 100755 index 0000000..08d6681 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh @@ -0,0 +1,107 @@ +#!/bin/bash +set -euo pipefail +# All-in-one launcher for AstraFlow v2 math training (Qwen3.5-4B, M2PO, TCP). +# +# Launches 3 processes: +# 1. AstraFlow HTTP service (CPU-only) +# 2. RaaS inference server (SGLang, SERVICE_CUDA_VISIBLE_DEVICES) +# 3. Trainer model0 (math, TRAINER_MODEL0_GPUS) +# +# Requires: transformers>=5.8 (+ flash-linear-attention for training), +# SGLang main (qwen3_5 model). See yaml/raas.yaml for the backend note. +# +# Usage: +# bash examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh + +# ============================================================================= +# Part 1: Load env and settings +# ============================================================================= +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +export PYTHONPATH="${REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +YAML_DIR="${SCRIPT_DIR}/yaml" +export EXPERIMENT_CONFIG="${EXPERIMENT_CONFIG:-${YAML_DIR}/experiment.yaml}" +export RAAS_CONFIG="${RAAS_CONFIG:-${YAML_DIR}/raas.yaml}" +source "${REPO_ROOT}/examples/_common/utils.sh" +# Export EXP_NAME and TRIAL_NAME from the experiment YAML. +# Defined in examples/_common/utils.sh. +astraflow_load_experiment_env + +# ============================================================================= +# Part 2: Set up env +# ============================================================================= +# GPU assignments (default: 4 GPUs for inference, 4 for training) +export SERVICE_CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export TRAINER_MODEL0_GPUS="${TRAINER_MODEL0_GPUS:-4,5,6,7}" +# Ports / URLs (each component gets its own port) +export RAAS_HOST="${RAAS_HOST:-0.0.0.0}" +export RAAS_PORT="${RAAS_PORT:-19190}" +export ASTRAFLOW_HOST="${ASTRAFLOW_HOST:-0.0.0.0}" +export ASTRAFLOW_PORT="${ASTRAFLOW_PORT:-8000}" +export ASTRAFLOW_URL="http://127.0.0.1:${ASTRAFLOW_PORT}" +export WEIGHT_TRANSFER_HTTP_PORT_MODEL0="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0:-19861}" + +TRAINER0_NPROC="$(echo "${TRAINER_MODEL0_GPUS}" | awk -F',' '{print NF}')" + +# NCCL / PYTORCH / WANDB tweaks + LOG_DIR. +# Defined in examples/_common/utils.sh. +astraflow_setup_env + +# ============================================================================= +# Part 3: Print info and clean up +# ============================================================================= +echo "=== AstraFlow (Qwen3.5-4B, math, M2PO, ctx8k, TCP full) ===" +echo "Experiment config : ${EXPERIMENT_CONFIG}" +echo "RaaS config : ${RAAS_CONFIG}" +echo "RaaS GPUs : ${SERVICE_CUDA_VISIBLE_DEVICES}" +echo "Trainer model0 GPUs : ${TRAINER_MODEL0_GPUS} (FSDP dp${TRAINER0_NPROC})" +echo "RaaS port : ${RAAS_PORT}" +echo "AstraFlow port : ${ASTRAFLOW_PORT}" +echo "Sender HTTP model0 : ${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" +echo "WANDB mode : ${WANDB_MODE:-online}" +echo "==========================================================" + +trap astraflow_cleanup_trap EXIT INT TERM + +# Kill leftover processes and shared memory from prior runs. +# Defined in examples/_common/utils.sh. +astraflow_kill_stale + +# ============================================================================= +# Part 4: Launch training +# ============================================================================= +echo "[1/3] Starting AstraFlow HTTP service..." +CUDA_VISIBLE_DEVICES="" \ + python3 -u -m astraflow \ + --config "${EXPERIMENT_CONFIG}" \ + --port "${ASTRAFLOW_PORT}" \ + --host "${ASTRAFLOW_HOST}" \ + 2>&1 | tee "${LOG_DIR}/astraflow.log" & +sleep 5 + +echo "[2/3] Starting RaaS inference server (SGLang + TCP receiver)..." +CUDA_VISIBLE_DEVICES="${SERVICE_CUDA_VISIBLE_DEVICES}" \ + python3 -u -m astraflow.raas.server \ + --host "${RAAS_HOST}" \ + --port "${RAAS_PORT}" \ + --config "${EXPERIMENT_CONFIG}" \ + --config "${RAAS_CONFIG}" \ + --engine-id "${ENGINE_ID:-default}" \ + --astraflow-url "${ASTRAFLOW_URL}" \ + 2>&1 | tee "${LOG_DIR}/raas.log" & +sleep 15 + +export ASTRAFLOW_RAAS_URL="http://127.0.0.1:${RAAS_PORT}" + +echo "[3/3] Starting trainer model0..." +CUDA_VISIBLE_DEVICES="${TRAINER_MODEL0_GPUS}" \ +WEIGHT_TRANSFER_HTTP_PORT="${WEIGHT_TRANSFER_HTTP_PORT_MODEL0}" \ + torchrun --nnodes 1 --nproc-per-node "${TRAINER0_NPROC}" \ + --master-addr "${MASTER_ADDR:-127.0.0.1}" --master-port "${MASTER_PORT_MODEL0:-29541}" \ + examples/launch_trainer.py \ + --config "${EXPERIMENT_CONFIG}" \ + --trainer trainer_model0 \ + "$@" \ + 2>&1 | tee "${LOG_DIR}/trainer_model0.log" diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml new file mode 100644 index 0000000..e9ecbb9 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml @@ -0,0 +1,153 @@ +# ============================================================================ +# Experiment config — AstraFlow service + Trainer +# Experiment: math / qwen3.5-4b-m2po-full +# +# Qwen3.5-4B math RL with M2PO, ctx 8k, lr 5e-6, full TCP weight transfer. +# +# NOTE: Qwen3.5-4B is a HYBRID (Gated-DeltaNet + attention) multimodal model +# (architecture Qwen3_5ForConditionalGeneration, model_type qwen3_5), trained +# here TEXT-ONLY for math RL. Requires transformers>=5.8 (+ `fla` kernels for +# training) and SGLang main (qwen3_5 + TritonGDNKernel) for inference. +# attn_impl=sdpa (prebuilt flash-attn is not ABI-compatible with this torch). +# +# GPU layout (default, 8 GPUs): +# SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS (model0 dp=4) +# TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +# ============================================================================ + +experiment: + experiment_name: astraflow-math + trial_name: qwen3.5-4b-m2po-full + fileroot: ./data-experiments/${experiment.experiment_name}/${experiment.trial_name} + + model_path: "Qwen/Qwen3.5-4B" + tokenizer_path: "Qwen/Qwen3.5-4B" + seed: 1 + dtype: bfloat16 + weight_transfer_mode: tcp + weight_transfer_strategies: full + +raas: + models: + model0: + backend: sglang + gconfig: + n_samples: 8 + temperature: 1.0 + max_new_tokens: 4000 + min_new_tokens: 0 + +dataflow: + host: "0.0.0.0" + port: 8000 + + buffer: + size: 10000 + replay_size: 10000 + replay_ratio: 0 + max_staleness: 8 + filter_function: filter_zero_adv + + rollout_dataset: + dataset_fn: "astraflow.dataflow.dataset.deepscaler:get_deepscaler_rl_dataset" + max_length: 2000 + + workflow_spec: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + + eval_workflows: + math_eval: + workflow_cls: "rlvr" + reward_fn: "math_verify" + enable_thinking: false + gconfig_overrides: + temperature: 0.6 + n_samples: 1 + + eval_datasets: + aime24: + dataset_fn: "astraflow.dataflow.dataset.aime24x4:get_aime_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + aime25: + dataset_fn: "astraflow.dataflow.dataset.aime25x4:get_aime_2025x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + amc: + dataset_fn: "astraflow.dataflow.dataset.amc24:get_amc_2024x4_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + minerva: + dataset_fn: "astraflow.dataflow.dataset.minervamath:get_minerva_math_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + math500: + dataset_fn: "astraflow.dataflow.dataset.math500:get_math500_test_dataset" + max_length: 2000 + repeat: 4 + eval_workflow: math_eval + +trainer_base: + total_train_steps: 800 + train_batch_size: 256 + n_samples: 8 + engine: + backend: fsdp + data_parallel_size: 4 + + actor: + # sdpa: prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build; + # Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks use sdpa. + attn_impl: sdpa + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 8192 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + m2_threshold: 0.01 + eps_clip: 100.0 + eps_clip_higher: 100.0 + reward_scaling: 1 + reward_bias: 0 + kl_ctl: 0.00 + kl_penalty_coef: 0.001 + ppo_n_minibatches: 4 + reward_norm: { mean_level: group, std_level: group } + adv_norm: { mean_level: batch, std_level: batch } + + ref: + attn_impl: sdpa + mb_spec: + max_tokens_per_mb: 8192 + + recover: + mode: auto + freq_steps: 25 + + evaluator: + eval_at_start: false + freq_steps: 25 + + stats_logger: + wandb: + mode: online + id_suffix: "uid" + +trainer_model0: + model_id: model0 + stats_logger: + wandb: + tags: ["m2po", "math", "astraflow-v2", "qwen3.5-4b", "tcp", "ctx8k", "hybrid-gdn"] diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml new file mode 100644 index 0000000..cb5538f --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml @@ -0,0 +1,49 @@ +# ============================================================================ +# RaaS config — Inference serving instance (hardware/resources) +# Experiment: math / qwen3.5-4b-m2po-full +# +# Hardware: 4x GPU, TP=1 +# model0: DP=4, TP=1 +# +# Qwen3.5-4B is a hybrid Gated-DeltaNet model: SGLang allocates a Mamba/SSM +# state cache (~10 GB) in addition to the KV cache, so keep mem_fraction_static +# conservative. Served by SGLang main via its TritonGDNKernel backend. +# +# Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# ============================================================================ + +rollout: + max_concurrent_rollouts: 512 + max_concurrent_evals: 64 + pause_grace_period: 3 + enable_adaptive_availability: true + target_waiting_queue_per_dp: 4 + adaptive_step_size: 4 + load_cache_ttl_ms: 100 + +engine: + model0: + backend: sglang + data_parallel_size: 4 + +sglang: + context_length: 8192 + mem_fraction_static: 0.7 + # Attention backend. NOTE: this is specific to Qwen3.5's HYBRID + # Gated-DeltaNet architecture, NOT a general L40 limitation: + # - Plain dense Qwen3 (full-attention) runs fine with the default fa3 on + # L40/Ada. But for Qwen3.5's GDN path, fa3 dispatches a Hopper-only + # kernel (hopper/flash_fwd_launch_template.h) that fails on Ada (sm_89) + # with "CUDA error: invalid argument" under real load. + # - On non-Hopper archs SGLang auto-selects flashinfer (full-attn) + + # triton (the GDN/linear-attn + mamba layers). Both flashinfer and + # triton are verified working for Qwen3.5-4B here; flashinfer is set + # explicitly (SGLang's literal auto-default on Ada/L40). triton is an + # equally-valid alternative (set attention_backend: triton). + attention_backend: flashinfer + # Cap concurrency: with n_samples=8 x max_new_tokens=4000 and an + # unbounded queue, the hybrid GDN's KV + Mamba state cache overflows + # -> 'KV cache pool full, retract' -> CUDA OOM on a 44GB L40. 32 keeps + # peak KV bounded while still saturating the engine. + max_running_requests: 32 + skip_tokenizer_init: true From f40a14d5c7db285deb709c011202c816f78b6cf2 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Wed, 17 Jun 2026 15:00:28 -0400 Subject: [PATCH 02/11] Make Qwen3 dense recipes run on transformers-5.8 / torch-2.11 stack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The version bump (transformers 5.8.1, torch 2.11+cu130, sglang dev) broke the existing plain-Qwen3 math recipes (qwen3-1.7b, qwen3-8b). Two fixes: 1. fsdp_engine.py: always pass attention_mask=None for the packed/varlen forward. The old dict(full_attention=None, sliding_attention=None) form is a transformers-4.x relic; on transformers>=5 a dense model (qwen3/qwen2) treats that dict as a *precomputed* mask, skips creation, and crashes ('dict' object has no attribute 'ndim'). None is correct for all archs (dense, moe, vl, qwen3.5/GDN) — masking is driven by cu_seqlens + position_ids. Subsumes the prior qwen3_5/moe/vl special-case (drop now-unused imports). 2. recipes + cli_args: set attn_impl: sdpa in actor+ref for qwen3-1.7b and qwen3-8b recipes (flash_attn is ABI-broken vs torch 2.11+cu130 -> import crash; default was flash_attention_2). Expand attn_impl choices to sdpa/eager. Verified end-to-end: qwen3-1.7b full + delta train cleanly to step 100 on the bumped stack; math500 avg@k rises 73.0->77.0 (full) / 72.2->78.8 (delta); importance_weight=1.0000 (packed forward correct); zero crashes. Co-Authored-By: Claude Opus 4.8 (1M context) --- astraflow/train_worker/api/cli_args.py | 2 +- astraflow/train_worker/engine/fsdp_engine.py | 23 +++++++------------ .../yaml/experiment.yaml | 2 ++ .../yaml/experiment.yaml | 2 ++ .../qwen3-8b-m2po-delta/yaml/experiment.yaml | 2 ++ .../qwen3-8b-m2po-full/yaml/experiment.yaml | 2 ++ 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/astraflow/train_worker/api/cli_args.py b/astraflow/train_worker/api/cli_args.py index 096324e..af4e5a7 100644 --- a/astraflow/train_worker/api/cli_args.py +++ b/astraflow/train_worker/api/cli_args.py @@ -476,7 +476,7 @@ class TrainEngineConfig: default="flash_attention_2", metadata={ "help": "Attention implementation for huggingface transformers model.", - "choices": ["flash_attention_2"], + "choices": ["flash_attention_2", "sdpa", "eager"], }, ) init_from_scratch: bool = field( diff --git a/astraflow/train_worker/engine/fsdp_engine.py b/astraflow/train_worker/engine/fsdp_engine.py index 437a131..1a52c74 100644 --- a/astraflow/train_worker/engine/fsdp_engine.py +++ b/astraflow/train_worker/engine/fsdp_engine.py @@ -94,9 +94,6 @@ from astraflow.train_worker.utils.model import ( disable_dropout_in_model, is_gemma3_model, - is_qwen3_5_model, - is_qwen3_moe_model, - is_qwen3_vl_model, is_qwen_vl_model, is_valid_vision_model, ) @@ -1207,18 +1204,14 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: ] mb["use_cache"] = False padded_mb["use_cache"] = False - if ( - is_qwen3_moe_model(self.model_config.model_type) - or is_qwen3_vl_model(self.model_config.model_type) - or is_qwen3_5_model(self.model_config.model_type) - ): - mb["attention_mask"] = None - padded_mb["attention_mask"] = None - else: - mb["attention_mask"] = dict(full_attention=None, sliding_attention=None) - padded_mb["attention_mask"] = dict( - full_attention=None, sliding_attention=None - ) + # Always pass attention_mask=None for the packed/varlen forward: per-sequence + # causal masking is driven by cu_seqlens + position_ids, and the model builds + # the right mask from None. The old dict(full_attention=None, sliding_attention= + # None) form is a transformers-4.x relic: on transformers>=5 a dense model + # (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips creation, and + # crashes. None is correct for all archs (dense, moe, vl, qwen3.5/GDN). + mb["attention_mask"] = None + padded_mb["attention_mask"] = None if "multi_modal_input" in mb: image_grid_thw_list = [ item["image_grid_thw"] diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml index e758135..659b199 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml @@ -110,6 +110,7 @@ trainer_base: data_parallel_size: 1 actor: + attn_impl: sdpa gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,6 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: sdpa mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml index 766c242..9b00759 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml @@ -109,6 +109,7 @@ trainer_base: data_parallel_size: 1 actor: + attn_impl: sdpa gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -134,6 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: sdpa mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml index 1614b93..e25f7d6 100644 --- a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml @@ -110,6 +110,7 @@ trainer_base: data_parallel_size: 4 actor: + attn_impl: sdpa gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,6 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: sdpa mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml index f00f8d2..29984df 100644 --- a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml @@ -109,6 +109,7 @@ trainer_base: data_parallel_size: 4 actor: + attn_impl: sdpa gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -134,6 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: + attn_impl: sdpa mb_spec: max_tokens_per_mb: 17408 From e11e1888c5186a46eb8af72ffd173d3bce151d00 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 10:55:33 -0400 Subject: [PATCH 03/11] hygiene: drop dead is_qwen3_5_model; document qwen3_5 model routing - Remove the unused is_qwen3_5_model() helper (its only caller, the mask-branch special-case, went away when the packed-forward mask was made unconditional). - Note in model.py why qwen3_5 is in VALID_VISION_MODELS (the checkpoint ships as Qwen3_5ForConditionalGeneration -> loaded via the ImageTextToText path). - Simplify the packed-forward attention_mask comment in fsdp_engine.py. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- astraflow/train_worker/engine/fsdp_engine.py | 3 ++- astraflow/train_worker/utils/model.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/astraflow/train_worker/engine/fsdp_engine.py b/astraflow/train_worker/engine/fsdp_engine.py index 1a52c74..652fe06 100644 --- a/astraflow/train_worker/engine/fsdp_engine.py +++ b/astraflow/train_worker/engine/fsdp_engine.py @@ -1209,7 +1209,8 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: # the right mask from None. The old dict(full_attention=None, sliding_attention= # None) form is a transformers-4.x relic: on transformers>=5 a dense model # (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips creation, and - # crashes. None is correct for all archs (dense, moe, vl, qwen3.5/GDN). + # crashes. Passing None lets the model build its mask from cu_seqlens + + # position_ids instead. mb["attention_mask"] = None padded_mb["attention_mask"] = None if "multi_modal_input" in mb: diff --git a/astraflow/train_worker/utils/model.py b/astraflow/train_worker/utils/model.py index c5c51f7..b29053e 100644 --- a/astraflow/train_worker/utils/model.py +++ b/astraflow/train_worker/utils/model.py @@ -5,6 +5,8 @@ "qwen2_vl", "qwen2_5_vl", "qwen3_vl", + # qwen3.5 dense math checkpoints ship as Qwen3_5ForConditionalGeneration, so they + # load via the ImageTextToText path even though these recipes train text-only. "qwen3_5", "gemma3", ] @@ -26,10 +28,6 @@ def is_qwen3_vl_model(model_type: str) -> bool: return model_type in ["qwen3_vl"] -def is_qwen3_5_model(model_type: str) -> bool: - return model_type in ["qwen3_5"] - - def is_qwen_vl_model(model_type: str) -> bool: return is_qwen2_vl_model(model_type) or is_qwen3_vl_model(model_type) From 56ea7056f45f061213ccafd4a61a05670c7c82da Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 10:55:33 -0400 Subject: [PATCH 04/11] docs: add Qwen3.5-4B math recipe README (validated stack + eval) Document the validated runtime stack (transformers 5.8.1, kernels 0.14.1, SGLang dev with qwen3_5 + TritonGDNKernel, fla 0.5.0, flashinfer 0.6.11.post1, torch 2.11.0+cu130, attn_impl sdpa), GPU layout, run commands, and the validated eval (overall avg@k 47.8 -> 57.4, pass@k 56.5 -> 67.4 over 80 steps). The default pyproject pins load qwen3_5, but the validated stack is installed out of band; a pin bump is deferred to a separate, tested PR (the validated SGLang is a dev build, not a clean release). Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/math/qwen3.5-4b-m2po-delta/README.md | 16 ++++ examples/math/qwen3.5-4b-m2po-full/README.md | 75 +++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 examples/math/qwen3.5-4b-m2po-delta/README.md create mode 100644 examples/math/qwen3.5-4b-m2po-full/README.md diff --git a/examples/math/qwen3.5-4b-m2po-delta/README.md b/examples/math/qwen3.5-4b-m2po-delta/README.md new file mode 100644 index 0000000..f72e977 --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-delta/README.md @@ -0,0 +1,16 @@ +# Qwen3.5-4B — Math RL (M2PO), delta weight transfer + +Same recipe as [`qwen3.5-4b-m2po-full`](../qwen3.5-4b-m2po-full/README.md), but +the trainer pushes **only changed weights** to the inference engine each sync +(`weight_transfer_strategies: delta`) instead of the full model. + +See the [full recipe's README](../qwen3.5-4b-m2po-full/README.md) for the +validated environment (transformers 5.8.1 / kernels 0.14.1 / SGLang dev with +`qwen3_5`, `attention_backend: flashinfer` / `fla` 0.5.0 / torch 2.11.0+cu130), +GPU layout, install note, and validation results. + +## Run + +```bash +bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh +``` diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md new file mode 100644 index 0000000..668d16c --- /dev/null +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -0,0 +1,75 @@ +# Qwen3.5-4B — Math RL (M2PO) + +Text-only math RL on **Qwen/Qwen3.5-4B** with **M2PO**, context 8k, lr 5e-6, +DeepScaleR data, `math_verify` reward. + +Qwen3.5-4B is a **hybrid Gated-DeltaNet + attention multimodal** checkpoint +(architecture `Qwen3_5ForConditionalGeneration`, `model_type: qwen3_5`); these +recipes train it **text-only**. The checkpoint ships as an image-text-to-text +model, so AstraFlow loads it via the `AutoModelForImageTextToText` path (the +`model_type` is registered in `VALID_VISION_MODELS`), and the trainer uses +`attn_impl: sdpa` because a prebuilt flash-attn is not ABI-compatible with this +torch build. + +Two variants: + +| recipe | weight transfer | +|---|---| +| `qwen3.5-4b-m2po-full` | full (push the whole model each sync) | +| `qwen3.5-4b-m2po-delta` | delta (push only changed weights) | + +## Validated environment + +These recipes were validated end-to-end on the following stack (8× L40 / +Ada). The model and GDN kernels come from pip dependencies — there is no +hand-patched framework source: + +| package | version | +|---|---| +| `torch` | `2.11.0+cu130` | +| `transformers` | `5.8.1` | +| `kernels` | `0.14.1` | +| `sglang` | main/dev with `qwen3_5` support, served with `attention_backend: flashinfer` (validated build `0.5.6.post3.dev5643`) | +| `flash-linear-attention` (`fla`) | `0.5.0` | +| `flashinfer-python` | `0.6.11.post1` | +| attention impl | `sdpa` (set in `experiment.yaml`) | + +> **Install note.** The repo's default `pyproject.toml` pins +> (`transformers==5.6.1`, `sglang==0.5.12.post1`) resolve and *load* `qwen3_5`, +> but the recipe was trained on the stack above — install it out of band +> (e.g. a dedicated env). A `pyproject` pin bump is intentionally **not** part +> of this PR: the validated SGLang is a dev build (older than the pinned +> `0.5.12.post1`), so it cannot be pinned to a clean release yet. Bumping the +> pins is deferred to a separate, explicitly-tested PR once a published SGLang +> release with `qwen3_5` support is available. + +## GPU layout (default, 8 GPUs) + +``` +SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS / SGLang inference (model0, dp=4) +TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) +``` + +Override those env vars to use different GPUs. + +## Run + +```bash +bash examples/math/qwen3.5-4b-m2po-full/scripts/run_qwen3.5-4b-m2po-full.sh +# delta variant: +bash examples/math/qwen3.5-4b-m2po-delta/scripts/run_qwen3.5-4b-m2po-delta.sh +``` + +The launcher starts three processes (AstraFlow HTTP service, RaaS/SGLang +inference, FSDP trainer). See `scripts/` for the per-process scripts and +`yaml/` for the experiment / RaaS configs. + +## Validation + +Trained end-to-end on the stack above; eval rises steadily over training. +Qwen3.5-4B-full, overall metrics across the eval suite: + +| metric | step 0 | step 80 | Δ | +|---|---|---|---| +| overall avg@k | 47.8% | 57.4% | +9.6 | +| overall pass@k | 56.5% | 67.4% | +10.9 | From 6607b11b5789fcc85a14e7d352f34f6231ef2578 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 11:22:04 -0400 Subject: [PATCH 05/11] deps: pin transformers 5.8.1 (validated stack) + kernels 0.14 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bump the pyproject transformers pin (core dep + uv override) from 5.6.1 to the 5.8.1 that the Qwen3.5 and Qwen3-dense math recipes were validated on, and move the coupled kernels constraint from <0.13 to >=0.14,<0.15 (what transformers 5.8.1 was validated against). torch is already 2.11.0; flashinfer comes in transitively via sglang. sglang stays pinned at the published 0.5.12.post1 — the Qwen3.5 inference path was validated on an sglang dev build that ships qwen3_5, as noted in the recipe README. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/math/qwen3.5-4b-m2po-full/README.md | 15 +++++++-------- pyproject.toml | 20 +++++++++----------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md index 668d16c..fef4dbf 100644 --- a/examples/math/qwen3.5-4b-m2po-full/README.md +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -34,14 +34,13 @@ hand-patched framework source: | `flashinfer-python` | `0.6.11.post1` | | attention impl | `sdpa` (set in `experiment.yaml`) | -> **Install note.** The repo's default `pyproject.toml` pins -> (`transformers==5.6.1`, `sglang==0.5.12.post1`) resolve and *load* `qwen3_5`, -> but the recipe was trained on the stack above — install it out of band -> (e.g. a dedicated env). A `pyproject` pin bump is intentionally **not** part -> of this PR: the validated SGLang is a dev build (older than the pinned -> `0.5.12.post1`), so it cannot be pinned to a clean release yet. Bumping the -> pins is deferred to a separate, explicitly-tested PR once a published SGLang -> release with `qwen3_5` support is available. +> **Install note.** `pyproject.toml` pins `transformers==5.8.1` (the validated +> training version) with `kernels>=0.14,<0.15`; `torch` is already `2.11.0` and +> `flashinfer` is pulled in automatically as an SGLang dependency. SGLang itself +> stays pinned at the published `0.5.12.post1` — the Qwen3.5 *inference* path +> above was validated on an SGLang main/dev build that ships `qwen3_5` + +> `TritonGDNKernel`, so if your installed SGLang doesn't serve `qwen3_5`, install +> a build that does. ## GPU layout (default, 8 GPUs) diff --git a/pyproject.toml b/pyproject.toml index 1f6b4b4..0c7b667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "torchdata", "huggingface_hub", "datasets>=3.0.0", - "transformers==5.6.1", + "transformers==5.8.1", "megatron-core==0.13.1", "mbridge==0.13.0", "torch_memory_saver==0.0.9.post1", @@ -199,16 +199,14 @@ prerelease = "allow" exclude-dependencies = ["flash-attn"] override-dependencies=[ "outlines-core==0.1.26", - # sglang 0.5.12 pins transformers==5.6.0, which has a flash-attention bug - # (unconditional s_aux.to() crashes non-sink models like Qwen3). 5.6.1 is - # a patch release that fixes it; override sglang's exact pin to pick it up. - "transformers==5.6.1", - # sglang requires an unbounded "kernels", so uv resolves the latest (0.15+), - # but transformers 5.6.1 only supports kernels<0.13 (its hub_kernels module - # calls LayerRepository() without a revision/version, which 0.15 rejects -> - # `import sglang` crashes with "Either a revision or a version must be - # specified."). Pin to the range transformers 5.6.1 was built against. - "kernels>=0.12.0,<0.13", + # The Qwen3.5 + Qwen3-dense math recipes are validated on transformers 5.8.1 + # (see examples/math/qwen3.5-4b-m2po-full/README.md). sglang 0.5.12 pins + # transformers==5.6.0; override that exact pin to install the validated 5.8.1. + "transformers==5.8.1", + # transformers 5.8.1 is validated against kernels 0.14.x (sglang otherwise + # resolves kernels 0.15+, whose LayerRepository() revision handling breaks + # `import sglang`). Pin to the range the validated stack was built against. + "kernels>=0.14.0,<0.15", ] [tool.uv.extra-build-dependencies] From 1c111c29e250146230be990c54abb070c6fb1099 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 11:35:36 -0400 Subject: [PATCH 06/11] deps: pin sglang to the validated git build (qwen3_5 + transformers 5.8.1) The published sglang 0.5.12.post1 predates qwen3_5 and pins transformers 5.6.x, so it is incompatible with the transformers==5.8.1 the recipes require. Pin sglang to the validated main-branch build (sgl-project/sglang @ 373cadc9): it ships qwen3_5 + TritonGDNKernel and itself requires transformers==5.8.1, flashinfer 0.6.11.post1, torch 2.11.0, kernels<0.15 -- all matching the validated env. Update the [tool.uv] comments to match. Verified: every pin/override matches the working astraflow35 env (transformers 5.8.1, kernels 0.14.1, outlines-core 0.1.26, torch 2.11.0, sglang @ g373cadc92). Co-Authored-By: Claude Opus 4.8 (1M context) --- pyproject.toml | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0c7b667..093facb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,7 +140,12 @@ te = [ ] sglang = [ - "sglang==0.5.12.post1", + # Pinned to the validated main-branch build (commit 373cadc9): it ships qwen3_5 + # support + the TritonGDNKernel and requires transformers==5.8.1. The published + # 0.5.12.post1 release predates qwen3_5 and pins transformers 5.6.x, so it is + # not usable for the Qwen3.5 recipe. Installs from source (sgl-project/sglang, + # subdirectory=python); see examples/math/qwen3.5-4b-m2po-full/README.md. + "sglang @ git+https://github.com/sgl-project/sglang.git@373cadc92ea421710e32c395e8c0e931f000c707#subdirectory=python", ] vllm = [ @@ -192,20 +197,19 @@ include = ["astraflow*"] exclude = ["tests*", "docs*", "examples*"] [tool.uv] -# sglang 0.5.12 depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, pulled -# in automatically as a sglang dependency). Without this, `uv pip install +# The pinned sglang build depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, +# pulled in automatically as a sglang dependency). Without this, `uv pip install # -e ".[sglang]"` fails to resolve with "pre-releases weren't enabled". prerelease = "allow" exclude-dependencies = ["flash-attn"] override-dependencies=[ "outlines-core==0.1.26", - # The Qwen3.5 + Qwen3-dense math recipes are validated on transformers 5.8.1 - # (see examples/math/qwen3.5-4b-m2po-full/README.md). sglang 0.5.12 pins - # transformers==5.6.0; override that exact pin to install the validated 5.8.1. + # Pin transformers to 5.8.1, the version the pinned sglang build requires and + # the Qwen3.5 + Qwen3-dense math recipes are validated on + # (see examples/math/qwen3.5-4b-m2po-full/README.md). "transformers==5.8.1", - # transformers 5.8.1 is validated against kernels 0.14.x (sglang otherwise - # resolves kernels 0.15+, whose LayerRepository() revision handling breaks - # `import sglang`). Pin to the range the validated stack was built against. + # The pinned sglang build requires kernels<0.15; pin >=0.14 so uv selects the + # validated 0.14.x rather than an older kernels release. "kernels>=0.14.0,<0.15", ] From 705f95461c0238b199ccb5dd3ef09f2caf1b603e Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 11:42:18 -0400 Subject: [PATCH 07/11] style: align comments with repo conventions; refresh sglang install note Comment-only alignment of the Qwen3.5 recipes + touched code to the repo's conventions (qwen3-8b recipes / surrounding code as the baseline): - recipe experiment.yaml/raas.yaml: restore the "# -- ... --" section banners and the auto-derives / adaptive-availability rationale comments the dense recipes use; restyle the attn_impl comment to sentence-case prose. - model.py: capitalize the qwen3_5 registry comment to match the file. - fsdp_engine.py: reflow the attention_mask comment so dict(...) is not split mid-token, matching the file's other multi-line comments. - pyproject.toml: reflow the sglang-extra comment to the column band. - README: tag the GPU-layout code fence (text). No functional content changed (verified: all yaml/sh/py/toml parse; only comment lines were added/removed). Also refresh the recipe README install note, which still described SGLang as pinned at 0.5.12.post1 before the git-build pin landed. Co-Authored-By: Claude Opus 4.8 (1M context) --- astraflow/train_worker/engine/fsdp_engine.py | 10 ++++----- astraflow/train_worker/utils/model.py | 2 +- .../yaml/experiment.yaml | 16 ++++++++++++++ .../math/qwen3.5-4b-m2po-delta/yaml/raas.yaml | 3 +++ examples/math/qwen3.5-4b-m2po-full/README.md | 15 +++++++------ .../qwen3.5-4b-m2po-full/yaml/experiment.yaml | 21 +++++++++++++++++-- .../math/qwen3.5-4b-m2po-full/yaml/raas.yaml | 2 ++ pyproject.toml | 11 +++++----- 8 files changed, 59 insertions(+), 21 deletions(-) diff --git a/astraflow/train_worker/engine/fsdp_engine.py b/astraflow/train_worker/engine/fsdp_engine.py index 652fe06..1385277 100644 --- a/astraflow/train_worker/engine/fsdp_engine.py +++ b/astraflow/train_worker/engine/fsdp_engine.py @@ -1206,11 +1206,11 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList: padded_mb["use_cache"] = False # Always pass attention_mask=None for the packed/varlen forward: per-sequence # causal masking is driven by cu_seqlens + position_ids, and the model builds - # the right mask from None. The old dict(full_attention=None, sliding_attention= - # None) form is a transformers-4.x relic: on transformers>=5 a dense model - # (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips creation, and - # crashes. Passing None lets the model build its mask from cu_seqlens + - # position_ids instead. + # the right mask from None. The old dict(full_attention=None, + # sliding_attention=None) form is a transformers-4.x relic: on transformers>=5 + # a dense model (qwen3 / qwen2) treats that dict as a *precomputed* mask, skips + # creation, and crashes. Passing None lets the model build its mask from + # cu_seqlens + position_ids instead. mb["attention_mask"] = None padded_mb["attention_mask"] = None if "multi_modal_input" in mb: diff --git a/astraflow/train_worker/utils/model.py b/astraflow/train_worker/utils/model.py index b29053e..491f3d1 100644 --- a/astraflow/train_worker/utils/model.py +++ b/astraflow/train_worker/utils/model.py @@ -5,7 +5,7 @@ "qwen2_vl", "qwen2_5_vl", "qwen3_vl", - # qwen3.5 dense math checkpoints ship as Qwen3_5ForConditionalGeneration, so they + # Qwen3.5 dense math checkpoints ship as Qwen3_5ForConditionalGeneration, so they # load via the ImageTextToText path even though these recipes train text-only. "qwen3_5", "gemma3", diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml index 704382e..60cbbf3 100644 --- a/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/experiment.yaml @@ -15,6 +15,7 @@ # TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) # ============================================================================ +# ── Experiment: identity, model, shared settings ── experiment: experiment_name: astraflow-math trial_name: qwen3.5-4b-m2po-delta @@ -27,6 +28,8 @@ experiment: weight_transfer_mode: tcp weight_transfer_strategies: delta +# ── RaaS: what to generate (inference-level config) ── +# model keys here also determine expected_model_ids for AstraFlow service raas: models: model0: @@ -37,6 +40,9 @@ raas: max_new_tokens: 4000 min_new_tokens: 0 +# ── AstraFlow: data pipeline ── +# auto-derives: expected_model_ids from raas.models keys +# auto-derives: dump_dir from experiment.fileroot dataflow: host: "0.0.0.0" port: 8000 @@ -94,6 +100,14 @@ dataflow: repeat: 4 eval_workflow: math_eval +# ── Trainer base: shared config ── +# auto-derives from experiment: experiment_name, trial_name, fileroot, +# tokenizer_path, seed, dtype, weight_transfer_mode +# auto-derives from raas.models.: actor.path, actor.max_new_tokens, +# ref.path +# auto-derives: saver, recover, stats_logger fields from experiment section +# auto-derives: cluster.name_resolve from experiment.fileroot +# auto-derives: trial_name suffix from model_id (e.g. trial_name-model0) trainer_base: total_train_steps: 800 train_batch_size: 256 @@ -118,6 +132,7 @@ trainer_base: eps: 1e-8 lr_scheduler_type: constant gradient_clipping: 1.0 + # PPO / M2PO algorithm m2_threshold: 0.01 eps_clip: 100.0 eps_clip_higher: 100.0 @@ -147,6 +162,7 @@ trainer_base: mode: online id_suffix: "uid" +# ── Trainer for model0 — only overrides ── trainer_model0: model_id: model0 stats_logger: diff --git a/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml index 9e949b9..d8bb9fe 100644 --- a/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml +++ b/examples/math/qwen3.5-4b-m2po-delta/yaml/raas.yaml @@ -14,8 +14,11 @@ rollout: max_concurrent_rollouts: 512 + # Cap concurrent eval prefills to bound peak KV pressure during the + # ~3.5k-item eval burst (5 datasets x repeat=4) — default 128 OOMs sglang. max_concurrent_evals: 64 pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. enable_adaptive_availability: true target_waiting_queue_per_dp: 4 adaptive_step_size: 4 diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md index fef4dbf..270e5c0 100644 --- a/examples/math/qwen3.5-4b-m2po-full/README.md +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -34,17 +34,16 @@ hand-patched framework source: | `flashinfer-python` | `0.6.11.post1` | | attention impl | `sdpa` (set in `experiment.yaml`) | -> **Install note.** `pyproject.toml` pins `transformers==5.8.1` (the validated -> training version) with `kernels>=0.14,<0.15`; `torch` is already `2.11.0` and -> `flashinfer` is pulled in automatically as an SGLang dependency. SGLang itself -> stays pinned at the published `0.5.12.post1` — the Qwen3.5 *inference* path -> above was validated on an SGLang main/dev build that ships `qwen3_5` + -> `TritonGDNKernel`, so if your installed SGLang doesn't serve `qwen3_5`, install -> a build that does. +> **Install note.** `pyproject.toml` pins the full validated stack: +> `transformers==5.8.1` (with `kernels>=0.14,<0.15`), `torch==2.11.0`, and SGLang +> pinned to the validated main-branch build (`sgl-project/sglang` @ `373cadc9`) — +> the published `0.5.12.post1` release predates `qwen3_5`, so the git build is +> required. It installs from source and pulls `flashinfer` in automatically, so +> `uv pip install -e ".[sglang]"` resolves the validated environment directly. ## GPU layout (default, 8 GPUs) -``` +```text SERVICE_CUDA_VISIBLE_DEVICES=0,1,2,3 -> RaaS / SGLang inference (model0, dp=4) TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) ``` diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml index e9ecbb9..739613a 100644 --- a/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/experiment.yaml @@ -15,6 +15,7 @@ # TRAINER_MODEL0_GPUS=4,5,6,7 -> Trainer model0 (FSDP, 4 GPUs) # ============================================================================ +# ── Experiment: identity, model, shared settings ── experiment: experiment_name: astraflow-math trial_name: qwen3.5-4b-m2po-full @@ -27,6 +28,8 @@ experiment: weight_transfer_mode: tcp weight_transfer_strategies: full +# ── RaaS: what to generate (inference-level config) ── +# model keys here also determine expected_model_ids for AstraFlow service raas: models: model0: @@ -37,6 +40,9 @@ raas: max_new_tokens: 4000 min_new_tokens: 0 +# ── AstraFlow: data pipeline ── +# auto-derives: expected_model_ids from raas.models keys +# auto-derives: dump_dir from experiment.fileroot dataflow: host: "0.0.0.0" port: 8000 @@ -93,6 +99,14 @@ dataflow: repeat: 4 eval_workflow: math_eval +# ── Trainer base: shared config ── +# auto-derives from experiment: experiment_name, trial_name, fileroot, +# tokenizer_path, seed, dtype, weight_transfer_mode +# auto-derives from raas.models.: actor.path, actor.max_new_tokens, +# ref.path +# auto-derives: saver, recover, stats_logger fields from experiment section +# auto-derives: cluster.name_resolve from experiment.fileroot +# auto-derives: trial_name suffix from model_id (e.g. trial_name-model0) trainer_base: total_train_steps: 800 train_batch_size: 256 @@ -102,8 +116,9 @@ trainer_base: data_parallel_size: 4 actor: - # sdpa: prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build; - # Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks use sdpa. + # Prebuilt flash-attn isn't ABI-compatible with this torch/cu130 build, so + # use sdpa: Qwen3.5's GDN linear-attn uses fla kernels and full-attn blocks + # use sdpa. attn_impl: sdpa gradient_checkpointing: true mb_spec: @@ -117,6 +132,7 @@ trainer_base: eps: 1e-8 lr_scheduler_type: constant gradient_clipping: 1.0 + # PPO / M2PO algorithm m2_threshold: 0.01 eps_clip: 100.0 eps_clip_higher: 100.0 @@ -146,6 +162,7 @@ trainer_base: mode: online id_suffix: "uid" +# ── Trainer for model0 — only overrides ── trainer_model0: model_id: model0 stats_logger: diff --git a/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml index cb5538f..274a1bf 100644 --- a/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml +++ b/examples/math/qwen3.5-4b-m2po-full/yaml/raas.yaml @@ -10,12 +10,14 @@ # conservative. Served by SGLang main via its TritonGDNKernel backend. # # Merged with experiment.yaml at launch (--config experiment.yaml --config raas.yaml) +# experiment.yaml provides: model_path, tokenizer_path, seed, dtype, models/gconfig # ============================================================================ rollout: max_concurrent_rollouts: 512 max_concurrent_evals: 64 pause_grace_period: 3 + # Adaptive availability — drive /availability off sglang /get_load. enable_adaptive_availability: true target_waiting_queue_per_dp: 4 adaptive_step_size: 4 diff --git a/pyproject.toml b/pyproject.toml index 093facb..c67feec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,11 +140,12 @@ te = [ ] sglang = [ - # Pinned to the validated main-branch build (commit 373cadc9): it ships qwen3_5 - # support + the TritonGDNKernel and requires transformers==5.8.1. The published - # 0.5.12.post1 release predates qwen3_5 and pins transformers 5.6.x, so it is - # not usable for the Qwen3.5 recipe. Installs from source (sgl-project/sglang, - # subdirectory=python); see examples/math/qwen3.5-4b-m2po-full/README.md. + # Pinned to the validated main-branch build (commit 373cadc9): it ships + # qwen3_5 support + the TritonGDNKernel and requires transformers==5.8.1. The + # published 0.5.12.post1 release predates qwen3_5 and pins transformers 5.6.x, + # so it is not usable for the Qwen3.5 recipe. Installs from source + # (sgl-project/sglang, subdirectory=python); see + # examples/math/qwen3.5-4b-m2po-full/README.md. "sglang @ git+https://github.com/sgl-project/sglang.git@373cadc92ea421710e32c395e8c0e931f000c707#subdirectory=python", ] From 9c844d6892001691f72ddaf4e5db1f72acdedc08 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Mon, 22 Jun 2026 15:43:15 -0400 Subject: [PATCH 08/11] deps: match astraflow35 -- pin flash-linear-attention 0.5.0; networkx 3.6.1 Two pyproject fixes so a clean install reproduces the validated astraflow35 env: - Add flash-linear-attention==0.5.0 to the sglang extra: the fast Gated-DeltaNet kernels Qwen3.5 GDN training used in the validated runs (optional -- transformers falls back to a slower pure-torch path when absent). Pulls fla-core==0.5.0. - networkx==3.3 -> 3.6.1 to match the validated env (the only pinned version that differed from astraflow35). Verified: uv pip install --dry-run -e ".[sglang]" resolves cleanly (299 packages, exit 0) to the validated versions (flash-linear-attention 0.5.0, fla-core 0.5.0, networkx 3.6.1, transformers 5.8.1, sglang @ 373cadc9, ...). All other pinned ML versions already matched astraflow35; the ~60 loose utility deps are left flexible per the repo's convention. Co-Authored-By: Claude Opus 4.8 (1M context) --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c67feec..28f338d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ dependencies = [ "numba", "blosc", "pybind11>=2.10.0", - "networkx==3.3", + "networkx==3.6.1", "aiofiles", "aiohttp>=3.11.10", "httpx>=0.28.1", @@ -147,6 +147,10 @@ sglang = [ # (sgl-project/sglang, subdirectory=python); see # examples/math/qwen3.5-4b-m2po-full/README.md. "sglang @ git+https://github.com/sgl-project/sglang.git@373cadc92ea421710e32c395e8c0e931f000c707#subdirectory=python", + # Fast Gated-DeltaNet kernels for Qwen3.5 GDN training. Optional (transformers + # falls back to a slower pure-torch path when absent), but the validated runs + # used it; pin to the validated version. + "flash-linear-attention==0.5.0", ] vllm = [ From 3d5b0ccd1bfbcc8740c4ac764f74e843c1174cf2 Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Tue, 23 Jun 2026 12:42:08 -0400 Subject: [PATCH 09/11] recipes: default dense Qwen3 to kernels-hub FlashAttention-2 Switch the dense Qwen3 math recipes (qwen3-1.7b-m2po-2gpus-{full,delta}, qwen3-8b-m2po-{full,delta}) and the cli_args attn_impl default to kernels-community/flash-attn2 -- a prebuilt, ABI-matched FlashAttention-2 kernel from the HF kernels hub (cached on first use, no source build). The literal flash_attention_2 loads the local flash-attn wheel, which is ABI-broken on torch 2.11+cu130 (undefined symbol); is_flash_attn_2_available() only checks package metadata so it reports available and then crashes on import. The kernels-hub variant is ABI-matched and is the upstream-faithful FA2 the recipes were tuned with. A paired step-25 A/B/C on qwen3-1.7b-m2po-2gpus-full gave overall avg@k FA2 >= sdpa >= sdpa+recompute_logprob, all within eval noise; FA2 varlen also derives the packed block-diagonal mask from cu_seqlens, avoiding sdpa's reliance on position_ids resets. Qwen3.5 recipes stay on sdpa. Co-Authored-By: Claude Opus 4.8 (1M context) --- astraflow/train_worker/api/cli_args.py | 18 +++++++++++++++--- examples/math/README.md | 18 ++++++++++++++++++ .../yaml/experiment.yaml | 4 ++-- .../yaml/experiment.yaml | 4 ++-- .../qwen3-8b-m2po-delta/yaml/experiment.yaml | 4 ++-- .../qwen3-8b-m2po-full/yaml/experiment.yaml | 4 ++-- 6 files changed, 41 insertions(+), 11 deletions(-) diff --git a/astraflow/train_worker/api/cli_args.py b/astraflow/train_worker/api/cli_args.py index af4e5a7..41b1e6f 100644 --- a/astraflow/train_worker/api/cli_args.py +++ b/astraflow/train_worker/api/cli_args.py @@ -473,10 +473,22 @@ class TrainEngineConfig: trial_name: str = "" path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"}) attn_impl: str = field( - default="flash_attention_2", + default="kernels-community/flash-attn2", metadata={ - "help": "Attention implementation for huggingface transformers model.", - "choices": ["flash_attention_2", "sdpa", "eager"], + "help": ( + "Attention implementation for huggingface transformers model. " + "Default pulls a prebuilt FlashAttention-2 kernel from the HF kernels " + "hub (ABI-matched to torch, incl. varlen for packed sequences). The " + "literal 'flash_attention_2' loads the local flash-attn wheel, which is " + "ABI-broken on torch>=2.11+cu13; 'sdpa' works but relies on position_ids " + "resets for packed block-diagonal masking." + ), + "choices": [ + "kernels-community/flash-attn2", + "flash_attention_2", + "sdpa", + "eager", + ], }, ) init_from_scratch: bool = field( diff --git a/examples/math/README.md b/examples/math/README.md index 0b2735e..cc4a81f 100644 --- a/examples/math/README.md +++ b/examples/math/README.md @@ -15,3 +15,21 @@ Complete guidance: [`docs/en/recipes/math.md`](../../docs/en/recipes/math.md). Most math recipes default to one 8xH100 node. The `qwen3-1.7b-m2po-2gpus-*` recipes are smaller 2xH100 variants. + +--- +**Attention kernel** + +The dense Qwen3 recipes (`qwen3-1.7b-m2po-2gpus-*`, `qwen3-8b-m2po-*`) set +`attn_impl: kernels-community/flash-attn2` — a prebuilt, ABI-matched +FlashAttention-2 kernel pulled from the Hugging Face `kernels` hub (fetched and +cached on first use; no source build). This is the working FA2 on the validated +stack (`torch 2.11+cu130`): the literal `attn_impl: flash_attention_2` would +instead load the local `flash-attn` wheel and crash with an `undefined symbol` +ABI error (`is_flash_attn_2_available()` is metadata-only, so it never catches +the broken import). It is also the same kernel as `cli_args.py`'s default, so +recipes that omit `attn_impl` get it too. + +`sdpa` and `eager` remain available; `sdpa` works but relies on per-sequence +`position_ids` resets for packed block-diagonal masking, whereas FA2 varlen +derives the block-diagonal mask from `cu_seqlens` directly. The Qwen3.5 recipes +use `sdpa` (hybrid Gated-DeltaNet + attention model). diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml index 659b199..f1a0a27 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-delta/yaml/experiment.yaml @@ -110,7 +110,7 @@ trainer_base: data_parallel_size: 1 actor: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -136,7 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml index 9b00759..1a8cd02 100644 --- a/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml +++ b/examples/math/qwen3-1.7b-m2po-2gpus-full/yaml/experiment.yaml @@ -109,7 +109,7 @@ trainer_base: data_parallel_size: 1 actor: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,7 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml index e25f7d6..1629c36 100644 --- a/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-delta/yaml/experiment.yaml @@ -110,7 +110,7 @@ trainer_base: data_parallel_size: 4 actor: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -136,7 +136,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 diff --git a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml index 29984df..8fa6c90 100644 --- a/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml +++ b/examples/math/qwen3-8b-m2po-full/yaml/experiment.yaml @@ -109,7 +109,7 @@ trainer_base: data_parallel_size: 4 actor: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 gradient_checkpointing: true mb_spec: max_tokens_per_mb: 17408 @@ -135,7 +135,7 @@ trainer_base: adv_norm: { mean_level: batch, std_level: batch } ref: - attn_impl: sdpa + attn_impl: kernels-community/flash-attn2 mb_spec: max_tokens_per_mb: 17408 From 1e44ee7482b4743cbcb2e2bc0288cdd389a389ec Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Wed, 24 Jun 2026 13:45:16 -0400 Subject: [PATCH 10/11] deps: pin sglang to validated 0.5.13.post1 release (supersedes git build) Switch the [sglang] extra from the git build (373cadc9) to the published sglang==0.5.13.post1 -- the release the recipes were recently validated against, which ships qwen3_5 support (sglang/srt/models/qwen3_5.py) so it covers both the Qwen3.5 and Qwen3-dense recipes. Update the Qwen3.5 READMEs' validated-stack tables to match (sglang 0.5.13.post1, flashinfer 0.6.12). The flash-attn-4 pre-release / transformers 5.8.1 / kernels 0.14.x [tool.uv] overrides are unchanged (0.5.13.post1 still pulls flash-attn-4 4.0.0b15). Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/math/qwen3.5-4b-m2po-delta/README.md | 5 ++-- examples/math/qwen3.5-4b-m2po-full/README.md | 13 +++++----- pyproject.toml | 25 ++++++++++--------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/examples/math/qwen3.5-4b-m2po-delta/README.md b/examples/math/qwen3.5-4b-m2po-delta/README.md index f72e977..971a1eb 100644 --- a/examples/math/qwen3.5-4b-m2po-delta/README.md +++ b/examples/math/qwen3.5-4b-m2po-delta/README.md @@ -5,8 +5,9 @@ the trainer pushes **only changed weights** to the inference engine each sync (`weight_transfer_strategies: delta`) instead of the full model. See the [full recipe's README](../qwen3.5-4b-m2po-full/README.md) for the -validated environment (transformers 5.8.1 / kernels 0.14.1 / SGLang dev with -`qwen3_5`, `attention_backend: flashinfer` / `fla` 0.5.0 / torch 2.11.0+cu130), +validated environment (transformers 5.8.1 / kernels 0.14.1 / SGLang +`0.5.13.post1` with `qwen3_5`, `attention_backend: flashinfer` / `fla` 0.5.0 / +torch 2.11.0+cu130), GPU layout, install note, and validation results. ## Run diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md index 270e5c0..8af4642 100644 --- a/examples/math/qwen3.5-4b-m2po-full/README.md +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -29,17 +29,16 @@ hand-patched framework source: | `torch` | `2.11.0+cu130` | | `transformers` | `5.8.1` | | `kernels` | `0.14.1` | -| `sglang` | main/dev with `qwen3_5` support, served with `attention_backend: flashinfer` (validated build `0.5.6.post3.dev5643`) | +| `sglang` | `0.5.13.post1` (published release with `qwen3_5` support), served with `attention_backend: flashinfer` | | `flash-linear-attention` (`fla`) | `0.5.0` | -| `flashinfer-python` | `0.6.11.post1` | +| `flashinfer-python` | `0.6.12` (pulled by sglang) | | attention impl | `sdpa` (set in `experiment.yaml`) | > **Install note.** `pyproject.toml` pins the full validated stack: -> `transformers==5.8.1` (with `kernels>=0.14,<0.15`), `torch==2.11.0`, and SGLang -> pinned to the validated main-branch build (`sgl-project/sglang` @ `373cadc9`) — -> the published `0.5.12.post1` release predates `qwen3_5`, so the git build is -> required. It installs from source and pulls `flashinfer` in automatically, so -> `uv pip install -e ".[sglang]"` resolves the validated environment directly. +> `transformers==5.8.1` (with `kernels>=0.14,<0.15`), `torch==2.11.0`, and +> `sglang==0.5.13.post1` — the published release that ships `qwen3_5` support (the +> older `0.5.12.post1` predated it). It pulls `flashinfer 0.6.12` in automatically, +> so `uv pip install -e ".[sglang]"` resolves the validated environment directly. ## GPU layout (default, 8 GPUs) diff --git a/pyproject.toml b/pyproject.toml index 28f338d..35e599f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,13 +140,13 @@ te = [ ] sglang = [ - # Pinned to the validated main-branch build (commit 373cadc9): it ships - # qwen3_5 support + the TritonGDNKernel and requires transformers==5.8.1. The - # published 0.5.12.post1 release predates qwen3_5 and pins transformers 5.6.x, - # so it is not usable for the Qwen3.5 recipe. Installs from source - # (sgl-project/sglang, subdirectory=python); see - # examples/math/qwen3.5-4b-m2po-full/README.md. - "sglang @ git+https://github.com/sgl-project/sglang.git@373cadc92ea421710e32c395e8c0e931f000c707#subdirectory=python", + # Pinned to the validated published release 0.5.13.post1: it ships qwen3_5 + # support (sglang/srt/models/qwen3_5.py) + the TritonGDNKernel and requires + # transformers==5.8.1. (A git build was pinned earlier while the then-latest + # 0.5.12.post1 still predated qwen3_5; 0.5.13.post1 supersedes it and is the + # release the Qwen3.5 + Qwen3-dense math recipes are validated against -- see + # examples/math/qwen3.5-4b-m2po-full/README.md.) + "sglang==0.5.13.post1", # Fast Gated-DeltaNet kernels for Qwen3.5 GDN training. Optional (transformers # falls back to a slower pure-torch path when absent), but the validated runs # used it; pin to the validated version. @@ -202,18 +202,19 @@ include = ["astraflow*"] exclude = ["tests*", "docs*", "examples*"] [tool.uv] -# The pinned sglang build depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, -# pulled in automatically as a sglang dependency). Without this, `uv pip install -# -e ".[sglang]"` fails to resolve with "pre-releases weren't enabled". +# The pinned sglang release (0.5.13.post1) depends on flash-attn-4>=4.0.0b9 (a +# pre-release wheel -- 4.0.0b15 in the validated env -- pulled in automatically as +# a sglang dependency). Without this, `uv pip install -e ".[sglang]"` fails to +# resolve with "pre-releases weren't enabled". prerelease = "allow" exclude-dependencies = ["flash-attn"] override-dependencies=[ "outlines-core==0.1.26", - # Pin transformers to 5.8.1, the version the pinned sglang build requires and + # Pin transformers to 5.8.1, the version the pinned sglang release requires and # the Qwen3.5 + Qwen3-dense math recipes are validated on # (see examples/math/qwen3.5-4b-m2po-full/README.md). "transformers==5.8.1", - # The pinned sglang build requires kernels<0.15; pin >=0.14 so uv selects the + # The pinned sglang release requires kernels<0.15; pin >=0.14 so uv selects the # validated 0.14.x rather than an older kernels release. "kernels>=0.14.0,<0.15", ] From 34b17d2a87cada35a31deaac5d640e4a134e1a0d Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Thu, 25 Jun 2026 15:43:05 -0400 Subject: [PATCH 11/11] docs: note Qwen3.5-4B re-validation on pinned sglang 0.5.13.post1 Both full and delta variants re-run end-to-end on the published-release pin (sglang==0.5.13.post1): training completes, full + delta weight transfer both work, eval holds at baseline (~49-51% overall avg@k) -- no regression vs the predecessor git build the step0->step80 table was produced on. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/math/qwen3.5-4b-m2po-full/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/math/qwen3.5-4b-m2po-full/README.md b/examples/math/qwen3.5-4b-m2po-full/README.md index 8af4642..3d0dd25 100644 --- a/examples/math/qwen3.5-4b-m2po-full/README.md +++ b/examples/math/qwen3.5-4b-m2po-full/README.md @@ -70,3 +70,10 @@ Qwen3.5-4B-full, overall metrics across the eval suite: |---|---|---|---| | overall avg@k | 47.8% | 57.4% | +9.6 | | overall pass@k | 56.5% | 67.4% | +10.9 | + +The table above was produced on the predecessor SGLang git build. Both variants +(`full` and `delta`) were subsequently re-validated end-to-end on the pinned +`sglang==0.5.13.post1` release: training completes with no crashes, full +(`shard_copy`) and delta (~7× compressed) weight transfer both function, and +eval holds at the baseline (overall avg@k ≈ 49–51% over a short run) — i.e. the +published-release pin introduces no regression versus the git build.