Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions examples/omni_thinker/run_qwen3_omni_thinker_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""GRPO on the Qwen3-Omni-30B-A3B thinker (text MoE) with DAPO-math reward.

Off-policy until live weight-sync lands: point --sglang-router-ip/port at a standalone omni
server; the rollout half serves the frozen base while the trainer updates its own copy, with
TIS + --get-mismatch-metrics absorbing/measuring the gap.
"""

import os
from dataclasses import dataclass
from typing import Literal

import typer

import miles.utils.external_utils.command_utils as U

OMNI_MODEL = "Qwen3-Omni-30B-A3B-Instruct"
THINKER_MODEL = "Qwen3-Omni-30B-A3B-Thinker"
MEGATRON_MODEL_TYPE = "qwen3-omni-30B-A3B-thinker"


@dataclass
class ScriptArgs(U.ExecuteTrainConfig):
mode: Literal["normal", "debug_minimal"] = "normal"
run_id: str = U.create_run_id()
num_gpus_per_node: int = 8
data_dir: str = "/root/datasets"
model_dir: str = "/root/models"
megatron_path: str = "/root/Megatron-LM"
omni_router_ip: str = "127.0.0.1"
omni_router_port: int = 30000
extra_args: str = ""


def prepare(args: ScriptArgs):
U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}")
U.exec_command(f"hf download Qwen/{OMNI_MODEL} --local-dir {args.model_dir}/{OMNI_MODEL}")
repo = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
U.exec_command(
f"python {repo}/tools/extract_qwen3_omni_thinker.py "
f"--src {args.model_dir}/{OMNI_MODEL} --dst {args.model_dir}/{THINKER_MODEL}"
)
U.convert_checkpoint(
model_name=THINKER_MODEL,
megatron_model_type=MEGATRON_MODEL_TYPE,
num_gpus_per_node=args.num_gpus_per_node,
dir_dst=args.model_dir,
hf_checkpoint=f"{args.model_dir}/{THINKER_MODEL}",
megatron_path=args.megatron_path,
)
U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir)


def execute(args: ScriptArgs):
ref_load_path = f"{args.model_dir}/{THINKER_MODEL}_torch_dist"
load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints"

ckpt_args = (
f"--hf-checkpoint {args.model_dir}/{THINKER_MODEL}/ "
f"--ref-load {ref_load_path} "
f"--load {load_save_path} "
"--model-name qwen3omni_moe " # body.* broadcast naming
)

rollout_args = (
"--custom-generate-function-path miles.rollout.generate_hub.omni_thinker.generate "
f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl "
"--input-key prompt "
"--label-key label "
"--apply-chat-template "
"--rollout-shuffle "
"--rm-type dapo "
"--reward-key score "
"--num-rollout 3000 "
"--rollout-batch-size 32 "
"--n-samples-per-prompt 8 "
f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} "
"--rollout-temperature 1 "
"--global-batch-size 256 "
"--balance-data "
f"--sglang-router-ip {args.omni_router_ip} "
f"--sglang-router-port {args.omni_router_port} "
)

consistency_args = (
"--use-rollout-logprobs "
"--get-mismatch-metrics "
"--use-tis "
"--tis-clip 2.0 "
)

perf_args = (
"--tensor-model-parallel-size 8 "
"--sequence-parallel "
"--pipeline-model-parallel-size 1 "
"--context-parallel-size 1 "
"--expert-model-parallel-size 8 "
"--expert-tensor-parallel-size 1 "
"--recompute-granularity full "
"--recompute-method uniform "
"--recompute-num-layers 1 "
"--use-dynamic-batch-size "
"--max-tokens-per-gpu 9216 "
"--optimizer-cpu-offload "
"--overlap-cpu-optimizer-d2h-h2d "
"--use-precision-aware-optimizer "
)

grpo_args = (
"--advantage-estimator grpo "
"--use-kl-loss "
"--kl-loss-coef 0.00 "
"--kl-loss-type low_var_kl "
"--entropy-coef 0.00 "
"--eps-clip 0.2 "
"--eps-clip-high 0.28 "
)

optimizer_args = (
"--optimizer adam "
"--lr 1e-6 "
"--lr-decay-style constant "
"--weight-decay 0.1 "
"--adam-beta1 0.9 "
"--adam-beta2 0.98 "
)

misc_args = (
"--attention-dropout 0.0 "
"--hidden-dropout 0.0 "
"--accumulate-allreduce-grads-in-fp32 "
"--attention-softmax-in-fp32 "
"--attention-backend flash "
"--actor-num-nodes 1 "
f"--actor-num-gpus-per-node {args.num_gpus_per_node} "
f"--num-gpus-per-node {args.num_gpus_per_node} "
f"--rollout-num-gpus {args.num_gpus_per_node} "
)

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
f"{consistency_args} "
f"{optimizer_args} "
f"{grpo_args} "
f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} "
f"{perf_args} "
f"{misc_args} "
f"{args.extra_args} "
)

U.execute_train(
train_args=train_args,
num_gpus_per_node=args.num_gpus_per_node,
megatron_model_type=MEGATRON_MODEL_TYPE,
train_script="train.py",
megatron_path=args.megatron_path,
extra_env_vars={
"FLASHINFER_DISABLE_VERSION_CHECK": "1",
"PYTHONPATH": f"{args.megatron_path}",
},
)


@U.dataclass_cli
def main(args: ScriptArgs):
prepare(args)
execute(args)


if __name__ == "__main__":
typer.run(main)
3 changes: 3 additions & 0 deletions miles/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .qwen3_5 import convert_qwen3_5_to_hf
from .qwen3_next import convert_qwen3_next_to_hf
from .qwen3moe import convert_qwen3moe_to_hf
from .qwen3omni_moe import convert_qwen3omni_moe_to_hf


# TODO unify w/ `convert_to_hf`
Expand Down Expand Up @@ -42,6 +43,8 @@ def _convert_to_hf_core(args, model_name, name, param):
converted_named_tensors = convert_glm4moe_to_hf(args, name, param)
elif "glm4" in model_name:
converted_named_tensors = convert_glm4_to_hf(args, name, param)
elif "qwen3omni" in model_name: # body.* names; keep before generic qwen3
converted_named_tensors = convert_qwen3omni_moe_to_hf(args, name, param)
elif "qwen3moe" in model_name:
converted_named_tensors = convert_qwen3moe_to_hf(args, name, param)
elif "qwen3next" in model_name:
Expand Down
14 changes: 14 additions & 0 deletions miles/backends/megatron_utils/megatron_to_hf/qwen3omni_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Megatron->HF broadcast converter for the Qwen3-Omni thinker (Qwen3-MoE text).

Same conversion as qwen3moe, but prefixes every HF name with `body.` (the namespace
the sglang-omni weight-receive side demuxes on). Selected via `--model-name
qwen3omni_moe`. Nothing is dropped: the thinker is untied; `tied.*` is talker-side.
"""

