diff --git a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py index 84cfcca9e..f1af52e85 100644 --- a/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/QEfficient/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -761,7 +761,7 @@ def _forward_expert_blocked( packed_chunk_size=self.expert_blocking_packed_chunk_size, ) - return expert_out.sum(dim=0) + return torch.einsum("nth->th", expert_out) def forward(self, hidden_states): residuals = hidden_states diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index b7f42c0c5..66a094544 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -200,7 +200,8 @@ def forward(self, hidden: torch.Tensor): packed_chunk_size=packed_chunk_size, ) - return expert_out.sum(dim=0).view(B, S, H), router_logits + expert_out_sum = torch.einsum("nth->th", expert_out) + return expert_out_sum.view(B, S, H), router_logits class QEffPrefillOnlyGptOssMLP(GptOssMLP): @@ -460,7 +461,7 @@ def forward_weights_as_activation(self, hidden_states): # Apply routing weights AFTER expert computation (This is before on Llama4) experts_out = experts_out * router_top_value.unsqueeze(-1) - experts_out = experts_out.sum(dim=1) + experts_out = torch.einsum("bnd->bd", experts_out) return experts_out, router_logits diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 8d67c2863..93ce6a685 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -260,7 +260,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens act_fn=act_fn, packed_chunk_size=packed_chunk_size, ) - return expert_out.sum(dim=0).view(B, S, H), router_logits + expert_out_sum = torch.einsum("nth->th", expert_out) + return expert_out_sum.view(B, S, H), router_logits class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): diff --git a/examples/disagg_serving/glm4_moe_disagg_mode_with_chunking.py b/examples/disagg_serving/glm4_moe_disagg_mode_with_chunking.py index 8f38f1f59..13df887b6 100644 --- a/examples/disagg_serving/glm4_moe_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/glm4_moe_disagg_mode_with_chunking.py @@ -50,7 +50,7 @@ mxfp6_matmul=True, mxint8_kv_cache=True, num_devices=1, - split_retained_state_io=True, + # split_retained_state_io=True, mos=1, aic_enable_depth_first=False, user_tiled=True, diff --git a/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py b/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py index 48de31241..92adbc69d 100644 --- a/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py @@ -19,7 +19,7 @@ subfunc_npi_file_path = os.path.join(dir_path, "subfunction_120b_npi.yaml") non_subfunc_npi_file_path = os.path.join(dir_path, "non_subfunction_120b_npi.yaml") -model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 +model_id = "tiny-random/gpt-oss-bf16" # weights are not required to convert to fp32 prompt = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -32,8 +32,8 @@ config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 512 -CTX_LEN = 8192 -NUM_CORES = 16 +CTX_LEN = 1024 +NUM_CORES = 4 MOE_PREFILL_PACKED_CHUNK_SIZE = 256 qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py index a1d0cdbf6..2a51e89f8 100644 --- a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py @@ -60,7 +60,7 @@ num_speculative_tokens=None, prefill_only=True, enable_chunking=True, - use_onnx_subfunctions=False, + use_onnx_subfunctions=True, ) diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 4b7ed6f17..8711878f0 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -105,6 +105,51 @@ TINY_SEQ_CLASSIFICATION_MODEL_ID = "ydshieh/tiny-random-BertForSequenceClassification" TINY_AWQ_MODEL_ID = "optimum-intel-internal-testing/tiny-mixtral-AWQ-4bit" +TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS = { + "glm4_moe": dict( + max_position_embeddings=128, + num_hidden_layers=1, + num_attention_heads=4, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, + n_routed_experts=4, + num_experts_per_tok=2, + first_k_dense_replace=0, + n_group=1, + topk_group=1, + head_dim=16, + ), + "qwen3_moe": dict( + max_position_embeddings=128, + num_hidden_layers=1, + num_attention_heads=4, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, + num_experts=4, + num_experts_per_tok=2, + ), + "gpt_oss": dict( + max_position_embeddings=128, + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=32, + intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, + num_local_experts=4, + num_experts_per_tok=2, + head_dim=16, + sliding_window=128, + rope_scaling=None, + ), +} + MODEL_KWARGS = {"attn_implementation": "eager"} PREFIX_CACHING_MODEL_ID = "hf-internal-testing/tiny-random-GPT2LMHeadModel" @@ -202,6 +247,30 @@ def _count_decoder_block_subfunctions(onnx_model, qeff_model) -> int: return sum(any(block_name in func.name for block_name in block_names) for func in onnx_model.functions) +def _decoder_block_subfunction_names(onnx_model, qeff_model) -> Set[str]: + get_submodules = getattr(qeff_model.model, "get_submodules_for_export", None) + assert callable(get_submodules) + + submodules = get_submodules() + assert submodules + + if not isinstance(submodules, (set, list, tuple)): + submodules = [submodules] + + block_names = {module.__name__ for module in submodules if hasattr(module, "__name__")} + assert block_names + return {func.name for func in onnx_model.functions if any(block_name in func.name for block_name in block_names)} + + +def _function_op_types(onnx_model, function_names: Set[str]) -> Set[str]: + return { + node.op_type + for function_proto in onnx_model.functions + if function_proto.name in function_names + for node in function_proto.node + } + + def _assert_has_retained_state_outputs(onnx_path: Path) -> None: onnx_model = onnx.load(onnx_path, load_external_data=False) retained_outputs = [output.name for output in onnx_model.graph.output if output.name.endswith("_RetainedState")] @@ -509,6 +578,42 @@ def test_causal_subfunction_export_smoke(tmp_path): assert not any("QEffGPT2Block" in name for name in without_names) +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("model_type", "config_kwargs"), + sorted(TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS.items()), + ids=sorted(TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS), +) +def test_moe_prefill_subfunction_export_uses_einsum_reductions(model_type, config_kwargs, tmp_path): + config = AutoConfig.for_model(model_type, **config_kwargs) + model_hf = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + model_hf.eval() + qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=False) + + onnx_path = _exported_onnx_path( + qeff_model.export( + tmp_path / f"{model_type}-prefill-subfunctions", + prefill_only=True, + enable_chunking=True, + prefill_seq_len=64, + num_cores=2, + moe_prefill_packed_chunk_size=32, + use_onnx_subfunctions=True, + offload_pt_weights=False, + ) + ) + + onnx_model = onnx.load(onnx_path, load_external_data=False) + decoder_function_names = _decoder_block_subfunction_names(onnx_model, qeff_model) + decoder_op_types = _function_op_types(onnx_model, decoder_function_names) + + assert len(decoder_function_names) == config.num_hidden_layers + assert "Einsum" in decoder_op_types + assert "CtxGather3D" in decoder_op_types + assert "CtxScatter3D" in decoder_op_types + assert "CtxScatter3DInt" in decoder_op_types + + @pytest.mark.llm_model @pytest.mark.parametrize( ("model_type", "model_id"),