From 80142ca6b0bb7d8da531f07f34dbb85ecccdb5fb Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 02:24:04 +0530 Subject: [PATCH 01/16] feat: NSP-blocked MoE prefill dispatch for Qwen3MOE and GPT-OSS Add expert-blocked NSP-parallel prefill forward to QEffPrefillChunkedQwen3MoeSparseMoeBlock and QEffPrefillOnlyChunkedGptOssMLP. Controlled via EXPERT_BLOCKING_NUM_NSP env var. Fix CtxScatterFunc3D/CtxGatherFunc3D eager forward for INT32_MAX sentinel handling. Add disagg-mode tests for both models with tiny configs. Signed-off-by: vbaddi --- QEfficient/customop/ctx_scatter_gather.py | 4 +- .../models/gpt_oss/modeling_gpt_oss.py | 89 ++++++++++++ .../models/qwen3_moe/modeling_qwen3_moe.py | 129 +++++++++++++++-- .../models/test_moe_prefill_blocked.py | 132 ++++++++++++++++++ 4 files changed, 340 insertions(+), 14 deletions(-) create mode 100644 tests/transformers/models/test_moe_prefill_blocked.py diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 59bfe6af03..bc87757079 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -78,8 +78,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() batch_idx = torch.arange(data.shape[0]).view(-1, 1) - ctx_idx = position_ids + ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) data[batch_idx, ctx_idx] = updates return data @@ -103,6 +104,7 @@ class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f805bfd4c..5707dc2091 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -36,6 +36,7 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -50,7 +51,91 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) + + +def _ctx_scatter_gather_gptoss_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + b_g: torch.Tensor, + b_u: torch.Tensor, + b_d: torch.Tensor, + limit: float, + alpha: float, + T: int, +) -> torch.Tensor: + """Packed-prefix expert helper for GPT-OSS NSP-blocked dispatch.""" + batch_size, hidden_size = T2Ei.shape[0], x.shape[1] + scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) + invalid_mask = ~T2Ei + INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + + x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) + x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) + valid_output_rows = row_range < valid_rows + x_prime = torch.where(valid_output_rows.unsqueeze(-1), x_prime, torch.zeros_like(x_prime)) + + gate = (x_prime @ W_g) + b_g.unsqueeze(1) + up = (x_prime @ W_u) + b_u.unsqueeze(1) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + intermediate = (up + 1) * glu + down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) + down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + + gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) + delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) + return delta_out + + class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def __qeff_init__(self): + pass + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + E = self.experts.num_experts + if E % num_nsp != 0: + raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") + local_experts = E // num_nsp + I = self.experts.gate_proj.shape[2] # noqa: E741 + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :].unsqueeze(-1) + T2Ei = routing_weight.squeeze(-1) > 0 + delta = _ctx_scatter_gather_gptoss_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + b_g=b_g[:, slot], + b_u=b_u[:, slot], + b_d=b_d[:, slot], + limit=self.experts.limit, + alpha=self.experts.alpha, + T=T, + ) + expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S @@ -69,6 +154,10 @@ def forward(self, hidden: torch.Tensor): # Routing weights for each expert [T, E] routing_weights = masked_logits + if self.experts.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + # ────────────────── allocate the output tensor ───── expert_out = hidden.new_zeros((T, H)) # accumulation buffer diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index de92eae8f7..415c1c396f 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import os from typing import List, Optional, Tuple, Type import torch @@ -32,6 +33,7 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -100,8 +102,85 @@ def eager_attention_forward( return attn_output, attn_weights +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) + + +def _ctx_scatter_gather_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + act_fn, + T: int, +) -> torch.Tensor: + """Packed-prefix expert helper for NSP-blocked dispatch.""" + batch_size, hidden_size = T2Ei.shape[0], x.shape[1] + scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) + invalid_mask = ~T2Ei + INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + + x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) + x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + + gate_prime = x_prime @ W_g + up_prime = x_prime @ W_u + down_prime = (up_prime * act_fn(gate_prime)) @ W_d + + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) + down_prime = torch.where((row_range < valid_rows).unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + + gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) + delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) + return delta_out + + class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :].unsqueeze(-1) + T2Ei = routing_weight.squeeze(-1) > 0 + delta = _ctx_scatter_gather_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + act_fn=self.experts[0].act_fn, + T=T, + ) + expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) + + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) @@ -113,20 +192,44 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = top_w.to(hidden_states.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── - expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── + expert_out = x.new_zeros((T, H)) + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight + return expert_out.view(B, S, H), router_logits + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + expert_out = x.new_zeros((T, H)) for e in range(self.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] - W_d = self.experts[e].down_proj.weight.T # [I, H] - gate = x @ W_g # [T, I] - up = x @ W_u # [T, I] - down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] - masked_down = down * routing_weight - expert_out += masked_down + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight return expert_out.view(B, S, H), router_logits diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py new file mode 100644 index 0000000000..ca22975433 --- /dev/null +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -0,0 +1,132 @@ +""" +Tests for NSP-blocked MoE prefill dispatch (Qwen3MOE + GPT-OSS). +Uses EXPERT_BLOCKING_NUM_NSP=2 so tests run fast on any num_experts. +Covers: parity, decode export, prefill+chunking export (disagg mode). +""" + +import os + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +os.environ.setdefault("EXPERT_BLOCKING_NUM_NSP", "2") + +from QEfficient import QEFFAutoModelForCausalLM + +MODEL_KWARGS = {"attn_implementation": "eager"} + +QWEN3_MOE_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=4, + hidden_size=128, + intermediate_size=512, + vocab_size=127, + num_key_value_heads=2, +) +GPTOSS_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=32, + intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, +) + + +# ── Qwen3MOE ────────────────────────────────────────────────────────────────── + + +def test_qwen3moe_blocked_forward_parity(): + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, + ) + + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks = [ + m + for _, m in model.named_modules() + if hasattr(m, "experts") and hasattr(m, "gate") and hasattr(m, "num_experts") + ] + assert blocks + + block = blocks[0] + chunked = QEffPrefillChunkedQwen3MoeSparseMoeBlock.__new__(QEffPrefillChunkedQwen3MoeSparseMoeBlock) + chunked.__dict__.update(block.__dict__) + chunked.__class__ = QEffPrefillChunkedQwen3MoeSparseMoeBlock + chunked.__qeff_init__() + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = chunked.orig_forward(x) + blocked, _ = chunked.forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "Qwen3MOE parity failed" + + +def test_qwen3moe_decode_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_qwen3moe_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() + + +# ── GPT-OSS ─────────────────────────────────────────────────────────────────── + + +def test_gptoss_blocked_forward_parity(): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffPrefillOnlyChunkedGptOssMLP, + ) + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks_orig = [m for _, m in model.named_modules() if m.__class__.__name__ == "GptOssMLP"] + assert blocks_orig + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = blocks_orig[0].forward(x) + + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + PrefillOnlyChunkedTransform.apply(qeff.model) + + blocks_chunked = [m for _, m in qeff.model.named_modules() if isinstance(m, QEffPrefillOnlyChunkedGptOssMLP)] + assert blocks_chunked + + with torch.no_grad(): + blocked, _ = blocks_chunked[0].forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "GPT-OSS parity failed" + + +def test_gptoss_decode_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_gptoss_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() From a5bd93a48ba2bc55fad8538dfb6cae4b34fe9070 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 03:09:20 +0530 Subject: [PATCH 02/16] nit: weights re-route fixes Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 39 +++++++++------- .../models/qwen3_moe/modeling_qwen3_moe.py | 46 +++++++++++-------- .../models/test_moe_prefill_blocked.py | 7 +++ 3 files changed, 55 insertions(+), 37 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5707dc2091..5e4d1547aa 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -99,23 +99,28 @@ def _ctx_scatter_gather_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): def __qeff_init__(self): - pass - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP E = self.experts.num_experts if E % num_nsp != 0: raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") local_experts = E // num_nsp - I = self.experts.gate_proj.shape[2] # noqa: E741 + H = self.experts.hidden_size + I = self.experts.expert_dim # noqa: E741 + with torch.no_grad(): + self._blocked_W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + self._blocked_b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + self._blocked_b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + self._blocked_b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + self._blocked_num_nsp = num_nsp + self._blocked_local_experts = local_experts + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = self._blocked_num_nsp + local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -123,12 +128,12 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], - b_g=b_g[:, slot], - b_u=b_u[:, slot], - b_d=b_d[:, slot], + W_g=self._blocked_W_g[:, slot], + W_u=self._blocked_W_u[:, slot], + W_d=self._blocked_W_d[:, slot], + b_g=self._blocked_b_g[:, slot], + b_u=self._blocked_b_u[:, slot], + b_d=self._blocked_b_d[:, slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 415c1c396f..13122dfb19 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -140,30 +140,36 @@ def _ctx_scatter_gather_expert_blocked( class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if self.num_experts % num_nsp != 0: raise ValueError( f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" ) local_experts = self.num_experts // num_nsp + gate_proj_w = [] + up_proj_w = [] + down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + gate_proj_w.append(self.experts[e].gate_proj.weight.T) + up_proj_w.append(self.experts[e].up_proj.weight.T) + down_proj_w.append(self.experts[e].down_proj.weight.T) + stacked_g = torch.stack(gate_proj_w) # [E, H, I] + stacked_u = torch.stack(up_proj_w) + stacked_d = torch.stack(down_proj_w) # [E, I, H] + H = stacked_g.shape[1] + I = stacked_g.shape[2] # noqa: E741 + self._blocked_W_g = stacked_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_u = stacked_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_d = stacked_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + self._blocked_num_nsp = num_nsp + self._blocked_local_experts = local_experts + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = self._blocked_num_nsp + local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -171,9 +177,9 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], + W_g=self._blocked_W_g[:, slot], + W_u=self._blocked_W_u[:, slot], + W_d=self._blocked_W_d[:, slot], act_fn=self.experts[0].act_fn, T=T, ) diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py index ca22975433..f7789707e8 100644 --- a/tests/transformers/models/test_moe_prefill_blocked.py +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + """ Tests for NSP-blocked MoE prefill dispatch (Qwen3MOE + GPT-OSS). Uses EXPERT_BLOCKING_NUM_NSP=2 so tests run fast on any num_experts. From c4ef4c847b37ee77cee8453bb4cbdbebc7149754 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 03:21:45 +0530 Subject: [PATCH 03/16] nit: weights re-route fixes v1 Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 36 ++++++--------- .../models/qwen3_moe/modeling_qwen3_moe.py | 46 ++++++++----------- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e4d1547aa..84eb4acac7 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -98,29 +98,21 @@ def _ctx_scatter_gather_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): - def __qeff_init__(self): + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP E = self.experts.num_experts if E % num_nsp != 0: raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") local_experts = E // num_nsp - H = self.experts.hidden_size I = self.experts.expert_dim # noqa: E741 - with torch.no_grad(): - self._blocked_W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - self._blocked_b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - self._blocked_b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - self._blocked_b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() - self._blocked_num_nsp = num_nsp - self._blocked_local_experts = local_experts - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape - num_nsp = self._blocked_num_nsp - local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -128,12 +120,12 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=self._blocked_W_g[:, slot], - W_u=self._blocked_W_u[:, slot], - W_d=self._blocked_W_d[:, slot], - b_g=self._blocked_b_g[:, slot], - b_u=self._blocked_b_u[:, slot], - b_d=self._blocked_b_d[:, slot], + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + b_g=b_g[:, slot], + b_u=b_u[:, slot], + b_d=b_d[:, slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 13122dfb19..e233e0e837 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -140,36 +140,30 @@ def _ctx_scatter_gather_expert_blocked( class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] + self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] + self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if self.num_experts % num_nsp != 0: raise ValueError( f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" ) local_experts = self.num_experts // num_nsp - gate_proj_w = [] - up_proj_w = [] - down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - gate_proj_w.append(self.experts[e].gate_proj.weight.T) - up_proj_w.append(self.experts[e].up_proj.weight.T) - down_proj_w.append(self.experts[e].down_proj.weight.T) - stacked_g = torch.stack(gate_proj_w) # [E, H, I] - stacked_u = torch.stack(up_proj_w) - stacked_d = torch.stack(down_proj_w) # [E, I, H] - H = stacked_g.shape[1] - I = stacked_g.shape[2] # noqa: E741 - self._blocked_W_g = stacked_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_u = stacked_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_d = stacked_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - self._blocked_num_nsp = num_nsp - self._blocked_local_experts = local_experts - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape - num_nsp = self._blocked_num_nsp - local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -177,9 +171,9 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_expert_blocked( x=x, T2Ei=T2Ei, - W_g=self._blocked_W_g[:, slot], - W_u=self._blocked_W_u[:, slot], - W_d=self._blocked_W_d[:, slot], + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], act_fn=self.experts[0].act_fn, T=T, ) From 290839e6855a52ba4f921eeff734b95615c11dfe Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 23 Apr 2026 22:25:50 +0530 Subject: [PATCH 04/16] nit(0423): gpt oss moe fixed and nit Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 4 +- QEfficient/transformers/modeling_utils.py | 2 +- .../models/gpt_oss/modeling_gpt_oss.py | 50 ++++++++++--------- .../qwen3moe_disagg_mode_with_chunking.py | 9 ++-- 4 files changed, 35 insertions(+), 30 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e9213761d9..a091de7497 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -530,8 +530,8 @@ def _compile( onnx_path = Path( onnx_path if onnx_path - else self.onnx_path - if self.onnx_path + # else self.onnx_path + # if self.onnx_path else self.get_onnx_path( prefill_only, enable_chunking, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f9d7fe62cd..183c19f6f6 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 84eb4acac7..1248e20bab 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -67,12 +67,11 @@ def _ctx_scatter_gather_gptoss_expert_blocked( alpha: float, T: int, ) -> torch.Tensor: - """Packed-prefix expert helper for GPT-OSS NSP-blocked dispatch.""" batch_size, hidden_size = T2Ei.shape[0], x.shape[1] scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) invalid_mask = ~T2Ei - INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + int32_max = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, int32_max, scatter_idx) x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) @@ -91,7 +90,7 @@ def _ctx_scatter_gather_gptoss_expert_blocked( down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) - gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + gather_idx = torch.where(invalid_mask, int32_max, scatter_idx) delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) return delta_out @@ -101,36 +100,41 @@ class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP - E = self.experts.num_experts - if E % num_nsp != 0: - raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") - local_experts = E // num_nsp - I = self.experts.expert_dim # noqa: E741 - rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + num_experts = self.experts.num_experts + if num_experts % num_nsp != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") + + local_experts = num_experts // num_nsp + expert_dim = self.experts.expert_dim + routing_weights_by_expert = ( + routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + ) + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, expert_dim, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) - for slot in range(local_experts): - routing_weight = rw[:, slot, :].unsqueeze(-1) + for local_slot in range(local_experts): + routing_weight = routing_weights_by_expert[:, local_slot, :].unsqueeze(-1) T2Ei = routing_weight.squeeze(-1) > 0 delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], - b_g=b_g[:, slot], - b_u=b_u[:, slot], - b_d=b_d[:, slot], + W_g=W_g[:, local_slot], + W_u=W_u[:, local_slot], + W_d=W_d[:, local_slot], + b_g=b_g[:, local_slot], + b_u=b_u[:, local_slot], + b_d=b_d[:, local_slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, ) expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) def forward(self, hidden: torch.Tensor): diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py index 655de4ef51..3bc9339091 100644 --- a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py @@ -14,14 +14,15 @@ from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +model_id = "yujiepan/qwen3-moe-tiny-random" prompt = """ Explain quantum computing in simple terms. """ config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) -PREFILL_SEQ_LEN = 128 -CTX_LEN = 128 * 3 +PREFILL_SEQ_LEN = 256 +CTX_LEN = PREFILL_SEQ_LEN * 3 qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) decode_qpc_path = qeff_model.compile( @@ -48,7 +49,7 @@ num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, - num_devices=2, + num_devices=1, split_retained_state_io=True, mos=1, aic_enable_depth_first=True, From 28048519ae21c8fa106af11562a82322acda91aa Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 24 Apr 2026 19:27:10 +0530 Subject: [PATCH 05/16] nit(0424): ctx batch idx cast to int32 Signed-off-by: vbaddi --- QEfficient/customop/ctx_scatter_gather.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index bc87757079..4f46791af8 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates # Create indices batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + + # keep index tensor types aligned for backend that require exact dtype match + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) indices = ops.Concat(batch_idx, ctx_idx, axis=2) From 6b049bcd883c982a8638838194a41155fde3e716 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 30 Apr 2026 07:11:10 +0530 Subject: [PATCH 06/16] nit(0429): qwen3_moe, gpt_oss: port cumsum scatter-gather-update MoE prefill Signed-off-by: vbaddi --- QEfficient/customop/__init__.py | 6 + QEfficient/customop/ctx_scatter_gather.py | 93 +++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 122 +++++++++++++----- .../models/qwen3_moe/modeling_qwen3_moe.py | 112 ++++++++++++---- 4 files changed, 276 insertions(+), 57 deletions(-) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index 35830aa91e..4830e660c3 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -8,9 +8,12 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncBlockedKVCB, @@ -26,7 +29,10 @@ "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", + "CtxGatherFunc3DGeneralized", "CtxScatterFunc3D", + "CtxScatterFunc3DGeneralized", + "CtxScatterFunc3DInt", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 4f46791af8..19f60886de 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -96,6 +96,74 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) +class CtxScatterFunc3DGeneralized(torch.autograd.Function): + """Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions. + + Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to + ``data.shape[1]-1`` (potentially clobbering valid content), this version + masks out invalid rows before scattering so ``data`` is left untouched where + ``position_ids == INT32_MAX``. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter3DInt( + data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32 +) -> onnxscript.INT32: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) + indices = ops.Concat(batch_idx, ctx_idx, axis=2) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc3DInt(torch.autograd.Function): + """Int32-typed scatter used to build a packed->original index table.""" + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) @@ -119,6 +187,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data) +class CtxGatherFunc3DGeneralized(torch.autograd.Function): + """Gather variant that tolerates INT32_MAX indices (invalid rows read from 0). + + Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but + exposed as a separate autograd op so callers using the packed/cumsum scatter + pipeline can be easily recognized and so the ONNX symbolic omits + ``setTypeAs`` (needed when the caller already has a matching dtype on + ``data`` and wants the op signature to flow through without dtype pinning). + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) + return data[batch_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather3D, data, ctx_indices) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather( data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 1248e20bab..5e0270b7be 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -36,7 +36,11 @@ generic_blocked_attention_interface, past_key_value_update, ) -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -52,9 +56,36 @@ def __qeff_init__(self): EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table for an NSP-sliced expert mask. -def _ctx_scatter_gather_gptoss_expert_blocked( + Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to + an expert, produces an index tensor where ``matched_idx[b, j]`` is the + original token position in ``x`` that lands at packed position ``j`` for + NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). + """ + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this + # can be switched back to ``torch.full_like(token_idx, int32_max)``. + matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_gptoss_expert_blocked( x: torch.Tensor, T2Ei: torch.Tensor, W_g: torch.Tensor, @@ -63,37 +94,64 @@ def _ctx_scatter_gather_gptoss_expert_blocked( b_g: torch.Tensor, b_u: torch.Tensor, b_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, limit: float, alpha: float, T: int, + packed_chunk_size: int, ) -> torch.Tensor: - batch_size, hidden_size = T2Ei.shape[0], x.shape[1] - scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) - invalid_mask = ~T2Ei - int32_max = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, int32_max, scatter_idx) - - x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) - x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + """Cumsum-scatter-gather-update expert helper for GPT-OSS NSP-blocked dispatch. + + Same algorithm as the Qwen3-MOE version but with GPT-OSS biases and GLU + activation (clamped gate/up, ``(up + 1) * gate * sigmoid(gate * alpha)``). + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + b_g, b_u : [num_nsp, I] + b_d : [num_nsp, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) - row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) - valid_output_rows = row_range < valid_rows - x_prime = torch.where(valid_output_rows.unsqueeze(-1), x_prime, torch.zeros_like(x_prime)) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for packed_start in range(0, seq_len, packed_chunk_size): + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] - gate = (x_prime @ W_g) + b_g.unsqueeze(1) - up = (x_prime @ W_u) + b_u.unsqueeze(1) - gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - intermediate = (up + 1) * glu - down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) - down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate = (x_chunk @ W_g) + b_g.unsqueeze(1) + up = (x_chunk @ W_u) + b_u.unsqueeze(1) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + intermediate = (up + 1) * glu + down_chunk = (intermediate @ W_d) + b_d.unsqueeze(1) + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - gather_idx = torch.where(invalid_mask, int32_max, scatter_idx) - delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) - delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) - return delta_out + return expert_out class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): @@ -116,11 +174,11 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() - expert_out_partial = x.new_zeros((num_nsp, T, H)) + expert_out = x.new_zeros((num_nsp, T, H)) for local_slot in range(local_experts): - routing_weight = routing_weights_by_expert[:, local_slot, :].unsqueeze(-1) - T2Ei = routing_weight.squeeze(-1) > 0 - delta = _ctx_scatter_gather_gptoss_expert_blocked( + routing_weight = routing_weights_by_expert[:, local_slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_gptoss_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, local_slot], @@ -129,13 +187,15 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor b_g=b_g[:, local_slot], b_u=b_u[:, local_slot], b_d=b_d[:, local_slot], + routing_weight=routing_weight, + expert_out=expert_out, limit=self.experts.limit, alpha=self.experts.alpha, T=T, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - expert_out_partial = expert_out_partial + (delta * routing_weight) - return expert_out_partial.sum(dim=0) + return expert_out.sum(dim=0) def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index e233e0e837..939d8faa93 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -33,7 +33,11 @@ generic_blocked_attention_interface, past_key_value_update, ) -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -103,39 +107,93 @@ def eager_attention_forward( EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table for an NSP-sliced expert mask. -def _ctx_scatter_gather_expert_blocked( + Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to + an expert, produces an index tensor where ``matched_idx[b, j]`` is the + original token position in ``x`` that lands at packed position ``j`` for + NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). + """ + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this + # can be switched back to ``torch.full_like(token_idx, int32_max)``. + matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_expert_blocked( x: torch.Tensor, T2Ei: torch.Tensor, W_g: torch.Tensor, W_u: torch.Tensor, W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, act_fn, T: int, + packed_chunk_size: int, ) -> torch.Tensor: - """Packed-prefix expert helper for NSP-blocked dispatch.""" - batch_size, hidden_size = T2Ei.shape[0], x.shape[1] - scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) - invalid_mask = ~T2Ei - INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. + + Accumulates one local expert's contribution in-place onto ``expert_out``. + Uses a packed/cumsum layout so the MLP runs only over active rows, then + scatters the weighted output back to original token positions. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) - x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) - x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) - gate_prime = x_prime @ W_g - up_prime = x_prime @ W_u - down_prime = (up_prime * act_fn(gate_prime)) @ W_d + for packed_start in range(0, seq_len, packed_chunk_size): + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) - row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) - down_prime = torch.where((row_range < valid_rows).unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = x_chunk @ W_g + up_prime = x_chunk @ W_u + down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) - delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) - delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) - return delta_out + return expert_out class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -164,21 +222,23 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() - expert_out_partial = x.new_zeros((num_nsp, T, H)) + expert_out = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): - routing_weight = rw[:, slot, :].unsqueeze(-1) - T2Ei = routing_weight.squeeze(-1) > 0 - delta = _ctx_scatter_gather_expert_blocked( + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, slot], W_u=W_u[:, slot], W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, act_fn=self.experts[0].act_fn, T=T, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - expert_out_partial = expert_out_partial + (delta * routing_weight) - return expert_out_partial.sum(dim=0) + return expert_out.sum(dim=0) def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape From 1ae7b239ed471740df783e37a77a7aeac3e0e86a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 30 Apr 2026 07:36:31 +0530 Subject: [PATCH 07/16] nit(0429): update modeling files Signed-off-by: vbaddi --- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 8 +------- .../transformers/models/qwen3_moe/modeling_qwen3_moe.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e0270b7be..53dd72193e 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -60,13 +60,7 @@ def __qeff_init__(self): def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index table for an NSP-sliced expert mask. - - Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to - an expert, produces an index tensor where ``matched_idx[b, j]`` is the - original token position in ``x`` that lands at packed position ``j`` for - NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). - """ + """Build packed->original token index""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 939d8faa93..942ebdc738 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -111,13 +111,7 @@ def eager_attention_forward( def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index table for an NSP-sliced expert mask. - - Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to - an expert, produces an index tensor where ``matched_idx[b, j]`` is the - original token position in ``x`` that lands at packed position ``j`` for - NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). - """ + """Build packed->original token index""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) From 96df4926a8ab2e55e84f344b2e96fc2789fd0790 Mon Sep 17 00:00:00 2001 From: vtirumal Date: Wed, 13 May 2026 19:27:45 +0530 Subject: [PATCH 08/16] Fix CtxGather3D packed-chunk shape expansion - Root cause: CtxGather3D ONNX symbolic expanded ctx_indices to Shape(data)[:2] ([batch, seq_len]), which is wrong for packed dispatch. - In expert-blocked MoE prefill, ctx_indices is intentionally [batch, packed_chunk_size] (e.g. [16, 256]) while data stays [batch, seq_len, ...] (e.g. [16, 512, ...]). - This caused invalid Expand attempts ([16,256] -> [16,512]) and QAIC compile/runtime failure on /model/layers.0/mlp/CtxGather3D/.... Fix: - Update CtxGather3D expand target to: - batch dim from data - index-seq dim from ctx_indices - New expand shape is [batch_size(data), idx_seq_len(ctx_indices)], preserving packed chunk length. Signed-off-by: vtirumal --- QEfficient/customop/ctx_scatter_gather.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 19f60886de..b1f322c606 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -166,7 +166,10 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) + batch_size = ops.Slice(ops.Shape(data), starts=[0], ends=[1], axes=[0]) + idx_seq_len = ops.Slice(ops.Shape(ctx_indices), starts=[1], ends=[2], axes=[0]) + expand_shape = ops.Concat(batch_size, idx_seq_len, axis=0) + ctx_indices = ops.Expand(ctx_indices, expand_shape) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=1) From a6191751bad69b5781921d3b3f8e91534ba150b4 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 13 May 2026 22:33:33 +0530 Subject: [PATCH 09/16] nit(0513): fix: register moe prefill 3d custom ops for subfunction export Add missing CustomOpTransform mappings for CtxScatterFunc3DInt and generalized 3D scatter/gather ops, plus a prefill-only subfunction export regression test to verify the ONNX graph includes the required CtxScatter3DInt/CtxScatter3D/CtxGather3D ops. Signed-off-by: vbaddi --- QEfficient/base/onnx_transforms.py | 7 +++++ .../models/test_moe_prefill_blocked.py | 28 +++++++++++++++++++ .../transforms/test_onnx_transforms.py | 18 ++++++++++++ 3 files changed, 53 insertions(+) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index c27e3cc704..dbbb51e5b2 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -21,11 +21,15 @@ CtxGatherBlockedKV, CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatter, CtxScatter3D, + CtxScatter3DInt, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherBlockedKVCB, @@ -92,8 +96,11 @@ class CustomOpTransform(BaseOnnxTransform): "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), "CtxScatterFunc": (CtxScatterFunc, CtxScatter), "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), + "CtxScatterFunc3DInt": (CtxScatterFunc3DInt, CtxScatter3DInt), + "CtxScatterFunc3DGeneralized": (CtxScatterFunc3DGeneralized, CtxScatter3D), "CtxGatherFunc": (CtxGatherFunc, CtxGather), "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), + "CtxGatherFunc3DGeneralized": (CtxGatherFunc3DGeneralized, CtxGather3D), "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py index f7789707e8..579393ad4a 100644 --- a/tests/transformers/models/test_moe_prefill_blocked.py +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -91,6 +91,34 @@ def test_qwen3moe_prefill_chunked_export(tmp_path): assert qeff.onnx_path.is_file() +def test_qwen3moe_prefill_chunked_subfunction_export_contains_cumsum_custom_ops(tmp_path): + import onnx + + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + onnx_path = qeff.export( + tmp_path / "prefill-subfunction", + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, + offload_pt_weights=False, + ) + + onnx_model = onnx.load(str(onnx_path), load_external_data=False) + function_names = {func.name for func in onnx_model.functions} + used_op_types = {node.op_type for node in onnx_model.graph.node} + for function_proto in onnx_model.functions: + used_op_types.update(node.op_type for node in function_proto.node) + + assert "CtxScatter3DInt" in function_names + assert "CtxScatter3D" in function_names + assert "CtxGather3D" in function_names + assert "CtxScatter3DInt" in used_op_types + assert "CtxScatter3D" in used_op_types + assert "CtxGather3D" in used_op_types + + # ── GPT-OSS ─────────────────────────────────────────────────────────────────── diff --git a/tests/unit_test/transforms/test_onnx_transforms.py b/tests/unit_test/transforms/test_onnx_transforms.py index 5a43b345d6..d8897364aa 100644 --- a/tests/unit_test/transforms/test_onnx_transforms.py +++ b/tests/unit_test/transforms/test_onnx_transforms.py @@ -531,6 +531,24 @@ def test_custom_op_transform_contains_ctx_gather(self): assert "CtxGatherFunc" in CustomOpTransform._custom_ops + def test_custom_op_transform_contains_ctx_scatter_3d_int(self): + """CustomOpTransform._custom_ops must contain 'CtxScatterFunc3DInt'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxScatterFunc3DInt" in CustomOpTransform._custom_ops + + def test_custom_op_transform_contains_ctx_scatter_3d_generalized(self): + """CustomOpTransform._custom_ops must contain 'CtxScatterFunc3DGeneralized'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxScatterFunc3DGeneralized" in CustomOpTransform._custom_ops + + def test_custom_op_transform_contains_ctx_gather_3d_generalized(self): + """CustomOpTransform._custom_ops must contain 'CtxGatherFunc3DGeneralized'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxGatherFunc3DGeneralized" in CustomOpTransform._custom_ops + def test_custom_op_transform_rms_norm_maps_to_custom_rms_norm(self): """CustomRMSNormFunc must map to CustomRMSNorm class.""" from QEfficient.base.onnx_transforms import CustomOpTransform From 27c0d288de3366bc2f94a8ce8dbde18120eec7bb Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 14 May 2026 23:11:25 +0530 Subject: [PATCH 10/16] fix(0415): fix: avoid unsupported prefill MoE reductions in subfunction export Replace MoE prefill sum reductions with equivalent einsum forms and rewrite int32 clamp bounds using where to avoid QAIC subfunction compile failures for GPT-OSS and Qwen3-MoE. Signed-off-by: vbaddi --- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 12 +++++++++--- .../models/qwen3_moe/modeling_qwen3_moe.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 53dd72193e..33f7a2d707 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -114,7 +114,7 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) matched_idx = _build_matched_idx_from_cumsum(T2Ei) - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) @@ -139,7 +139,13 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > packed_chunk_size, + torch.ones_like(chunk_valid_rows) * packed_chunk_size, + chunk_valid_rows, + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) @@ -189,7 +195,7 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - return expert_out.sum(dim=0) + return torch.einsum("ijk->jk", expert_out) def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 942ebdc738..5c22061256 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -160,7 +160,7 @@ def _cumsum_scatter_gather_update_expert_blocked( packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) matched_idx = _build_matched_idx_from_cumsum(T2Ei) - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) @@ -181,7 +181,13 @@ def _cumsum_scatter_gather_update_expert_blocked( expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > packed_chunk_size, + torch.ones_like(chunk_valid_rows) * packed_chunk_size, + chunk_valid_rows, + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) @@ -232,7 +238,7 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor T=T, packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - return expert_out.sum(dim=0) + return torch.einsum("ijk->jk", expert_out) def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape From e605480521dffca7df5f45afff1386daedde98dc Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 15 May 2026 21:28:45 +0530 Subject: [PATCH 11/16] fix(0415): align prefill MoE chunk export with packed dispatch Trace chunked prefill exports with the requested prefill_seq_len so packed MoE dispatch unrolls all packed chunks, restore torch.full_like index init, and add ONNX coverage for the second packed chunk slice. Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 4 +-- .../transformers/models/modeling_auto.py | 5 ++- .../models/qwen3_moe/modeling_qwen3_moe.py | 4 +-- .../models/test_moe_prefill_blocked.py | 35 +++++++++++++++++++ 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 33f7a2d707..c13b9c822a 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -68,9 +68,7 @@ def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) valid_dest = valid_prefix - 1 scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) - # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this - # can be switched back to ``torch.full_like(token_idx, int32_max)``. - matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = torch.full_like(token_idx, int32_max) matched_idx = CtxScatterFunc3DInt.apply( matched_idx.unsqueeze(-1), scatter_pos, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c5c50f1c7d..673d929331 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2991,7 +2991,9 @@ def get_seq_len_and_handle_specialized_prefill_model( self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True - return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + seq_len = max(prefill_seq_len or 0, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + self.hash_params["chunking_seq_len"] = seq_len + return seq_len num_q_blocks = ( self.hash_params["blocking_config"].num_q_blocks if self.hash_params.get("blocking_kwargs", None) else None @@ -3102,6 +3104,7 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) + self.hash_params.pop("chunking_seq_len", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 5c22061256..ee3273d081 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -119,9 +119,7 @@ def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) valid_dest = valid_prefix - 1 scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) - # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this - # can be switched back to ``torch.full_like(token_idx, int32_max)``. - matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = torch.full_like(token_idx, int32_max) matched_idx = CtxScatterFunc3DInt.apply( matched_idx.unsqueeze(-1), scatter_pos, diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py index 579393ad4a..1b115d79e1 100644 --- a/tests/transformers/models/test_moe_prefill_blocked.py +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -165,3 +165,38 @@ def test_gptoss_prefill_chunked_export(tmp_path): qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) assert qeff.onnx_path.is_file() + + +def test_gptoss_prefill_chunked_export_traces_packed_chunks(tmp_path): + import onnx + from onnx import numpy_helper + + config = AutoConfig.for_model("gpt_oss", **{**GPTOSS_CFG, "max_position_embeddings": 1024}) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + onnx_path = qeff.export( + tmp_path / "prefill-subfunction-512", + prefill_only=True, + enable_chunking=True, + prefill_seq_len=512, + use_onnx_subfunctions=True, + offload_pt_weights=False, + ) + + onnx_model = onnx.load(str(onnx_path), load_external_data=False) + slice_starts = [] + op_types = [] + for nodes in [onnx_model.graph.node] + [function.node for function in onnx_model.functions]: + constants = {} + for node in nodes: + op_types.append(node.op_type) + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value": + constants[node.output[0]] = numpy_helper.to_array(attr.t).flatten().tolist() + for node in nodes: + if node.op_type == "Slice" and len(node.input) > 1 and node.input[1] in constants: + slice_starts.append(constants[node.input[1]]) + + assert [256] in slice_starts + assert op_types.count("CtxGather3D") >= 2 * op_types.count("CtxScatter3DInt") From c76082e062f2c18db92fcb9f8b29db9d18bd6296 Mon Sep 17 00:00:00 2001 From: divytrip Date: Thu, 21 May 2026 14:08:51 +0530 Subject: [PATCH 12/16] feat: NSP-blocked MoE prefill for GPT-OSS, Qwen3-MoE and GraniteMoE - gpt_oss/modeling_gpt_oss.py: add num_expert_chunks dynamic-loop mechanism to _cumsum_scatter_gather_update_gptoss_expert_blocked and _forward_expert_blocked; read _num_expert_chunks from module in forward - qwen3_moe/modeling_qwen3_moe.py: same num_expert_chunks dynamic-loop mechanism; replace fallback inline loop with orig_forward call - granitemoe/modeling_granitemoe.py: add QEffPrefillChunkedGraniteMoeAttention, QEffPrefillChunkedGraniteMoeMoE with full ONNX-friendly cumsum-scatter-gather dispatch; update get_submodules_for_export to return set() when chunked prefill MoE is active - pytorch_transforms.py: register GraniteMoe forward/reverse mappings in PrefillOnlyChunkedTransform and RevertPrefillKeepAttentionTransform - modeling_auto.py: compute num_expert_chunks from prefill_seq_len and EXPERT_BLOCKING_PACKED_CHUNK_SIZE; setattr _num_expert_chunks on every MoE layer when enable_chunking=True --- .../models/gpt_oss/modeling_gpt_oss.py | 38 ++- .../models/granitemoe/modeling_granitemoe.py | 242 ++++++++++++++++++ .../transformers/models/modeling_auto.py | 19 +- .../transformers/models/pytorch_transforms.py | 8 + .../models/qwen3_moe/modeling_qwen3_moe.py | 49 ++-- 5 files changed, 317 insertions(+), 39 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index c13b9c822a..eaae3f5af1 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -92,6 +92,7 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( alpha: float, T: int, packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for GPT-OSS NSP-blocked dispatch. @@ -109,16 +110,26 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( expert_out : [num_nsp, T, H] (accumulator, in-out) """ batch_size, seq_len = T2Ei.shape - packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + # Trace time: seq_len=32 → pcs=16 (valid slice) + # Runtime: seq_len=512 → pcs=256 (correct size) ✅ + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) - row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) - for packed_start in range(0, seq_len, packed_chunk_size): - packed_stop = packed_start + packed_chunk_size + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start chunk_matched_idx = matched_idx[:, packed_start:packed_stop] x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) @@ -137,13 +148,8 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - rows_remaining = valid_rows - packed_start - chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) - chunk_valid_rows = torch.where( - chunk_valid_rows > packed_chunk_size, - torch.ones_like(chunk_valid_rows) * packed_chunk_size, - chunk_valid_rows, - ) + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) @@ -153,7 +159,9 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP num_experts = self.experts.num_experts @@ -191,6 +199,7 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor alpha=self.experts.alpha, T=T, packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + num_expert_chunks=num_expert_chunks, ) return torch.einsum("ijk->jk", expert_out) @@ -214,7 +223,10 @@ def forward(self, hidden: torch.Tensor): routing_weights = masked_logits if self.experts.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights) + num_expert_chunks = getattr(self, "_num_expert_chunks", None) + expert_out = self._forward_expert_blocked( + x=hidden, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks + ) return expert_out.view(B, S, H), router_logits # ────────────────── allocate the output tensor ───── diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8728b4d3e4..0b0c775d4e 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- from typing import List, Optional, Tuple, Type, Union +import os import torch from torch import nn @@ -32,6 +33,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -89,6 +95,62 @@ def qeff_apply_rotary_pos_emb( return q_embed.to(q.dtype), k_embed.to(k.dtype) +class QEffPrefillChunkedGraniteMoeAttention(GraniteMoeAttention): + """Prefill-chunked attention for GraniteMoE — no sliding window.""" + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + + key_states, value_states, _ = past_key_value_update( + module=self, + key=key_states, + value=value_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids, + ) + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -541,6 +603,180 @@ def forward(self, layer_input): return final_hidden_states, router_logits +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table via cumsum scatter.""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_granitemoe_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + activation, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update for one local GraniteMoE expert slot. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] bool + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] accumulator (in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate = x_chunk @ W_g + up = x_chunk @ W_u + down_chunk = (activation(gate) * up) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + +class QEffPrefillChunkedGraniteMoeMoE(GraniteMoeMoE): + """NSP-blocked prefill dispatch for GraniteMoE. + + Replaces the per-expert loop in QEffGraniteMoeMoE with a + cumsum-scatter-gather-update strategy that only runs the MLP on active + tokens, mirroring the Qwen3-MoE implementation. + """ + + def __qeff_init__(self): + W_gate_up = self.input_linear.weight # [E, 2I, H] + I = W_gate_up.shape[1] // 2 + self._W_g = nn.Parameter(W_gate_up[:, :I, :].transpose(1, 2).contiguous()) # [E, H, I] + self._W_u = nn.Parameter(W_gate_up[:, I:, :].transpose(1, 2).contiguous()) # [E, H, I] + self._W_d = nn.Parameter(self.output_linear.weight.transpose(1, 2).contiguous()) # [E, I, H] + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_experts = self.router.num_experts + num_nsp = EXPERT_BLOCKING_NUM_NSP + if num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})." + ) + local_experts = num_experts // num_nsp + I = self._W_g.shape[2] + + rw = routing_weights.transpose(0, 1).contiguous().view(num_nsp, local_experts, T) + W_g = self._W_g.view(num_nsp, local_experts, H, I) + W_u = self._W_u.view(num_nsp, local_experts, H, I) + W_d = self._W_d.view(num_nsp, local_experts, I, H) + + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_granitemoe_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + activation=self.activation, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + num_expert_chunks=num_expert_chunks, + ) + return torch.einsum("ijk->jk", expert_out) + + def orig_forward(self, layer_input: torch.Tensor): + """Original per-expert loop — kept for parity testing.""" + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + topk_gates, expert_mask, router_logits, num_experts = self.router(layer_input) + final_hidden_states = torch.zeros_like(layer_input) + for expert_idx in range(num_experts): + mask = expert_mask[expert_idx].transpose(0, 1).to(layer_input.dtype) + mask_weight = torch.einsum("be,be->b", topk_gates, mask.to(topk_gates.dtype))[:, None] + hidden_states = self.input_linear(layer_input, expert_idx) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_idx) + current_hidden_states = torch.where(mask_weight > 0, expert_outputs * mask_weight, 0.0) + final_hidden_states += current_hidden_states + final_hidden_states = final_hidden_states.view(bsz, length, self.input_size) + return final_hidden_states, router_logits + + def forward(self, layer_input: torch.Tensor): + bsz, length, emb_size = layer_input.size() + x = layer_input.reshape(-1, emb_size) + topk_gates, expert_mask, router_logits, num_experts = self.router(x) + + # Convert [E, top_k, T] + [T, top_k] -> flat [T, E] routing weights + routing_weights = torch.einsum("tk,ekt->te", topk_gates, expert_mask.float()) + + if num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + num_expert_chunks = getattr(self, "_num_expert_chunks", None) + expert_out = self._forward_expert_blocked( + x=x, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks + ) + return expert_out.view(bsz, length, self.input_size), router_logits + + return self.orig_forward(layer_input) + + class QEffGraniteMoeParallelExperts(GraniteMoeParallelExperts): def forward(self, inputs, expert_size): """ @@ -570,6 +806,12 @@ def get_submodules_for_export(self) -> Type[nn.Module]: This method should return the *class object* (not an instance). Downstream code can use this to find/build subfunctions for repeated blocks. """ + # When the chunked-prefill MoE block is active it emits CtxScatter3DInt, + # which the compiler does not support inside ONNX subfunctions. + # Return an empty set so -sub-functions is not passed to qaic-compile. + first_layer_moe = self.model.layers[0].block_sparse_moe + if isinstance(first_layer_moe, QEffPrefillChunkedGraniteMoeMoE): + return set() return {self.model.layers[0].__class__} def forward( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 673d929331..00d821947d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2991,9 +2991,22 @@ def get_seq_len_and_handle_specialized_prefill_model( self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True - seq_len = max(prefill_seq_len or 0, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - self.hash_params["chunking_seq_len"] = seq_len - return seq_len + self.hash_params["EXPERT_BLOCKING_NUM_NSP"] = os.environ.get("EXPERT_BLOCKING_NUM_NSP", None) + self.hash_params["EXPERT_BLOCKING_PACKED_CHUNK_SIZE"] = os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", None) + # Compute num_expert_chunks and set on model so the packed chunk + # loop unrolls the correct number of times during ONNX export + # even when tracing with the small default seq_len (32). + if prefill_seq_len is not None: + pcs = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", 256)) + num_expert_chunks = max(1, prefill_seq_len // pcs) + self.hash_params["num_expert_chunks"] = num_expert_chunks + # Set directly on each MoE module so it never travels via **kwargs + if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): + for layer in self.model.model.layers: + moe = getattr(layer, "block_sparse_moe", None) or getattr(layer, "mlp", None) + if moe is not None: + setattr(moe, "_num_expert_chunks", num_expert_chunks) + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN num_q_blocks = ( self.hash_params["blocking_config"].num_q_blocks if self.hash_params.get("blocking_kwargs", None) else None diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5ff06e6443..14c7bbf140 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -324,9 +324,11 @@ ) from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( QEffGraniteMoeAttention, + QEffPrefillChunkedGraniteMoeAttention, QEffGraniteMoeForCausalLM, QEffGraniteMoeModel, QEffGraniteMoeMoE, + QEffPrefillChunkedGraniteMoeMoE, QEffGraniteMoeParallelExperts, QEffGraniteMoeRotaryEmbedding, QEffGraniteMoeTopKGating, @@ -754,6 +756,9 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, # Qwen3 VL Moe QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, + # GraniteMoe + QEffGraniteMoeMoE: QEffPrefillChunkedGraniteMoeMoE, + QEffGraniteMoeAttention: QEffPrefillChunkedGraniteMoeAttention, } @@ -767,6 +772,9 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, # Qwen3Moe QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + # GraniteMoe + QEffPrefillChunkedGraniteMoeMoE: QEffGraniteMoeMoE, + QEffPrefillChunkedGraniteMoeAttention: QEffGraniteMoeAttention, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index ee3273d081..735eabefcc 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -139,6 +139,7 @@ def _cumsum_scatter_gather_update_expert_blocked( act_fn, T: int, packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. @@ -155,16 +156,26 @@ def _cumsum_scatter_gather_update_expert_blocked( expert_out : [num_nsp, T, H] (accumulator, in-out) """ batch_size, seq_len = T2Ei.shape - packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + # Trace time: seq_len=32 → pcs=16 (valid slice) + # Runtime: seq_len=512 → pcs=256 (correct size) ✅ + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) - row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) rw_expanded = routing_weight.unsqueeze(-1) - for packed_start in range(0, seq_len, packed_chunk_size): - packed_stop = packed_start + packed_chunk_size + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start chunk_matched_idx = matched_idx[:, packed_start:packed_stop] x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) @@ -179,13 +190,8 @@ def _cumsum_scatter_gather_update_expert_blocked( expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - rows_remaining = valid_rows - packed_start - chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) - chunk_valid_rows = torch.where( - chunk_valid_rows > packed_chunk_size, - torch.ones_like(chunk_valid_rows) * packed_chunk_size, - chunk_valid_rows, - ) + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) @@ -208,7 +214,9 @@ def __qeff_init__(self): self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if self.num_experts % num_nsp != 0: @@ -235,6 +243,7 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor act_fn=self.experts[0].act_fn, T=T, packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + num_expert_chunks=num_expert_chunks, ) return torch.einsum("ijk->jk", expert_out) @@ -276,19 +285,13 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights.scatter_(1, top_i, top_w) if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + num_expert_chunks = getattr(self, "_num_expert_chunks", None) + expert_out = self._forward_expert_blocked( + x=x, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks + ) return expert_out.view(B, S, H), router_logits - expert_out = x.new_zeros((T, H)) - for e in range(self.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T - W_d = self.experts[e].down_proj.weight.T - gate = x @ W_g - up = x @ W_u - down = (up * self.experts[e].act_fn(gate)) @ W_d - expert_out += down * routing_weight - return expert_out.view(B, S, H), router_logits + return self.orig_forward(hidden_states) class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): From 3a1873b029aae2a96d0735b18077a3ab2f856590 Mon Sep 17 00:00:00 2001 From: divytrip Date: Thu, 21 May 2026 15:13:13 +0530 Subject: [PATCH 13/16] fix: replace torch.clamp with torch.where for int32 chunk_valid_rows torch.clamp on int32 tensors exports to ONNX Clip op which QAIC compiler does not support (Unhandled ElemKind in Clip operation). Replace with torch.where in all three models: gpt_oss, qwen3_moe, granitemoe. --- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 8 +++++++- .../transformers/models/granitemoe/modeling_granitemoe.py | 8 +++++++- .../transformers/models/qwen3_moe/modeling_qwen3_moe.py | 8 +++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index eaae3f5af1..d2964eb1eb 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -149,7 +149,13 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( updated_chunk = expert_out_chunk + down_chunk row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 0b0c775d4e..529f7c50d7 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -681,7 +681,13 @@ def _cumsum_scatter_gather_update_granitemoe_expert_blocked( updated_chunk = expert_out_chunk + down_chunk row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 735eabefcc..8f1cd27413 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -191,7 +191,13 @@ def _cumsum_scatter_gather_update_expert_blocked( updated_chunk = expert_out_chunk + down_chunk row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=chunk_size) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) From 0ebf018104201720843f3f045bb5195e7348fdd2 Mon Sep 17 00:00:00 2001 From: divytrip Date: Tue, 26 May 2026 14:19:37 +0530 Subject: [PATCH 14/16] feat: API-driven NSP blocking for Qwen3-VL-MoE, Qwen3-MoE, GPT-OSS, GraniteMoE - modeling_qwen3_vl_moe.py: Add QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock with NSP-blocked cumsum-scatter-gather dispatch; supports_moe_prefill_blocking=True; use expert_blocking_num_nsp/packed_chunk_size/num_packed_chunks instance attrs - modeling_qwen3_moe.py, modeling_gpt_oss.py, modeling_granitemoe.py: Replace EXPERT_BLOCKING_NUM_NSP/EXPERT_BLOCKING_PACKED_CHUNK_SIZE env vars with API-driven instance attributes set via compile() params - pytorch_transforms.py: Register QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock in RevertPrefillKeepAttentionTransform - modeling_auto.py: QEFFAutoModelForCausalLM.get_seq_len_and_handle_specialized_prefill_model iterates modules with supports_moe_prefill_blocking=True and sets instance attrs; QEffCausalLMForTextImageToTextModel.export() uses same API-driven pattern for VLM; QEFFAutoModelForImageTextToText.compile() accepts moe_prefill_packed_chunk_size param - modeling_qeff.py: Uncomment self.onnx_path fallback in _compile so pre-exported ONNX is reused without hitting get_onnx_path; pass moe_prefill_packed_chunk_size through get_onnx_path and _compile - constants.py: Add MOE_PREFILL_PACKED_CHUNK_SIZE = 256 --- QEfficient/base/modeling_qeff.py | 11 +- .../models/gpt_oss/modeling_gpt_oss.py | 20 +- .../models/granitemoe/modeling_granitemoe.py | 20 +- .../transformers/models/modeling_auto.py | 60 +- .../transformers/models/pytorch_transforms.py | 2 + .../models/qwen3_moe/modeling_qwen3_moe.py | 20 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 197 +++++- QEfficient/utils/constants.py | 637 +++++++++--------- 8 files changed, 570 insertions(+), 397 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a091de7497..b69d2d0a5c 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -394,6 +394,7 @@ def get_onnx_path( retain_full_kv: Optional[bool] = False, mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, + moe_prefill_packed_chunk_size: Optional[int] = None, **compiler_options, ): kwargs = { @@ -409,6 +410,10 @@ def get_onnx_path( "prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len"), "enable_chunking": enable_chunking, + "num_cores": compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + "moe_prefill_packed_chunk_size": constants.MOE_PREFILL_PACKED_CHUNK_SIZE + if moe_prefill_packed_chunk_size is None + else moe_prefill_packed_chunk_size, } ) @@ -527,11 +532,12 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None) onnx_path = Path( onnx_path if onnx_path - # else self.onnx_path - # if self.onnx_path + else self.onnx_path + if self.onnx_path else self.get_onnx_path( prefill_only, enable_chunking, @@ -542,6 +548,7 @@ def _compile( mla_absorption, num_devices=mdp_ts_num_devices, qaic_config=qaic_config, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, **compiler_options, ) ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index d2964eb1eb..67f1a485ad 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -55,9 +55,6 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) -EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) -EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) - def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: """Build packed->original token index""" @@ -165,14 +162,16 @@ def _cumsum_scatter_gather_update_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + supports_moe_prefill_blocking = True + def _forward_expert_blocked( self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None ) -> torch.Tensor: T, H = x.shape - num_nsp = EXPERT_BLOCKING_NUM_NSP + num_nsp = self.expert_blocking_num_nsp num_experts = self.experts.num_experts if num_experts % num_nsp != 0: - raise ValueError(f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") + raise ValueError(f"num_experts ({num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})") local_experts = num_experts // num_nsp expert_dim = self.experts.expert_dim @@ -204,8 +203,8 @@ def _forward_expert_blocked( limit=self.experts.limit, alpha=self.experts.alpha, T=T, - packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, - num_expert_chunks=num_expert_chunks, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, ) return torch.einsum("ijk->jk", expert_out) @@ -228,11 +227,8 @@ def forward(self, hidden: torch.Tensor): # Routing weights for each expert [T, E] routing_weights = masked_logits - if self.experts.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - num_expert_chunks = getattr(self, "_num_expert_chunks", None) - expert_out = self._forward_expert_blocked( - x=hidden, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks - ) + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.experts.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights) return expert_out.view(B, S, H), router_logits # ────────────────── allocate the output tensor ───── diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 529f7c50d7..48436e9838 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -603,9 +603,6 @@ def forward(self, layer_input): return final_hidden_states, router_logits -EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) -EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) - def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: """Build packed->original token index table via cumsum scatter.""" @@ -699,6 +696,8 @@ def _cumsum_scatter_gather_update_granitemoe_expert_blocked( class QEffPrefillChunkedGraniteMoeMoE(GraniteMoeMoE): """NSP-blocked prefill dispatch for GraniteMoE. + supports_moe_prefill_blocking = True + Replaces the per-expert loop in QEffGraniteMoeMoE with a cumsum-scatter-gather-update strategy that only runs the MLP on active tokens, mirroring the Qwen3-MoE implementation. @@ -716,10 +715,10 @@ def _forward_expert_blocked( ) -> torch.Tensor: T, H = x.shape num_experts = self.router.num_experts - num_nsp = EXPERT_BLOCKING_NUM_NSP + num_nsp = self.expert_blocking_num_nsp if num_experts % num_nsp != 0: raise ValueError( - f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})." + f"num_experts ({num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})." ) local_experts = num_experts // num_nsp I = self._W_g.shape[2] @@ -742,8 +741,8 @@ def _forward_expert_blocked( routing_weight=routing_weight, expert_out=expert_out, activation=self.activation, - packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, - num_expert_chunks=num_expert_chunks, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, ) return torch.einsum("ijk->jk", expert_out) @@ -773,11 +772,8 @@ def forward(self, layer_input: torch.Tensor): # Convert [E, top_k, T] + [T, top_k] -> flat [T, E] routing weights routing_weights = torch.einsum("tk,ekt->te", topk_gates, expert_mask.float()) - if num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - num_expert_chunks = getattr(self, "_num_expert_chunks", None) - expert_out = self._forward_expert_blocked( - x=x, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks - ) + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) return expert_out.view(bsz, length, self.input_size), router_logits return self.orig_forward(layer_input) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 00d821947d..3601d287c6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1077,6 +1077,8 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ): """ @@ -1110,6 +1112,19 @@ def export( ) self.hash_params["prefill_only"] = True self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) + # Set API-driven MoE blocking params on VLM MoE layers AFTER __update_prefill_transform + # so they are set on the already-swapped chunked class instances. + if enable_chunking and prefill_seq_len is not None: + compile_seq_len = prefill_seq_len + num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) + for module in self.model.modules(): + if getattr(module, "supports_moe_prefill_blocking", False): + module.expert_blocking_num_nsp = num_cores + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = num_cores + self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size + self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks else: self.hash_params["prefill_only"] = False self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) @@ -1383,6 +1398,8 @@ def export( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + num_cores=kwargs.get("num_cores", constants.DEFAULT_AIC_NUM_CORES), + moe_prefill_packed_chunk_size=kwargs.get("moe_prefill_packed_chunk_size", constants.MOE_PREFILL_PACKED_CHUNK_SIZE), ) return self.onnx_path @@ -1436,6 +1453,7 @@ def compile( use_onnx_subfunctions: bool = False, prefill_only=None, enable_chunking=False, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, qaic_config: Optional[dict] = None, **compiler_options, ) -> str: @@ -1574,6 +1592,8 @@ def compile( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + num_cores=num_cores, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -2986,26 +3006,25 @@ def get_model_config(self) -> dict: return self.model.config.__dict__ def get_seq_len_and_handle_specialized_prefill_model( - self, prefill_seq_len: Optional[int] = None, enable_chunking=False + self, + prefill_seq_len: Optional[int] = None, + enable_chunking=False, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, ) -> int: self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True - self.hash_params["EXPERT_BLOCKING_NUM_NSP"] = os.environ.get("EXPERT_BLOCKING_NUM_NSP", None) - self.hash_params["EXPERT_BLOCKING_PACKED_CHUNK_SIZE"] = os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", None) - # Compute num_expert_chunks and set on model so the packed chunk - # loop unrolls the correct number of times during ONNX export - # even when tracing with the small default seq_len (32). - if prefill_seq_len is not None: - pcs = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", 256)) - num_expert_chunks = max(1, prefill_seq_len // pcs) - self.hash_params["num_expert_chunks"] = num_expert_chunks - # Set directly on each MoE module so it never travels via **kwargs - if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): - for layer in self.model.model.layers: - moe = getattr(layer, "block_sparse_moe", None) or getattr(layer, "mlp", None) - if moe is not None: - setattr(moe, "_num_expert_chunks", num_expert_chunks) + compile_seq_len = prefill_seq_len or constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) + for module in self.model.modules(): + if getattr(module, "supports_moe_prefill_blocking", False): + module.expert_blocking_num_nsp = num_cores + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = num_cores + self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size + self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN num_q_blocks = ( @@ -3102,7 +3121,10 @@ def export( self.hash_params.pop("retain_full_kv", None) if "DeepseekV3ForCausalLM" not in (getattr(self.model.config, "architectures", None) or []): seq_len = self.get_seq_len_and_handle_specialized_prefill_model( - prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + prefill_seq_len=prefill_seq_len, + enable_chunking=enable_chunking, + num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + moe_prefill_packed_chunk_size=kwargs.get("moe_prefill_packed_chunk_size", constants.MOE_PREFILL_PACKED_CHUNK_SIZE), ) kv_cache_shape[2] = ( seq_len @@ -3117,7 +3139,9 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) - self.hash_params.pop("chunking_seq_len", None) + self.hash_params.pop("moe_prefill_num_nsp", None) + self.hash_params.pop("moe_prefill_packed_chunk_size", None) + self.hash_params.pop("moe_prefill_num_packed_chunks", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 14c7bbf140..899e6acbc5 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -775,6 +775,8 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): # GraniteMoe QEffPrefillChunkedGraniteMoeMoE: QEffGraniteMoeMoE, QEffPrefillChunkedGraniteMoeAttention: QEffGraniteMoeAttention, + # Qwen3 VL Moe + QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 8f1cd27413..52386d82af 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -106,9 +106,6 @@ def eager_attention_forward( return attn_output, attn_weights -EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) -EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) - def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: """Build packed->original token index""" @@ -207,6 +204,8 @@ def _cumsum_scatter_gather_update_expert_blocked( class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + supports_moe_prefill_blocking = True + def __qeff_init__(self): self.gate_proj_w = [] self.up_proj_w = [] @@ -224,10 +223,10 @@ def _forward_expert_blocked( self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None ) -> torch.Tensor: T, H = x.shape - num_nsp = EXPERT_BLOCKING_NUM_NSP + num_nsp = self.expert_blocking_num_nsp if self.num_experts % num_nsp != 0: raise ValueError( - f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" ) local_experts = self.num_experts // num_nsp rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() @@ -248,8 +247,8 @@ def _forward_expert_blocked( expert_out=expert_out, act_fn=self.experts[0].act_fn, T=T, - packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, - num_expert_chunks=num_expert_chunks, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, ) return torch.einsum("ijk->jk", expert_out) @@ -290,11 +289,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights = torch.zeros_like(router_logits) routing_weights.scatter_(1, top_i, top_w) - if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - num_expert_chunks = getattr(self, "_num_expert_chunks", None) - expert_out = self._forward_expert_blocked( - x=x, routing_weights=routing_weights, num_expert_chunks=num_expert_chunks - ) + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) return expert_out.view(B, S, H), router_logits return self.orig_forward(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 976c80919b..075f958ddc 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- import math +import os from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -37,6 +38,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants @@ -633,41 +639,186 @@ def _deepstack_process( return local_this + + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table via cumsum scatter.""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_qwen3vlmoe_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + act_fn, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update expert helper for Qwen3-VL-MoE NSP-blocked dispatch. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops -> DYNAMIC at runtime (scales with actual seq_len). + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = x_chunk @ W_g + up_prime = x_chunk @ W_u + down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """NSP-blocked prefill dispatch for Qwen3-VL-MoE text sparse MoE block.""" + + supports_moe_prefill_blocking = True + + def __qeff_init__(self): + # Split fused gate_up_proj [E, H, 2I] into separate gate and up [E, H, I] + W_gate_up = self.experts.gate_up_proj # [E, H, 2I] + I = W_gate_up.shape[2] // 2 + self._W_g = nn.Parameter(W_gate_up[:, :, :I].contiguous()) # [E, H, I] + self._W_u = nn.Parameter(W_gate_up[:, :, I:].contiguous()) # [E, H, I] + self._W_d = nn.Parameter(self.experts.down_proj.contiguous()) # [E, I, H] + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_nsp = self.expert_blocking_num_nsp + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + I = self._W_g.shape[2] + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self._W_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self._W_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self._W_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_qwen3vlmoe_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + act_fn=self.experts.act_fn, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, + ) + return torch.einsum("ijk->jk", expert_out) + + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Original per-expert loop — kept for parity testing.""" B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - act = getattr(self.experts, "act_fn", F.silu) - - router_logits = self.gate(x) # [T, E] - prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + act = self.experts.act_fn + router_logits = self.gate(x) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) - top_w = top_w / torch.einsum("bi->b", top_w)[:, None] - top_w = top_w.to(hidden_states.dtype) - routing_weights = torch.zeros((T, self.num_experts), dtype=x.dtype) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] # norm_topk_prob always True + top_w = top_w.to(x.dtype) + routing_weights = torch.zeros_like(router_logits) routing_weights.scatter_(1, top_i, top_w) - - expert_out = torch.zeros_like(x, dtype=x.dtype) - + expert_out = x.new_zeros((T, H)) for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - gate_up = x @ W_gate_up_e # [T, 2I] - + W_dn_e = self.experts.down_proj[e] # [I, H] + gate_up = x @ W_gate_up_e I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] + gate, up = gate_up[:, :I2], gate_up[:, I2:] intermediate = up * act(gate) down = intermediate @ W_dn_e - masked_down = torch.where( - routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) - ) # TODO: verify and remove - expert_out += masked_down - expert_out = expert_out.to(x.dtype).view(B, S, H) - return expert_out, router_logits + expert_out += torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) + return expert_out.view(B, S, H), router_logits + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] # norm_topk_prob always True + top_w = top_w.to(x.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + return self.orig_forward(hidden_states) class QEffQwen3VLMoeModel(Qwen3VLMoeModel): diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..22fad9a611 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -1,318 +1,319 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import os -from dataclasses import dataclass - -UTILS_DIR = os.path.dirname(os.path.abspath(__file__)) -QEFF_DIR = os.path.dirname(UTILS_DIR) -ROOT_DIR = os.path.dirname(QEFF_DIR) -QEFF_CACHE_DIR_NAME = "qeff_cache" - -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 -ONNX_EXPORT_EXAMPLE_FBS = 4 -ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_MAX_NUM_IMAGES = 1 -ONNX_EXPORT_MAX_IMAGE_TILES = 4 -ONNX_EXPORT_IMAGE_WIDTH = 560 -ONNX_EXPORT_IMAGE_LENGHT = 560 -ONNX_EXPORT_IMAGE_DEPTH = 3 -ONNX_EXPORT_CTX_LEN = 1024 - -NPI_MAPPING = { - "google/gemma-3-4b-it": os.path.join( - QEFF_DIR, "transformers", "models", "gemma3", "configs", "fp32_nodes_gemma3_4b.yaml" - ), - "google/gemma-3-27b-it": os.path.join( - QEFF_DIR, "transformers", "models", "gemma3", "configs", "gemma_updated_npi.yaml" - ), -} - -# Blocking defaults -VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 - -# Compiler defaults -DEFAULT_AIC_NUM_CORES = 16 -DEFAULT_AIC_MXPF6_MATMUL = False -# Hashing defaults -HASH_HEXDIGEST_STR_LEN = 16 -KWARGS_INCLUSION_LIST = [ - "state_dict", - "revision", - "key_mapping", - "commit_hash", - "adapter_kwargs", - "adapter_name", - "gguf_file", - "pretrained_model_name_or_path", - "attn_implementation", - "_attn_implementation", - "qaic_config", -] - -# Minimum value for causal mask -MIN_MASKED_ATTENTION_VALUE = float("-inf") - - -# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable. -def get_models_dir(): - """ - Determine the directory for storing QEFF models. - Priority: - 1. Use $XDG_CACHE_HOME/qeff_models if XDG_CACHE_HOME is set. - 2. Use QEFF_HOME if set in environment. - 3. Default to ~/.cache/qeff_models. - Sets QEFF_MODELS_DIR environment variable if not already set. - Returns: - str: Path to the QEFF models directory. - """ - qeff_cache_home = os.environ.get("QEFF_HOME") - # Check if XDG_CACHE_HOME is set - xdg_cache_home = os.environ.get("XDG_CACHE_HOME") - if qeff_cache_home: - qeff_models_dir = os.path.join(qeff_cache_home, QEFF_CACHE_DIR_NAME) - # Check if QEFF_MODELS_DIR is set - elif xdg_cache_home: - qeff_models_dir = os.path.join(xdg_cache_home, QEFF_CACHE_DIR_NAME) - else: - # Use ~/.cache/qeff_models as the default - qeff_models_dir = os.path.join(os.path.expanduser("~"), ".cache", QEFF_CACHE_DIR_NAME) - - # Set QEFF_MODELS_DIR environment variable - return qeff_models_dir - - -QEFF_MODELS_DIR = get_models_dir() - -ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 -ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 -ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 -ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 -ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 -ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 17 -FILE_CHUNK_SIZE_DEFAULT = 10 * 2**30 # 10 GB -SIZE_THRESHOLD_DEFAULT = 1024 - - -COMPILER = ["/opt/qti-aic/exec/qaic-compile", "-aic-hw"] -DEFAULT_AIC_HW_VERSION = "ai100" -ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 - -# InternVL constants -# Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B -INTERN_FEATURE_SIZE = 256 -INTERN_NUM_PATCHES = 13 -INTERN_IMG_SIZE = 448 -INTERN_CTX_LEN = 4096 -INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 -INTERN_NUM_CHANNELS = 3 -INTERN_IMAGE_HEIGHT = 1000 -INTERN_IMAGE_WIDTH = 747 - -INTERN_IMG_CONTEXT_TOKEN = 151667 -# Specific to InternVL3_5 series, same token won't work for InternVL2_5 series -INTERN_3_5_IMG_CONTEXT_TOKEN = 151671 - -# Granite Vision Constants -# Fixing the feature size with reference to ibm-granite/granite-vision-3.2-2b -GRANITEVISION_FEATURE_SIZE = 5239 -GRANITEVISION_NUM_PATCHES = 10 -GRANITEVISION_IMG_SIZE = 384 -GRANITEVISION_IMG_SIZE_HEIGHT = 1109 -GRANITEVISION_IMG_SIZE_WIDTH = 1610 -GRANITEVISION_PIXEL_VALUE_DIM = 5 -GRANITEVISION_PREFIL_SEQ_LEN = GRANITEVISION_SEQ_LEN = 5500 -GRANITEVISION_CTX_LEN = 6000 -GRANITEVISION_NUM_CHANNELS = 3 - -VISION_MXFP6_MATMUL = False -# Llama4 Constants -LLAMA4_ATTENTION_CHUNK_SIZE = 8192 -LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 - -# DeepSeek Kimi-k2 Constant -MAX_POSITION_EMBEDDINGS = 32768 -FP16_BYTES = 2 -DEFAULT_NUM_HEADS = 64 -KV_LORA_RANK = 512 -ROPE_DIM = 64 - -# Wav2Vec2 Constant -WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) - -# Qwen2_5_vl Constants -QWEN2_5_VL_HEIGHT = 354 -QWEN2_5_VL_WIDTH = 536 - -# Qwen3_vl Constanst -QWEN3_VL_HEIGHT = 354 -QWEN3_VL_WIDTH = 536 - -# Modules to cache while clearing the pytorch weights -CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] - -# Mistral3 Constants -MISTRAL3_IMAGE_HEIGHT = 1540 -MISTRAL3_IMAGE_WIDTH = 1540 - -# Molmo Constants -MOLMO_IMAGE_HEIGHT = 536 -MOLMO_IMAGE_WIDTH = 354 -# Flux Transformer Constants -FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 -FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 -FLUX_ADALN_HIDDEN_DIM = 3072 -FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context -FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 -FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM - -# Wan Transformer Constants -WAN_TEXT_EMBED_DIM = 5120 -WAN_PROJECTION_DIM = 6 -WAN_ONNX_EXPORT_BATCH_SIZE = 1 -WAN_ONNX_EXPORT_FRAMES = 81 -WAN_ONNX_EXPORT_LATENT_FRAMES = 21 -WAN_ONNX_EXPORT_SEQ_LEN = 512 -WAN_ONNX_EXPORT_ROTARY_DIM = 128 -WAN_DIT_OUT_CHANNELS = 64 -# Wan dims for 45p -WAN_ONNX_EXPORT_CL_45P = 252 -WAN_ONNX_EXPORT_LATENT_HEIGHT_45P = 6 -WAN_ONNX_EXPORT_LATENT_WIDTH_45P = 8 -WAN_ONNX_EXPORT_HEIGHT_45P = 48 -WAN_ONNX_EXPORT_WIDTH_45P = 64 - -# WAN I2V -WAN_DIT_I2V_IMG_LATENT_CHANNELS = 32 - -# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length -CCL_START_MAP = { - 32768: (4096, 4000), - 65536: (8192, 8000), - float("inf"): (16384, 16000), -} -# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process. -CCL_MAX_ELEMENTS_LISTS = 5 -CCL_START_CTX_LEN = 4096 -CCL_MIN_CTX_LEN = 1024 -CCL_UNIQNE_STEP = 32 - -# used for gpt-oss prefill-only model Q-blocking -GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256 - - -class Constants: - # Export Constants. - SEQ_LEN = 32 - CTX_LEN = 32 - PROMPT_LEN = 8 - INPUT_STR = ["My name is"] - GB = 2**30 - MAX_QPC_LIMIT = 30 - MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download - NUM_SPECULATIVE_TOKENS = 2 - MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS - SAMPLER_OPS = { - "repetition_penalties", - "presence_penalties", - "temperatures", - "top_ks", - "top_ps", - "min_ps", - "random_numbers", - } - SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"} - SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. - SDK_PLATFORM_XML = ( - "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. - ) - - -@dataclass -class QnnConstants: - # QNN PATH to be read from environment variable. - QNN_SDK_PATH_ENV_VAR_NAME = "QNN_SDK_ROOT" - QNN_SDK_YAML = "sdk.yaml" - - # QNN Compilation tools - QAIRT_CONVERTER = "{}/bin/{}/qairt-converter" - QNN_CONTEXT_BIN = "{}/bin/{}/qnn-context-binary-generator" - - # QNN Libraries required for compilation - QNN_CONTEXT_LIB_BACKEND = "{}/lib/{}/libQnnAic.so" - QNN_CONTEXT_LIB_NET_RUN_EXTENSIONS = "{}/lib/{}/libQnnAicNetRunExtensions.so" - - # QNN Compilation target names - MODEL_NAME = "model" - QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json" - CONTEXT_BIN_NAME = "qnngraph.serialized" - CONTEXT_BIN_QPC_NAME = "programqpc.bin" - - # TARGET System Architecture - TARGET = "x86_64-linux-clang" # TODO add support in infer to be override - - # Converter Arguments - FLOAT_BITWIDTH = 16 - FLOAT_BIAS_BITWIDTH = 32 - CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification --target_backend AIC " - - # Context-Binary-Generator Arguments - LOG_LEVEL = "error" - - # qnn_compilation_backend default Arguments - COMPILER_COMPILATION_TARGET = "hardware" - COMPILER_CONVERT_TO_FP16 = True - COMPILER_DO_DDR_TO_MULTICAST = True - COMPILER_HARDWARE_VERSION = "2.0" - COMPILER_PERF_WARNINGS = False - COMPILER_PRINT_DDR_STATS = False - COMPILER_PRINT_PERF_METRICS = False - COMPILER_RETAINED_STATE = True - COMPILER_STAT_LEVEL = 10 - COMPILER_STATS_BATCH_SIZE = 1 - COMPILER_TIME_PASSES = False - GRAPH_NAMES = [f"{MODEL_NAME}_configuration_1", f"{MODEL_NAME}_configuration_2"] - GRAPH_NAMES_PREFILL_ONLY = [f"{MODEL_NAME}"] - - # qnn_config JSON file supported Keys - CONVERTER_ARGS_EXTENSION_STR = "converter_args_extension" - CONTEXT_BIN_ARGS_EXTENSION_STR = "context_binary_generator_args_extension" - QNN_COMPILATION_BACKEND_STR = "qnn_compilation_backend" - SKIP_QNN_CONVERTER_STEP_STR = "SKIP_QNN_CONVERTER_STEP" - - IMMUTABLE_CONVERTER_ARGS = [ - "--input_network ", - "--output_path ", - "--config ", - "--float_bias_bitwidth ", - "--float_bitwidth ", - "--preserve_io_datatype", - "--onnx_skip_simplification", - ] - - IMMUTABLE_CONTEXT_BIN_GEN_ARGS = [ - "--binary_file ", - "--backend_binary ", - "--output_dir ", - "--backend ", - "--model ", - "--dlc_path ", - "--config_file ", - ] - - QNN_SAMPLE_CONFIG = { - "converter_args_extension": "--onnx_defer_loading", - "context_binary_generator_args_extension": "--log_level debug", - "qnn_compilation_backend": { - "compiler_enable_depth_first": True, - "compiler_printDDRStats": False, - "compiler_printPerfMetrics": False, - }, - "SKIP_QNN_CONVERTER_STEP": False, - } +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +from dataclasses import dataclass + +UTILS_DIR = os.path.dirname(os.path.abspath(__file__)) +QEFF_DIR = os.path.dirname(UTILS_DIR) +ROOT_DIR = os.path.dirname(QEFF_DIR) +QEFF_CACHE_DIR_NAME = "qeff_cache" + +ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 +ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 +ONNX_EXPORT_EXAMPLE_FBS = 4 +ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep +ONNX_EXPORT_MAX_NUM_IMAGES = 1 +ONNX_EXPORT_MAX_IMAGE_TILES = 4 +ONNX_EXPORT_IMAGE_WIDTH = 560 +ONNX_EXPORT_IMAGE_LENGHT = 560 +ONNX_EXPORT_IMAGE_DEPTH = 3 +ONNX_EXPORT_CTX_LEN = 1024 + +NPI_MAPPING = { + "google/gemma-3-4b-it": os.path.join( + QEFF_DIR, "transformers", "models", "gemma3", "configs", "fp32_nodes_gemma3_4b.yaml" + ), + "google/gemma-3-27b-it": os.path.join( + QEFF_DIR, "transformers", "models", "gemma3", "configs", "gemma_updated_npi.yaml" + ), +} + +# Blocking defaults +VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 + +# Compiler defaults +DEFAULT_AIC_NUM_CORES = 16 +DEFAULT_AIC_MXPF6_MATMUL = False +MOE_PREFILL_PACKED_CHUNK_SIZE = 256 +# Hashing defaults +HASH_HEXDIGEST_STR_LEN = 16 +KWARGS_INCLUSION_LIST = [ + "state_dict", + "revision", + "key_mapping", + "commit_hash", + "adapter_kwargs", + "adapter_name", + "gguf_file", + "pretrained_model_name_or_path", + "attn_implementation", + "_attn_implementation", + "qaic_config", +] + +# Minimum value for causal mask +MIN_MASKED_ATTENTION_VALUE = float("-inf") + + +# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable. +def get_models_dir(): + """ + Determine the directory for storing QEFF models. + Priority: + 1. Use $XDG_CACHE_HOME/qeff_models if XDG_CACHE_HOME is set. + 2. Use QEFF_HOME if set in environment. + 3. Default to ~/.cache/qeff_models. + Sets QEFF_MODELS_DIR environment variable if not already set. + Returns: + str: Path to the QEFF models directory. + """ + qeff_cache_home = os.environ.get("QEFF_HOME") + # Check if XDG_CACHE_HOME is set + xdg_cache_home = os.environ.get("XDG_CACHE_HOME") + if qeff_cache_home: + qeff_models_dir = os.path.join(qeff_cache_home, QEFF_CACHE_DIR_NAME) + # Check if QEFF_MODELS_DIR is set + elif xdg_cache_home: + qeff_models_dir = os.path.join(xdg_cache_home, QEFF_CACHE_DIR_NAME) + else: + # Use ~/.cache/qeff_models as the default + qeff_models_dir = os.path.join(os.path.expanduser("~"), ".cache", QEFF_CACHE_DIR_NAME) + + # Set QEFF_MODELS_DIR environment variable + return qeff_models_dir + + +QEFF_MODELS_DIR = get_models_dir() + +ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 +ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 +ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 +ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 +ONNX_EXPORT_OPSET = 17 +FILE_CHUNK_SIZE_DEFAULT = 10 * 2**30 # 10 GB +SIZE_THRESHOLD_DEFAULT = 1024 + + +COMPILER = ["/opt/qti-aic/exec/qaic-compile", "-aic-hw"] +DEFAULT_AIC_HW_VERSION = "ai100" +ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 + +# InternVL constants +# Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B +INTERN_FEATURE_SIZE = 256 +INTERN_NUM_PATCHES = 13 +INTERN_IMG_SIZE = 448 +INTERN_CTX_LEN = 4096 +INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 +INTERN_NUM_CHANNELS = 3 +INTERN_IMAGE_HEIGHT = 1000 +INTERN_IMAGE_WIDTH = 747 + +INTERN_IMG_CONTEXT_TOKEN = 151667 +# Specific to InternVL3_5 series, same token won't work for InternVL2_5 series +INTERN_3_5_IMG_CONTEXT_TOKEN = 151671 + +# Granite Vision Constants +# Fixing the feature size with reference to ibm-granite/granite-vision-3.2-2b +GRANITEVISION_FEATURE_SIZE = 5239 +GRANITEVISION_NUM_PATCHES = 10 +GRANITEVISION_IMG_SIZE = 384 +GRANITEVISION_IMG_SIZE_HEIGHT = 1109 +GRANITEVISION_IMG_SIZE_WIDTH = 1610 +GRANITEVISION_PIXEL_VALUE_DIM = 5 +GRANITEVISION_PREFIL_SEQ_LEN = GRANITEVISION_SEQ_LEN = 5500 +GRANITEVISION_CTX_LEN = 6000 +GRANITEVISION_NUM_CHANNELS = 3 + +VISION_MXFP6_MATMUL = False +# Llama4 Constants +LLAMA4_ATTENTION_CHUNK_SIZE = 8192 +LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 + +# DeepSeek Kimi-k2 Constant +MAX_POSITION_EMBEDDINGS = 32768 +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 +KV_LORA_RANK = 512 +ROPE_DIM = 64 + +# Wav2Vec2 Constant +WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) + +# Qwen2_5_vl Constants +QWEN2_5_VL_HEIGHT = 354 +QWEN2_5_VL_WIDTH = 536 + +# Qwen3_vl Constanst +QWEN3_VL_HEIGHT = 354 +QWEN3_VL_WIDTH = 536 + +# Modules to cache while clearing the pytorch weights +CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] + +# Mistral3 Constants +MISTRAL3_IMAGE_HEIGHT = 1540 +MISTRAL3_IMAGE_WIDTH = 1540 + +# Molmo Constants +MOLMO_IMAGE_HEIGHT = 536 +MOLMO_IMAGE_WIDTH = 354 +# Flux Transformer Constants +FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 +FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 +FLUX_ADALN_HIDDEN_DIM = 3072 +FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context +FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 +FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM + +# Wan Transformer Constants +WAN_TEXT_EMBED_DIM = 5120 +WAN_PROJECTION_DIM = 6 +WAN_ONNX_EXPORT_BATCH_SIZE = 1 +WAN_ONNX_EXPORT_FRAMES = 81 +WAN_ONNX_EXPORT_LATENT_FRAMES = 21 +WAN_ONNX_EXPORT_SEQ_LEN = 512 +WAN_ONNX_EXPORT_ROTARY_DIM = 128 +WAN_DIT_OUT_CHANNELS = 64 +# Wan dims for 45p +WAN_ONNX_EXPORT_CL_45P = 252 +WAN_ONNX_EXPORT_LATENT_HEIGHT_45P = 6 +WAN_ONNX_EXPORT_LATENT_WIDTH_45P = 8 +WAN_ONNX_EXPORT_HEIGHT_45P = 48 +WAN_ONNX_EXPORT_WIDTH_45P = 64 + +# WAN I2V +WAN_DIT_I2V_IMG_LATENT_CHANNELS = 32 + +# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length +CCL_START_MAP = { + 32768: (4096, 4000), + 65536: (8192, 8000), + float("inf"): (16384, 16000), +} +# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process. +CCL_MAX_ELEMENTS_LISTS = 5 +CCL_START_CTX_LEN = 4096 +CCL_MIN_CTX_LEN = 1024 +CCL_UNIQNE_STEP = 32 + +# used for gpt-oss prefill-only model Q-blocking +GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256 + + +class Constants: + # Export Constants. + SEQ_LEN = 32 + CTX_LEN = 32 + PROMPT_LEN = 8 + INPUT_STR = ["My name is"] + GB = 2**30 + MAX_QPC_LIMIT = 30 + MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download + NUM_SPECULATIVE_TOKENS = 2 + MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS + SAMPLER_OPS = { + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + } + SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"} + SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. + SDK_PLATFORM_XML = ( + "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. + ) + + +@dataclass +class QnnConstants: + # QNN PATH to be read from environment variable. + QNN_SDK_PATH_ENV_VAR_NAME = "QNN_SDK_ROOT" + QNN_SDK_YAML = "sdk.yaml" + + # QNN Compilation tools + QAIRT_CONVERTER = "{}/bin/{}/qairt-converter" + QNN_CONTEXT_BIN = "{}/bin/{}/qnn-context-binary-generator" + + # QNN Libraries required for compilation + QNN_CONTEXT_LIB_BACKEND = "{}/lib/{}/libQnnAic.so" + QNN_CONTEXT_LIB_NET_RUN_EXTENSIONS = "{}/lib/{}/libQnnAicNetRunExtensions.so" + + # QNN Compilation target names + MODEL_NAME = "model" + QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json" + CONTEXT_BIN_NAME = "qnngraph.serialized" + CONTEXT_BIN_QPC_NAME = "programqpc.bin" + + # TARGET System Architecture + TARGET = "x86_64-linux-clang" # TODO add support in infer to be override + + # Converter Arguments + FLOAT_BITWIDTH = 16 + FLOAT_BIAS_BITWIDTH = 32 + CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification --target_backend AIC " + + # Context-Binary-Generator Arguments + LOG_LEVEL = "error" + + # qnn_compilation_backend default Arguments + COMPILER_COMPILATION_TARGET = "hardware" + COMPILER_CONVERT_TO_FP16 = True + COMPILER_DO_DDR_TO_MULTICAST = True + COMPILER_HARDWARE_VERSION = "2.0" + COMPILER_PERF_WARNINGS = False + COMPILER_PRINT_DDR_STATS = False + COMPILER_PRINT_PERF_METRICS = False + COMPILER_RETAINED_STATE = True + COMPILER_STAT_LEVEL = 10 + COMPILER_STATS_BATCH_SIZE = 1 + COMPILER_TIME_PASSES = False + GRAPH_NAMES = [f"{MODEL_NAME}_configuration_1", f"{MODEL_NAME}_configuration_2"] + GRAPH_NAMES_PREFILL_ONLY = [f"{MODEL_NAME}"] + + # qnn_config JSON file supported Keys + CONVERTER_ARGS_EXTENSION_STR = "converter_args_extension" + CONTEXT_BIN_ARGS_EXTENSION_STR = "context_binary_generator_args_extension" + QNN_COMPILATION_BACKEND_STR = "qnn_compilation_backend" + SKIP_QNN_CONVERTER_STEP_STR = "SKIP_QNN_CONVERTER_STEP" + + IMMUTABLE_CONVERTER_ARGS = [ + "--input_network ", + "--output_path ", + "--config ", + "--float_bias_bitwidth ", + "--float_bitwidth ", + "--preserve_io_datatype", + "--onnx_skip_simplification", + ] + + IMMUTABLE_CONTEXT_BIN_GEN_ARGS = [ + "--binary_file ", + "--backend_binary ", + "--output_dir ", + "--backend ", + "--model ", + "--dlc_path ", + "--config_file ", + ] + + QNN_SAMPLE_CONFIG = { + "converter_args_extension": "--onnx_defer_loading", + "context_binary_generator_args_extension": "--log_level debug", + "qnn_compilation_backend": { + "compiler_enable_depth_first": True, + "compiler_printDDRStats": False, + "compiler_printPerfMetrics": False, + }, + "SKIP_QNN_CONVERTER_STEP": False, + } From 8c54aa1a48f1c238fd46d8d316f80ad011e38e64 Mon Sep 17 00:00:00 2001 From: divytrip Date: Wed, 27 May 2026 10:34:13 +0530 Subject: [PATCH 15/16] fix: add num_cores and moe_prefill_packed_chunk_size as explicit params to QEFFAutoModelForCausalLM.export() compiler_options is only available in compile(), not export(). Add num_cores and moe_prefill_packed_chunk_size as explicit named params to export() so they are directly accessible, matching the pattern in vbaddi/feat/prefill_moe. --- QEfficient/transformers/models/modeling_auto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3601d287c6..87a7990306 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3071,6 +3071,8 @@ def export( export_dir: Optional[str] = None, prefill_only: Optional[bool] = False, prefill_seq_len: Optional[int] = None, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ) -> str: """ @@ -3123,8 +3125,8 @@ def export( seq_len = self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking, - num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), - moe_prefill_packed_chunk_size=kwargs.get("moe_prefill_packed_chunk_size", constants.MOE_PREFILL_PACKED_CHUNK_SIZE), + num_cores=num_cores, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) kv_cache_shape[2] = ( seq_len From 1744f906a6ff7851de3ba8d0ad8c77a45201c5ac Mon Sep 17 00:00:00 2001 From: divytrip Date: Mon, 1 Jun 2026 11:00:24 +0530 Subject: [PATCH 16/16] fix: moe_prefill_num_nsp API param + GraniteMoE NSP blocking fixes - modeling_qeff.py: save moe_prefill_num_nsp from compiler_options in _compile and pass through get_onnx_path to export() - modeling_utils.py: add granitemoe to SPECIALIZED_DISAGG_SERVING_MODEL_ARCH - modeling_granitemoe.py: fix supports_moe_prefill_blocking moved out of docstring into class body; fix reshape order to match Qwen3-MoE/GPT-OSS - modeling_auto.py: add moe_prefill_num_nsp param to compile()/export()/ get_seq_len; pass moe_prefill_num_nsp to lang_model.export() in VLM path; fix sliding_window AttributeError for models without sliding_window attr --- QEfficient/base/modeling_qeff.py | 4 +++ QEfficient/transformers/modeling_utils.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 12 +++---- .../transformers/models/modeling_auto.py | 33 +++++++++++++------ 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b69d2d0a5c..51681846f5 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -395,6 +395,7 @@ def get_onnx_path( mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, moe_prefill_packed_chunk_size: Optional[int] = None, + moe_prefill_num_nsp: Optional[int] = None, **compiler_options, ): kwargs = { @@ -411,6 +412,7 @@ def get_onnx_path( "prefill_seq_len": specializations[0].get("seq_len"), "enable_chunking": enable_chunking, "num_cores": compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + "moe_prefill_num_nsp": moe_prefill_num_nsp, "moe_prefill_packed_chunk_size": constants.MOE_PREFILL_PACKED_CHUNK_SIZE if moe_prefill_packed_chunk_size is None else moe_prefill_packed_chunk_size, @@ -533,6 +535,7 @@ def _compile( """ moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None) + moe_prefill_num_nsp = compiler_options.pop("moe_prefill_num_nsp", None) onnx_path = Path( onnx_path if onnx_path @@ -549,6 +552,7 @@ def _compile( num_devices=mdp_ts_num_devices, qaic_config=qaic_config, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + moe_prefill_num_nsp=moe_prefill_num_nsp, **compiler_options, ) ) diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 183c19f6f6..648baf2b66 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "kimi_k2", "kimi_k25"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "granitemoe", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 48436e9838..8ca3930714 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -696,13 +696,13 @@ def _cumsum_scatter_gather_update_granitemoe_expert_blocked( class QEffPrefillChunkedGraniteMoeMoE(GraniteMoeMoE): """NSP-blocked prefill dispatch for GraniteMoE. - supports_moe_prefill_blocking = True - Replaces the per-expert loop in QEffGraniteMoeMoE with a cumsum-scatter-gather-update strategy that only runs the MLP on active tokens, mirroring the Qwen3-MoE implementation. """ + supports_moe_prefill_blocking = True + def __qeff_init__(self): W_gate_up = self.input_linear.weight # [E, 2I, H] I = W_gate_up.shape[1] // 2 @@ -723,10 +723,10 @@ def _forward_expert_blocked( local_experts = num_experts // num_nsp I = self._W_g.shape[2] - rw = routing_weights.transpose(0, 1).contiguous().view(num_nsp, local_experts, T) - W_g = self._W_g.view(num_nsp, local_experts, H, I) - W_u = self._W_u.view(num_nsp, local_experts, H, I) - W_d = self._W_d.view(num_nsp, local_experts, I, H) + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self._W_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self._W_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self._W_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() expert_out = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 87a7990306..a2bfa42a2f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1078,6 +1078,7 @@ def export( prefill_only: bool = False, enable_chunking: bool = False, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ): @@ -1119,10 +1120,11 @@ def export( num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) for module in self.model.modules(): if getattr(module, "supports_moe_prefill_blocking", False): - module.expert_blocking_num_nsp = num_cores - module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size - module.expert_blocking_num_packed_chunks = num_packed_chunks - self.hash_params["moe_prefill_num_nsp"] = num_cores + if moe_prefill_num_nsp is not None: + module.expert_blocking_num_nsp = moe_prefill_num_nsp + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = moe_prefill_num_nsp self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks else: @@ -1399,6 +1401,7 @@ def export( enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, num_cores=kwargs.get("num_cores", constants.DEFAULT_AIC_NUM_CORES), + moe_prefill_num_nsp=kwargs.get("moe_prefill_num_nsp", None), moe_prefill_packed_chunk_size=kwargs.get("moe_prefill_packed_chunk_size", constants.MOE_PREFILL_PACKED_CHUNK_SIZE), ) return self.onnx_path @@ -1453,6 +1456,7 @@ def compile( use_onnx_subfunctions: bool = False, prefill_only=None, enable_chunking=False, + moe_prefill_num_nsp: Optional[int] = None, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, qaic_config: Optional[dict] = None, **compiler_options, @@ -1593,6 +1597,7 @@ def compile( enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, num_cores=num_cores, + moe_prefill_num_nsp=moe_prefill_num_nsp, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) @@ -3010,6 +3015,7 @@ def get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len: Optional[int] = None, enable_chunking=False, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, ) -> int: self.hash_params["prefill_only"] = True @@ -3019,10 +3025,11 @@ def get_seq_len_and_handle_specialized_prefill_model( num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) for module in self.model.modules(): if getattr(module, "supports_moe_prefill_blocking", False): - module.expert_blocking_num_nsp = num_cores - module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size - module.expert_blocking_num_packed_chunks = num_packed_chunks - self.hash_params["moe_prefill_num_nsp"] = num_cores + if moe_prefill_num_nsp is not None: + module.expert_blocking_num_nsp = moe_prefill_num_nsp + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = moe_prefill_num_nsp self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -3072,6 +3079,7 @@ def export( prefill_only: Optional[bool] = False, prefill_seq_len: Optional[int] = None, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ) -> str: @@ -3126,11 +3134,12 @@ def export( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking, num_cores=num_cores, + moe_prefill_num_nsp=moe_prefill_num_nsp, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) kv_cache_shape[2] = ( seq_len - + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + + (getattr(self.model.config, 'sliding_window', None) or 0) if enable_chunking else seq_len ) @@ -3146,7 +3155,7 @@ def export( self.hash_params.pop("moe_prefill_num_packed_chunks", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( - self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + (getattr(self.model.config, 'sliding_window', None) or 0) ) self.hash_params["retain_full_kv"] = True @@ -3432,6 +3441,8 @@ def compile( use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: @@ -3681,6 +3692,8 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, + moe_prefill_num_nsp=moe_prefill_num_nsp, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, **compiler_options, )