from .qwen3moe import convert_qwen3moe_to_hf

BODY_PREFIX = "body."


def convert_qwen3omni_moe_to_hf(args, name, param):
return [(f"{BODY_PREFIX}{n}", t) for n, t in convert_qwen3moe_to_hf(args, name, param)]
62 changes: 62 additions & 0 deletions miles/rollout/generate_hub/omni_thinker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Rollout adapter for the sglang-omni `/generate` endpoint (thinker / text).

Thin wrapper over the default single-turn generate: reuses miles' payload builder and
response parser, adds only the omni-specific fields. Wire via
`--custom-generate-function-path miles.rollout.generate_hub.omni_thinker.generate`.
"""

from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput
from miles.rollout.generate_utils.generate_endpoint_utils import (
compute_prompt_ids_from_sample,
compute_request_payload,
update_sample_from_response,
)
from miles.utils.http_utils import post
from miles.utils.types import Sample


async def generate(input: GenerateFnInput) -> GenerateFnOutput:
args = input.args
sample = input.sample
sampling_params = input.sampling_params
assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}"
# omni /generate emits temp-1 (pre-temperature) full-vocab logprobs; the trainer recompute
# divides logits by rollout_temperature, so they agree only at temp=1.
assert args.rollout_temperature == 1.0, (
f"omni rollout logprob is temp-1; rollout_temperature must be 1.0, got {args.rollout_temperature}"
)
# text-only path with no MoE/indexer replay yet (server forward-declares both)
assert not (args.use_rollout_routing_replay or args.use_rollout_indexer_replay), (
"omni rollout has no routing/indexer replay; unset --use-rollout-routing-replay / --use-rollout-indexer-replay"
)
assert not (sample.multimodal_inputs or {}).get("images"), "omni thinker rollout is text-only"
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

prompt_ids = compute_prompt_ids_from_sample(input.state, sample)
if len(sample.response) > 0: # partial rollout resume
input_ids = sample.tokens
sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids)
assert sampling_params["max_new_tokens"] >= 0
if sampling_params["max_new_tokens"] == 0:
sample.status = Sample.Status.TRUNCATED
return GenerateFnOutput(samples=sample)
else:
input_ids = prompt_ids

payload, halt_status = compute_request_payload(
args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs
)
if payload is None:
sample.status = halt_status
return GenerateFnOutput(samples=sample)

payload["output_modalities"] = ["text"]
payload["return_omni_rollout"] = False
# rep_penalty=1: the trainer recompute can't replay a repetition penalty (logprobs would diverge)
payload["sampling_params"]["repetition_penalty"] = 1.0
if sample.metadata:
payload["metadata"] = sample.metadata

output = await post(url, payload)
await update_sample_from_response(args, sample, payload=payload, output=output)
return GenerateFnOutput(samples=sample)
52 changes: 52 additions & 0 deletions scripts/models/qwen3-omni-30B-A3B-thinker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Megatron model args for the Qwen3-Omni-30B-A3B thinker (text MoE).
# Architecturally Qwen3-30B-A3B; only vocab differs (152064 vs 151936, from the real config.json).

NLAYERS="${MODEL_ARGS_NUM_LAYERS:-48}"
FIRST_K_DENSE_REPLACE=0

arr=()
for ((i=0; i<NLAYERS; i++)); do
if (( i < FIRST_K_DENSE_REPLACE )); then
arr+=(0)
else
arr+=(1)
fi
done

printf -v MOE_LAYER_FREQ "[%s]" "$(IFS=', '; echo "${arr[*]}")"


MODEL_ARGS=(
--disable-bias-linear
--qk-layernorm
--group-query-attention
--num-attention-heads 32
--num-query-groups 4
--kv-channels 128
--num-layers $NLAYERS
--hidden-size 2048
--ffn-hidden-size 6144

--normalization RMSNorm
--position-embedding-type rope
--norm-epsilon 1e-6
--rotary-percent 1.0
--swiglu
--untie-embeddings-and-output-weights
--vocab-size "${MODEL_ARGS_VOCAB_SIZE:-152064}"

--rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"

# moe
--moe-ffn-hidden-size 768
--moe-router-score-function softmax
--moe-token-dispatcher-type alltoall
--moe-router-topk 8
--moe-layer-freq "$MOE_LAYER_FREQ"
--num-experts 128
--moe-grouped-gemm
--moe-token-drop-policy probs
--moe-router-dtype fp32
--moe-permute-fusion
--moe-aux-loss-coeff 0
)
Loading