diff --git a/examples/omni_thinker/run_qwen3_omni_thinker_grpo.py b/examples/omni_thinker/run_qwen3_omni_thinker_grpo.py new file mode 100644 index 0000000000..9023bc12d3 --- /dev/null +++ b/examples/omni_thinker/run_qwen3_omni_thinker_grpo.py @@ -0,0 +1,171 @@ +"""GRPO on the Qwen3-Omni-30B-A3B thinker (text MoE) with DAPO-math reward. + +Off-policy until live weight-sync lands: point --sglang-router-ip/port at a standalone omni +server; the rollout half serves the frozen base while the trainer updates its own copy, with +TIS + --get-mismatch-metrics absorbing/measuring the gap. +""" + +import os +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +OMNI_MODEL = "Qwen3-Omni-30B-A3B-Instruct" +THINKER_MODEL = "Qwen3-Omni-30B-A3B-Thinker" +MEGATRON_MODEL_TYPE = "qwen3-omni-30B-A3B-thinker" + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + num_gpus_per_node: int = 8 + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + omni_router_ip: str = "127.0.0.1" + omni_router_port: int = 30000 + extra_args: str = "" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command(f"hf download Qwen/{OMNI_MODEL} --local-dir {args.model_dir}/{OMNI_MODEL}") + repo = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + U.exec_command( + f"python {repo}/tools/extract_qwen3_omni_thinker.py " + f"--src {args.model_dir}/{OMNI_MODEL} --dst {args.model_dir}/{THINKER_MODEL}" + ) + U.convert_checkpoint( + model_name=THINKER_MODEL, + megatron_model_type=MEGATRON_MODEL_TYPE, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{THINKER_MODEL}", + megatron_path=args.megatron_path, + ) + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + + +def execute(args: ScriptArgs): + ref_load_path = f"{args.model_dir}/{THINKER_MODEL}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{THINKER_MODEL}/ " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + "--model-name qwen3omni_moe " # body.* broadcast naming + ) + + rollout_args = ( + "--custom-generate-function-path miles.rollout.generate_hub.omni_thinker.generate " + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type dapo " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + f"--sglang-router-ip {args.omni_router_ip} " + f"--sglang-router-port {args.omni_router_port} " + ) + + consistency_args = ( + "--use-rollout-logprobs " + "--get-mismatch-metrics " + "--use-tis " + "--tis-clip 2.0 " + ) + + perf_args = ( + "--tensor-model-parallel-size 8 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {args.num_gpus_per_node} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{consistency_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=MEGATRON_MODEL_TYPE, + train_script="train.py", + megatron_path=args.megatron_path, + extra_env_vars={ + "FLASHINFER_DISABLE_VERSION_CHECK": "1", + "PYTHONPATH": f"{args.megatron_path}", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/miles/backends/megatron_utils/megatron_to_hf/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/__init__.py index 0873e6f8e3..ba5d67d76d 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/__init__.py @@ -10,6 +10,7 @@ from .qwen3_5 import convert_qwen3_5_to_hf from .qwen3_next import convert_qwen3_next_to_hf from .qwen3moe import convert_qwen3moe_to_hf +from .qwen3omni_moe import convert_qwen3omni_moe_to_hf # TODO unify w/ `convert_to_hf` @@ -42,6 +43,8 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_glm4moe_to_hf(args, name, param) elif "glm4" in model_name: converted_named_tensors = convert_glm4_to_hf(args, name, param) + elif "qwen3omni" in model_name: # body.* names; keep before generic qwen3 + converted_named_tensors = convert_qwen3omni_moe_to_hf(args, name, param) elif "qwen3moe" in model_name: converted_named_tensors = convert_qwen3moe_to_hf(args, name, param) elif "qwen3next" in model_name: diff --git a/miles/backends/megatron_utils/megatron_to_hf/qwen3omni_moe.py b/miles/backends/megatron_utils/megatron_to_hf/qwen3omni_moe.py new file mode 100644 index 0000000000..b33b52b8a1 --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/qwen3omni_moe.py @@ -0,0 +1,14 @@ +"""Megatron->HF broadcast converter for the Qwen3-Omni thinker (Qwen3-MoE text). + +Same conversion as qwen3moe, but prefixes every HF name with `body.` (the namespace +the sglang-omni weight-receive side demuxes on). Selected via `--model-name +qwen3omni_moe`. Nothing is dropped: the thinker is untied; `tied.*` is talker-side. +""" + +from .qwen3moe import convert_qwen3moe_to_hf + +BODY_PREFIX = "body." + + +def convert_qwen3omni_moe_to_hf(args, name, param): + return [(f"{BODY_PREFIX}{n}", t) for n, t in convert_qwen3moe_to_hf(args, name, param)] diff --git a/miles/rollout/generate_hub/omni_thinker.py b/miles/rollout/generate_hub/omni_thinker.py new file mode 100644 index 0000000000..f49ec5ea5d --- /dev/null +++ b/miles/rollout/generate_hub/omni_thinker.py @@ -0,0 +1,62 @@ +"""Rollout adapter for the sglang-omni `/generate` endpoint (thinker / text). + +Thin wrapper over the default single-turn generate: reuses miles' payload builder and +response parser, adds only the omni-specific fields. Wire via +`--custom-generate-function-path miles.rollout.generate_hub.omni_thinker.generate`. +""" + +from miles.rollout.base_types import GenerateFnInput, GenerateFnOutput +from miles.rollout.generate_utils.generate_endpoint_utils import ( + compute_prompt_ids_from_sample, + compute_request_payload, + update_sample_from_response, +) +from miles.utils.http_utils import post +from miles.utils.types import Sample + + +async def generate(input: GenerateFnInput) -> GenerateFnOutput: + args = input.args + sample = input.sample + sampling_params = input.sampling_params + assert sample.status in {Sample.Status.PENDING, Sample.Status.ABORTED}, f"{sample.status=}" + # omni /generate emits temp-1 (pre-temperature) full-vocab logprobs; the trainer recompute + # divides logits by rollout_temperature, so they agree only at temp=1. + assert args.rollout_temperature == 1.0, ( + f"omni rollout logprob is temp-1; rollout_temperature must be 1.0, got {args.rollout_temperature}" + ) + # text-only path with no MoE/indexer replay yet (server forward-declares both) + assert not (args.use_rollout_routing_replay or args.use_rollout_indexer_replay), ( + "omni rollout has no routing/indexer replay; unset --use-rollout-routing-replay / --use-rollout-indexer-replay" + ) + assert not (sample.multimodal_inputs or {}).get("images"), "omni thinker rollout is text-only" + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + prompt_ids = compute_prompt_ids_from_sample(input.state, sample) + if len(sample.response) > 0: # partial rollout resume + input_ids = sample.tokens + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + assert sampling_params["max_new_tokens"] >= 0 + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return GenerateFnOutput(samples=sample) + else: + input_ids = prompt_ids + + payload, halt_status = compute_request_payload( + args, input_ids=input_ids, sampling_params=sampling_params, multimodal_inputs=sample.multimodal_inputs + ) + if payload is None: + sample.status = halt_status + return GenerateFnOutput(samples=sample) + + payload["output_modalities"] = ["text"] + payload["return_omni_rollout"] = False + # rep_penalty=1: the trainer recompute can't replay a repetition penalty (logprobs would diverge) + payload["sampling_params"]["repetition_penalty"] = 1.0 + if sample.metadata: + payload["metadata"] = sample.metadata + + output = await post(url, payload) + await update_sample_from_response(args, sample, payload=payload, output=output) + return GenerateFnOutput(samples=sample) diff --git a/scripts/models/qwen3-omni-30B-A3B-thinker.sh b/scripts/models/qwen3-omni-30B-A3B-thinker.sh new file mode 100644 index 0000000000..a071a2e6e3 --- /dev/null +++ b/scripts/models/qwen3-omni-30B-A3B-thinker.sh @@ -0,0 +1,52 @@ +# Megatron model args for the Qwen3-Omni-30B-A3B thinker (text MoE). +# Architecturally Qwen3-30B-A3B; only vocab differs (152064 vs 151936, from the real config.json). + +NLAYERS="${MODEL_ARGS_NUM_LAYERS:-48}" +FIRST_K_DENSE_REPLACE=0 + +arr=() +for ((i=0; i q(8,8) k(4,8) v(4,8) + out = conv(args, "module.module.decoder.layers.1.self_attention.linear_qkv.weight", torch.zeros(16, 8)) + assert {n: tuple(t.shape) for n, t in out} == { + "body.model.layers.1.self_attn.q_proj.weight": (8, 8), + "body.model.layers.1.self_attn.k_proj.weight": (4, 8), + "body.model.layers.1.self_attn.v_proj.weight": (4, 8), + } + # expert fc1: gate+up fused on dim 0 -> chunk(2) + out = conv(args, "module.module.decoder.layers.1.mlp.experts.linear_fc1.weight5", torch.zeros(8, 8)) + assert {n: tuple(t.shape) for n, t in out} == { + "body.model.layers.1.mlp.experts.5.gate_proj.weight": (4, 8), + "body.model.layers.1.mlp.experts.5.up_proj.weight": (4, 8), + } + out = conv(args, "module.module.decoder.layers.1.mlp.experts.linear_fc2.weight5", torch.zeros(8, 4)) + assert [n for n, _ in out] == ["body.model.layers.1.mlp.experts.5.down_proj.weight"] + + +def test_body_converter_dispatch_routes_qwen3omni(): + torch = pytest.importorskip("torch") + mod = pytest.importorskip("miles.backends.megatron_utils.megatron_to_hf") + out = mod._convert_to_hf_core( + _args(), "qwen3omni_moe", "module.module.decoder.layers.0.mlp.router.weight", torch.zeros(128, 8) + ) + assert [n for n, _ in out] == ["body.model.layers.0.mlp.gate.weight"] + + +# --- rollout adapter (miles.rollout.generate_hub.omni_thinker.generate) --- + + +def _omni_mod(): + return pytest.importorskip("miles.rollout.generate_hub.omni_thinker") + + +def _omni_input(monkeypatch, omni, *, replay=False, indexer_replay=False, multimodal=None, metadata=None, temperature=1.0): + """A generate() input with post/prompt-id mocked; returns (input, sample, captured).""" + Sample = pytest.importorskip("miles.utils.types").Sample + captured = {} + + async def fake_post(url, payload): + captured["url"], captured["payload"] = url, payload + return { + "text": "hello", + "meta_info": { + "finish_reason": {"type": "stop"}, + "output_token_logprobs": [[-0.1, 11], [-0.2, 22]], + "prompt_tokens": 3, + "completion_tokens": 2, + "cached_tokens": 0, + }, + } + + monkeypatch.setattr(omni, "post", fake_post) + monkeypatch.setattr(omni, "compute_prompt_ids_from_sample", lambda state, sample: [1, 2, 3]) + + args = SimpleNamespace( + sglang_router_ip="127.0.0.1", + sglang_router_port=30000, + rollout_temperature=temperature, + use_rollout_routing_replay=replay, + use_rollout_indexer_replay=indexer_replay, + rollout_max_response_len=128, + rollout_max_context_len=0, + sglang_speculative_algorithm=None, + ) + sample = Sample(status=Sample.Status.PENDING, metadata=metadata or {}, multimodal_inputs=multimodal) + inp = SimpleNamespace( + args=args, sample=sample, sampling_params={"temperature": 1.0, "max_new_tokens": 64}, state=None + ) + return inp, sample, captured + + +def test_generate_builds_text_only_omni_payload_and_parses_response(monkeypatch): + omni = _omni_mod() + inp, sample, captured = _omni_input(monkeypatch, omni, metadata={"task": "math"}) + + asyncio.run(omni.generate(inp)) + + p = captured["payload"] + assert captured["url"].endswith("/generate") + assert p["input_ids"] == [1, 2, 3] + assert p["output_modalities"] == ["text"] + assert p["return_omni_rollout"] is False + assert p["return_routed_experts"] is False + assert p["return_indexer_topk"] is False + assert p["sampling_params"]["repetition_penalty"] == 1.0 + assert p["metadata"] == {"task": "math"} + # response parse: tokens = input_ids + decoded ids, logprobs aligned, status from finish_reason + assert sample.tokens == [1, 2, 3, 11, 22] + assert sample.rollout_log_probs == [-0.1, -0.2] + assert sample.response == "hello" + assert sample.status.name == "COMPLETED" + + +def test_generate_rejects_moe_replay_for_omni(monkeypatch): + omni = _omni_mod() + inp, _, _ = _omni_input(monkeypatch, omni, replay=True) + with pytest.raises(AssertionError, match="replay"): + asyncio.run(omni.generate(inp)) + + +def test_generate_rejects_indexer_replay_for_omni(monkeypatch): + omni = _omni_mod() + inp, _, _ = _omni_input(monkeypatch, omni, indexer_replay=True) + with pytest.raises(AssertionError, match="replay"): + asyncio.run(omni.generate(inp)) + + +def test_generate_rejects_multimodal_on_text_only_path(monkeypatch): + omni = _omni_mod() + inp, _, _ = _omni_input(monkeypatch, omni, multimodal={"images": ["img"]}) + with pytest.raises(AssertionError, match="text-only"): + asyncio.run(omni.generate(inp)) + + +def test_generate_rejects_nonunit_temperature(monkeypatch): + # omni emits temp-1 logprobs; the trainer recompute divides by rollout_temperature, so temp!=1 desyncs + omni = _omni_mod() + inp, _, _ = _omni_input(monkeypatch, omni, temperature=2.0) + with pytest.raises(AssertionError, match="temp-1"): + asyncio.run(omni.generate(inp)) diff --git a/tools/extract_qwen3_omni_thinker.py b/tools/extract_qwen3_omni_thinker.py new file mode 100644 index 0000000000..71c1d91a49 --- /dev/null +++ b/tools/extract_qwen3_omni_thinker.py @@ -0,0 +1,120 @@ +"""Extract a standalone Qwen3-MoE thinker checkpoint from a full Qwen3-Omni model. + +miles loads HF via AutoBridge keyed on model_type; the composite omni checkpoint has +no bridge, but the thinker text backbone is a plain Qwen3-MoE. This writes a +self-contained HF dir (thinker text + lm_head, renamed, model_type=qwen3_moe) for the +existing Qwen3MoEBridge. + + python tools/extract_qwen3_omni_thinker.py --src --dst +""" + +from __future__ import annotations + +_NON_TEXT_THINKER_SUBMODULES = ("audio_tower.", "visual.", "model.audio_tower.", "model.visual.") + + +def map_thinker_param_name(name: str) -> str | None: + """Full omni param name -> standalone Qwen3-MoE name, or None to drop.""" + if not name.startswith("thinker."): + return None + rest = name[len("thinker.") :] + if rest.startswith(_NON_TEXT_THINKER_SUBMODULES): + return None + return rest.replace("model.language_model.", "model.") + + +def synthesize_thinker_config(omni_config: dict) -> dict: + """Plain Qwen3-MoE config dict from the full omni config.""" + thinker = omni_config.get("thinker_config", omni_config) + cfg = dict(thinker.get("text_config", thinker)) + cfg["model_type"] = "qwen3_moe" + cfg["architectures"] = ["Qwen3MoeForCausalLM"] + cfg.setdefault("tie_word_embeddings", False) + for k in ("bos_token_id", "eos_token_id", "pad_token_id"): + if k not in cfg and k in thinker: + cfg[k] = thinker[k] + return cfg + + +def main() -> None: + import argparse + import json + import shutil + from pathlib import Path + + from safetensors import safe_open + from safetensors.torch import save_file + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--src", required=True) + parser.add_argument("--dst", required=True) + parser.add_argument("--shard-size-gb", type=float, default=5.0) + args = parser.parse_args() + + src, dst = Path(args.src), Path(args.dst) + dst.mkdir(parents=True, exist_ok=True) + + with open(src / "config.json") as f: + omni_config = json.load(f) + with open(dst / "config.json", "w") as f: + json.dump(synthesize_thinker_config(omni_config), f, indent=2) + + for fname in ( + "tokenizer.json", "tokenizer_config.json", "vocab.json", "merges.txt", + "special_tokens_map.json", "generation_config.json", "chat_template.json", "chat_template.jinja", + ): + if (src / fname).exists(): + shutil.copy2(src / fname, dst / fname) + + index_path = src / "model.safetensors.index.json" + if index_path.exists(): + with open(index_path) as f: + shard_files = sorted(set(json.load(f)["weight_map"].values())) + else: + shard_files = ["model.safetensors"] + + shard_size_bytes = int(args.shard_size_gb * (1024 ** 3)) + out_index: dict[str, str] = {} + out_shards: list[tuple[str, dict]] = [] + buf: dict = {} + buf_bytes = 0 + total_kept = 0 + + def flush(): + nonlocal buf, buf_bytes + if not buf: + return + shard_name = f"model-{len(out_shards) + 1:05d}.safetensors" + out_shards.append((shard_name, buf)) + for k in buf: + out_index[k] = shard_name + buf, buf_bytes = {}, 0 + + for shard_file in shard_files: + with safe_open(src / shard_file, framework="pt") as reader: + for key in reader.keys(): + new_key = map_thinker_param_name(key) + if new_key is None: + continue + tensor = reader.get_tensor(key) + buf[new_key] = tensor + buf_bytes += tensor.numel() * tensor.element_size() + total_kept += 1 + if buf_bytes >= shard_size_bytes: + flush() + flush() + + if len(out_shards) == 1: + save_file(out_shards[0][1], dst / "model.safetensors", metadata={"format": "pt"}) + else: + total_size = sum(t.numel() * t.element_size() for _, ts in out_shards for t in ts.values()) + for shard_name, tensors in out_shards: + save_file(tensors, dst / shard_name, metadata={"format": "pt"}) + with open(dst / "model.safetensors.index.json", "w") as f: + json.dump({"metadata": {"total_size": total_size}, "weight_map": out_index}, f, indent=2) + + print(f"[done] {total_kept} thinker tensors -> {dst}") + + +if __name__ == "__main__": + main()