diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e9213761d9..51681846f5 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -394,6 +394,8 @@ def get_onnx_path( retain_full_kv: Optional[bool] = False, mla_absorption: Optional[Dict[str, bool]] = None, qaic_config: Optional[dict] = None, + moe_prefill_packed_chunk_size: Optional[int] = None, + moe_prefill_num_nsp: Optional[int] = None, **compiler_options, ): kwargs = { @@ -409,6 +411,11 @@ def get_onnx_path( "prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len"), "enable_chunking": enable_chunking, + "num_cores": compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES), + "moe_prefill_num_nsp": moe_prefill_num_nsp, + "moe_prefill_packed_chunk_size": constants.MOE_PREFILL_PACKED_CHUNK_SIZE + if moe_prefill_packed_chunk_size is None + else moe_prefill_packed_chunk_size, } ) @@ -527,6 +534,8 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None) + moe_prefill_num_nsp = compiler_options.pop("moe_prefill_num_nsp", None) onnx_path = Path( onnx_path if onnx_path @@ -542,6 +551,8 @@ def _compile( mla_absorption, num_devices=mdp_ts_num_devices, qaic_config=qaic_config, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + moe_prefill_num_nsp=moe_prefill_num_nsp, **compiler_options, ) ) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index c27e3cc704..dbbb51e5b2 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -21,11 +21,15 @@ CtxGatherBlockedKV, CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatter, CtxScatter3D, + CtxScatter3DInt, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherBlockedKVCB, @@ -92,8 +96,11 @@ class CustomOpTransform(BaseOnnxTransform): "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), "CtxScatterFunc": (CtxScatterFunc, CtxScatter), "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), + "CtxScatterFunc3DInt": (CtxScatterFunc3DInt, CtxScatter3DInt), + "CtxScatterFunc3DGeneralized": (CtxScatterFunc3DGeneralized, CtxScatter3D), "CtxGatherFunc": (CtxGatherFunc, CtxGather), "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), + "CtxGatherFunc3DGeneralized": (CtxGatherFunc3DGeneralized, CtxGather3D), "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index 35830aa91e..4830e660c3 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -8,9 +8,12 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncBlockedKVCB, @@ -26,7 +29,10 @@ "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", + "CtxGatherFunc3DGeneralized", "CtxScatterFunc3D", + "CtxScatterFunc3DGeneralized", + "CtxScatterFunc3DInt", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 59bfe6af03..b1f322c606 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates # Create indices batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + + # keep index tensor types aligned for backend that require exact dtype match + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) indices = ops.Concat(batch_idx, ctx_idx, axis=2) @@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() batch_idx = torch.arange(data.shape[0]).view(-1, 1) - ctx_idx = position_ids + ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) data[batch_idx, ctx_idx] = updates return data @@ -92,9 +96,80 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) +class CtxScatterFunc3DGeneralized(torch.autograd.Function): + """Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions. + + Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to + ``data.shape[1]-1`` (potentially clobbering valid content), this version + masks out invalid rows before scattering so ``data`` is left untouched where + ``position_ids == INT32_MAX``. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter3DInt( + data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32 +) -> onnxscript.INT32: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) + indices = ops.Concat(batch_idx, ctx_idx, axis=2) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc3DInt(torch.autograd.Function): + """Int32-typed scatter used to build a packed->original index table.""" + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) + batch_size = ops.Slice(ops.Shape(data), starts=[0], ends=[1], axes=[0]) + idx_seq_len = ops.Slice(ops.Shape(ctx_indices), starts=[1], ends=[2], axes=[0]) + expand_shape = ops.Concat(batch_size, idx_seq_len, axis=0) + ctx_indices = ops.Expand(ctx_indices, expand_shape) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=1) @@ -103,6 +178,7 @@ class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, ctx_indices] @staticmethod @@ -114,6 +190,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data) +class CtxGatherFunc3DGeneralized(torch.autograd.Function): + """Gather variant that tolerates INT32_MAX indices (invalid rows read from 0). + + Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but + exposed as a separate autograd op so callers using the packed/cumsum scatter + pipeline can be easily recognized and so the ONNX symbolic omits + ``setTypeAs`` (needed when the caller already has a matching dtype on + ``data`` and wants the op signature to flow through without dtype pinning). + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) + return data[batch_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather3D, data, ctx_indices) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather( data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f9d7fe62cd..648baf2b66 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "granitemoe", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f805bfd4c..67f1a485ad 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -36,6 +36,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -50,7 +55,160 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_gptoss_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + b_g: torch.Tensor, + b_u: torch.Tensor, + b_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + limit: float, + alpha: float, + T: int, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update expert helper for GPT-OSS NSP-blocked dispatch. + + Same algorithm as the Qwen3-MOE version but with GPT-OSS biases and GLU + activation (clamped gate/up, ``(up + 1) * gate * sigmoid(gate * alpha)``). + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + b_g, b_u : [num_nsp, I] + b_d : [num_nsp, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + # Trace time: seq_len=32 → pcs=16 (valid slice) + # Runtime: seq_len=512 → pcs=256 (correct size) ✅ + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate = (x_chunk @ W_g) + b_g.unsqueeze(1) + up = (x_chunk @ W_u) + b_u.unsqueeze(1) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + intermediate = (up + 1) * glu + down_chunk = (intermediate @ W_d) + b_d.unsqueeze(1) + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + supports_moe_prefill_blocking = True + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_nsp = self.expert_blocking_num_nsp + num_experts = self.experts.num_experts + if num_experts % num_nsp != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})") + + local_experts = num_experts // num_nsp + expert_dim = self.experts.expert_dim + routing_weights_by_expert = ( + routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + ) + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, expert_dim, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() + b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + + expert_out = x.new_zeros((num_nsp, T, H)) + for local_slot in range(local_experts): + routing_weight = routing_weights_by_expert[:, local_slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_gptoss_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, local_slot], + W_u=W_u[:, local_slot], + W_d=W_d[:, local_slot], + b_g=b_g[:, local_slot], + b_u=b_u[:, local_slot], + b_d=b_d[:, local_slot], + routing_weight=routing_weight, + expert_out=expert_out, + limit=self.experts.limit, + alpha=self.experts.alpha, + T=T, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, + ) + + return torch.einsum("ijk->jk", expert_out) + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S @@ -69,6 +227,10 @@ def forward(self, hidden: torch.Tensor): # Routing weights for each expert [T, E] routing_weights = masked_logits + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.experts.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + # ────────────────── allocate the output tensor ───── expert_out = hidden.new_zeros((T, H)) # accumulation buffer diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8728b4d3e4..8ca3930714 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- from typing import List, Optional, Tuple, Type, Union +import os import torch from torch import nn @@ -32,6 +33,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -89,6 +95,62 @@ def qeff_apply_rotary_pos_emb( return q_embed.to(q.dtype), k_embed.to(k.dtype) +class QEffPrefillChunkedGraniteMoeAttention(GraniteMoeAttention): + """Prefill-chunked attention for GraniteMoE — no sliding window.""" + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + + key_states, value_states, _ = past_key_value_update( + module=self, + key=key_states, + value=value_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids, + ) + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -541,6 +603,182 @@ def forward(self, layer_input): return final_hidden_states, router_logits + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table via cumsum scatter.""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_granitemoe_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + activation, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update for one local GraniteMoE expert slot. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] bool + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] accumulator (in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate = x_chunk @ W_g + up = x_chunk @ W_u + down_chunk = (activation(gate) * up) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + +class QEffPrefillChunkedGraniteMoeMoE(GraniteMoeMoE): + """NSP-blocked prefill dispatch for GraniteMoE. + + Replaces the per-expert loop in QEffGraniteMoeMoE with a + cumsum-scatter-gather-update strategy that only runs the MLP on active + tokens, mirroring the Qwen3-MoE implementation. + """ + + supports_moe_prefill_blocking = True + + def __qeff_init__(self): + W_gate_up = self.input_linear.weight # [E, 2I, H] + I = W_gate_up.shape[1] // 2 + self._W_g = nn.Parameter(W_gate_up[:, :I, :].transpose(1, 2).contiguous()) # [E, H, I] + self._W_u = nn.Parameter(W_gate_up[:, I:, :].transpose(1, 2).contiguous()) # [E, H, I] + self._W_d = nn.Parameter(self.output_linear.weight.transpose(1, 2).contiguous()) # [E, I, H] + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_experts = self.router.num_experts + num_nsp = self.expert_blocking_num_nsp + if num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})." + ) + local_experts = num_experts // num_nsp + I = self._W_g.shape[2] + + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self._W_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self._W_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self._W_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_granitemoe_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + activation=self.activation, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, + ) + return torch.einsum("ijk->jk", expert_out) + + def orig_forward(self, layer_input: torch.Tensor): + """Original per-expert loop — kept for parity testing.""" + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + topk_gates, expert_mask, router_logits, num_experts = self.router(layer_input) + final_hidden_states = torch.zeros_like(layer_input) + for expert_idx in range(num_experts): + mask = expert_mask[expert_idx].transpose(0, 1).to(layer_input.dtype) + mask_weight = torch.einsum("be,be->b", topk_gates, mask.to(topk_gates.dtype))[:, None] + hidden_states = self.input_linear(layer_input, expert_idx) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_idx) + current_hidden_states = torch.where(mask_weight > 0, expert_outputs * mask_weight, 0.0) + final_hidden_states += current_hidden_states + final_hidden_states = final_hidden_states.view(bsz, length, self.input_size) + return final_hidden_states, router_logits + + def forward(self, layer_input: torch.Tensor): + bsz, length, emb_size = layer_input.size() + x = layer_input.reshape(-1, emb_size) + topk_gates, expert_mask, router_logits, num_experts = self.router(x) + + # Convert [E, top_k, T] + [T, top_k] -> flat [T, E] routing weights + routing_weights = torch.einsum("tk,ekt->te", topk_gates, expert_mask.float()) + + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(bsz, length, self.input_size), router_logits + + return self.orig_forward(layer_input) + + class QEffGraniteMoeParallelExperts(GraniteMoeParallelExperts): def forward(self, inputs, expert_size): """ @@ -570,6 +808,12 @@ def get_submodules_for_export(self) -> Type[nn.Module]: This method should return the *class object* (not an instance). Downstream code can use this to find/build subfunctions for repeated blocks. """ + # When the chunked-prefill MoE block is active it emits CtxScatter3DInt, + # which the compiler does not support inside ONNX subfunctions. + # Return an empty set so -sub-functions is not passed to qaic-compile. + first_layer_moe = self.model.layers[0].block_sparse_moe + if isinstance(first_layer_moe, QEffPrefillChunkedGraniteMoeMoE): + return set() return {self.model.layers[0].__class__} def forward( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c5c50f1c7d..a2bfa42a2f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1077,6 +1077,9 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ): """ @@ -1110,6 +1113,20 @@ def export( ) self.hash_params["prefill_only"] = True self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) + # Set API-driven MoE blocking params on VLM MoE layers AFTER __update_prefill_transform + # so they are set on the already-swapped chunked class instances. + if enable_chunking and prefill_seq_len is not None: + compile_seq_len = prefill_seq_len + num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) + for module in self.model.modules(): + if getattr(module, "supports_moe_prefill_blocking", False): + if moe_prefill_num_nsp is not None: + module.expert_blocking_num_nsp = moe_prefill_num_nsp + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = moe_prefill_num_nsp + self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size + self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks else: self.hash_params["prefill_only"] = False self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) @@ -1383,6 +1400,9 @@ def export( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + num_cores=kwargs.get("num_cores", constants.DEFAULT_AIC_NUM_CORES), + moe_prefill_num_nsp=kwargs.get("moe_prefill_num_nsp", None), + moe_prefill_packed_chunk_size=kwargs.get("moe_prefill_packed_chunk_size", constants.MOE_PREFILL_PACKED_CHUNK_SIZE), ) return self.onnx_path @@ -1436,6 +1456,8 @@ def compile( use_onnx_subfunctions: bool = False, prefill_only=None, enable_chunking=False, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, qaic_config: Optional[dict] = None, **compiler_options, ) -> str: @@ -1574,6 +1596,9 @@ def compile( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + num_cores=num_cores, + moe_prefill_num_nsp=moe_prefill_num_nsp, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -2986,11 +3011,27 @@ def get_model_config(self) -> dict: return self.model.config.__dict__ def get_seq_len_and_handle_specialized_prefill_model( - self, prefill_seq_len: Optional[int] = None, enable_chunking=False + self, + prefill_seq_len: Optional[int] = None, + enable_chunking=False, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, ) -> int: self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True + compile_seq_len = prefill_seq_len or constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + num_packed_chunks = max(1, -(-compile_seq_len // moe_prefill_packed_chunk_size)) + for module in self.model.modules(): + if getattr(module, "supports_moe_prefill_blocking", False): + if moe_prefill_num_nsp is not None: + module.expert_blocking_num_nsp = moe_prefill_num_nsp + module.expert_blocking_packed_chunk_size = moe_prefill_packed_chunk_size + module.expert_blocking_num_packed_chunks = num_packed_chunks + self.hash_params["moe_prefill_num_nsp"] = moe_prefill_num_nsp + self.hash_params["moe_prefill_packed_chunk_size"] = moe_prefill_packed_chunk_size + self.hash_params["moe_prefill_num_packed_chunks"] = num_packed_chunks return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN num_q_blocks = ( @@ -3037,6 +3078,9 @@ def export( export_dir: Optional[str] = None, prefill_only: Optional[bool] = False, prefill_seq_len: Optional[int] = None, + num_cores: int = constants.DEFAULT_AIC_NUM_CORES, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, **kwargs, ) -> str: """ @@ -3087,11 +3131,15 @@ def export( self.hash_params.pop("retain_full_kv", None) if "DeepseekV3ForCausalLM" not in (getattr(self.model.config, "architectures", None) or []): seq_len = self.get_seq_len_and_handle_specialized_prefill_model( - prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + prefill_seq_len=prefill_seq_len, + enable_chunking=enable_chunking, + num_cores=num_cores, + moe_prefill_num_nsp=moe_prefill_num_nsp, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, ) kv_cache_shape[2] = ( seq_len - + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + + (getattr(self.model.config, 'sliding_window', None) or 0) if enable_chunking else seq_len ) @@ -3102,9 +3150,12 @@ def export( self.hash_params.pop("NUM_FFN_BLOCKS", None) self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) + self.hash_params.pop("moe_prefill_num_nsp", None) + self.hash_params.pop("moe_prefill_packed_chunk_size", None) + self.hash_params.pop("moe_prefill_num_packed_chunks", None) if kwargs.get("retain_full_kv", False): kv_cache_shape[2] = seq_len + ( - self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + (getattr(self.model.config, 'sliding_window', None) or 0) ) self.hash_params["retain_full_kv"] = True @@ -3390,6 +3441,8 @@ def compile( use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, + moe_prefill_num_nsp: Optional[int] = None, + moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: @@ -3639,6 +3692,8 @@ def compile( offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, + moe_prefill_num_nsp=moe_prefill_num_nsp, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5ff06e6443..899e6acbc5 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -324,9 +324,11 @@ ) from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( QEffGraniteMoeAttention, + QEffPrefillChunkedGraniteMoeAttention, QEffGraniteMoeForCausalLM, QEffGraniteMoeModel, QEffGraniteMoeMoE, + QEffPrefillChunkedGraniteMoeMoE, QEffGraniteMoeParallelExperts, QEffGraniteMoeRotaryEmbedding, QEffGraniteMoeTopKGating, @@ -754,6 +756,9 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, # Qwen3 VL Moe QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, + # GraniteMoe + QEffGraniteMoeMoE: QEffPrefillChunkedGraniteMoeMoE, + QEffGraniteMoeAttention: QEffPrefillChunkedGraniteMoeAttention, } @@ -767,6 +772,11 @@ class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, # Qwen3Moe QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + # GraniteMoe + QEffPrefillChunkedGraniteMoeMoE: QEffGraniteMoeMoE, + QEffPrefillChunkedGraniteMoeAttention: QEffGraniteMoeAttention, + # Qwen3 VL Moe + QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index de92eae8f7..52386d82af 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import os from typing import List, Optional, Tuple, Type import torch @@ -32,6 +33,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -100,8 +106,153 @@ def eager_attention_forward( return attn_output, attn_weights + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + act_fn, + T: int, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. + + Accumulates one local expert's contribution in-place onto ``expert_out``. + Uses a packed/cumsum layout so the MLP runs only over active rows, then + scatters the weighted output back to original token positions. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops → DYNAMIC at runtime (scales with actual seq_len). + # Trace time: seq_len=32 → pcs=16 (valid slice) + # Runtime: seq_len=512 → pcs=256 (correct size) ✅ + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = x_chunk @ W_g + up_prime = x_chunk @ W_u + down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + supports_moe_prefill_blocking = True + + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] + self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] + self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_nsp = self.expert_blocking_num_nsp + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + act_fn=self.experts[0].act_fn, + T=T, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, + ) + return torch.einsum("ijk->jk", expert_out) + + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) @@ -113,22 +264,37 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = top_w.to(hidden_states.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── - expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── + expert_out = x.new_zeros((T, H)) for e in range(self.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] - W_d = self.experts[e].down_proj.weight.T # [I, H] - gate = x @ W_g # [T, I] - up = x @ W_u # [T, I] - down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] - masked_down = down * routing_weight - expert_out += masked_down + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight return expert_out.view(B, S, H), router_logits + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + return self.orig_forward(hidden_states) + class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): def __qeff_init__(self): diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 976c80919b..075f958ddc 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- import math +import os from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -37,6 +38,11 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants @@ -633,41 +639,186 @@ def _deepstack_process( return local_this + + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table via cumsum scatter.""" + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + matched_idx = torch.full_like(token_idx, int32_max) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_qwen3vlmoe_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, + act_fn, + packed_chunk_size: int, + num_expert_chunks: Optional[int] = None, +) -> torch.Tensor: + """Cumsum-scatter-gather-update expert helper for Qwen3-VL-MoE NSP-blocked dispatch. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + # num_expert_chunks controls loop iteration count (unrolled at trace time). + # packed_chunk_size = seq_len // num_expert_chunks is computed via ONNX + # Shape+Div ops -> DYNAMIC at runtime (scales with actual seq_len). + if num_expert_chunks is not None: + packed_chunk_size = seq_len // num_expert_chunks + else: + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + num_expert_chunks = seq_len // packed_chunk_size + + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for chunk_idx in range(num_expert_chunks): + packed_start = chunk_idx * packed_chunk_size + packed_stop = packed_start + packed_chunk_size if chunk_idx < num_expert_chunks - 1 else seq_len + chunk_size = packed_stop - packed_start + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] + + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = x_chunk @ W_g + up_prime = x_chunk @ W_u + down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + row_range = torch.arange(chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + rows_remaining = valid_rows - packed_start + chunk_valid_rows = torch.where(rows_remaining < 0, torch.zeros_like(rows_remaining), rows_remaining) + chunk_valid_rows = torch.where( + chunk_valid_rows > chunk_size, + torch.ones_like(chunk_valid_rows) * chunk_size, + chunk_valid_rows, + ) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) + + return expert_out + + class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """NSP-blocked prefill dispatch for Qwen3-VL-MoE text sparse MoE block.""" + + supports_moe_prefill_blocking = True + + def __qeff_init__(self): + # Split fused gate_up_proj [E, H, 2I] into separate gate and up [E, H, I] + W_gate_up = self.experts.gate_up_proj # [E, H, 2I] + I = W_gate_up.shape[2] // 2 + self._W_g = nn.Parameter(W_gate_up[:, :, :I].contiguous()) # [E, H, I] + self._W_u = nn.Parameter(W_gate_up[:, :, I:].contiguous()) # [E, H, I] + self._W_d = nn.Parameter(self.experts.down_proj.contiguous()) # [E, I, H] + + def _forward_expert_blocked( + self, x: torch.Tensor, routing_weights: torch.Tensor, num_expert_chunks: Optional[int] = None + ) -> torch.Tensor: + T, H = x.shape + num_nsp = self.expert_blocking_num_nsp + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + I = self._W_g.shape[2] + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self._W_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self._W_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self._W_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_qwen3vlmoe_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, + act_fn=self.experts.act_fn, + packed_chunk_size=self.expert_blocking_packed_chunk_size, + num_expert_chunks=self.expert_blocking_num_packed_chunks, + ) + return torch.einsum("ijk->jk", expert_out) + + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Original per-expert loop — kept for parity testing.""" B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - act = getattr(self.experts, "act_fn", F.silu) - - router_logits = self.gate(x) # [T, E] - prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + act = self.experts.act_fn + router_logits = self.gate(x) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) - top_w = top_w / torch.einsum("bi->b", top_w)[:, None] - top_w = top_w.to(hidden_states.dtype) - routing_weights = torch.zeros((T, self.num_experts), dtype=x.dtype) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] # norm_topk_prob always True + top_w = top_w.to(x.dtype) + routing_weights = torch.zeros_like(router_logits) routing_weights.scatter_(1, top_i, top_w) - - expert_out = torch.zeros_like(x, dtype=x.dtype) - + expert_out = x.new_zeros((T, H)) for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - gate_up = x @ W_gate_up_e # [T, 2I] - + W_dn_e = self.experts.down_proj[e] # [I, H] + gate_up = x @ W_gate_up_e I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] + gate, up = gate_up[:, :I2], gate_up[:, I2:] intermediate = up * act(gate) down = intermediate @ W_dn_e - masked_down = torch.where( - routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) - ) # TODO: verify and remove - expert_out += masked_down - expert_out = expert_out.to(x.dtype).view(B, S, H) - return expert_out, router_logits + expert_out += torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) + return expert_out.view(B, S, H), router_logits + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] # norm_topk_prob always True + top_w = top_w.to(x.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if getattr(self, "supports_moe_prefill_blocking", False) and hasattr(self, "expert_blocking_num_nsp") and self.num_experts % self.expert_blocking_num_nsp == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + return self.orig_forward(hidden_states) class QEffQwen3VLMoeModel(Qwen3VLMoeModel): diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..22fad9a611 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -1,318 +1,319 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import os -from dataclasses import dataclass - -UTILS_DIR = os.path.dirname(os.path.abspath(__file__)) -QEFF_DIR = os.path.dirname(UTILS_DIR) -ROOT_DIR = os.path.dirname(QEFF_DIR) -QEFF_CACHE_DIR_NAME = "qeff_cache" - -ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 -ONNX_EXPORT_EXAMPLE_FBS = 4 -ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_MAX_NUM_IMAGES = 1 -ONNX_EXPORT_MAX_IMAGE_TILES = 4 -ONNX_EXPORT_IMAGE_WIDTH = 560 -ONNX_EXPORT_IMAGE_LENGHT = 560 -ONNX_EXPORT_IMAGE_DEPTH = 3 -ONNX_EXPORT_CTX_LEN = 1024 - -NPI_MAPPING = { - "google/gemma-3-4b-it": os.path.join( - QEFF_DIR, "transformers", "models", "gemma3", "configs", "fp32_nodes_gemma3_4b.yaml" - ), - "google/gemma-3-27b-it": os.path.join( - QEFF_DIR, "transformers", "models", "gemma3", "configs", "gemma_updated_npi.yaml" - ), -} - -# Blocking defaults -VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 - -# Compiler defaults -DEFAULT_AIC_NUM_CORES = 16 -DEFAULT_AIC_MXPF6_MATMUL = False -# Hashing defaults -HASH_HEXDIGEST_STR_LEN = 16 -KWARGS_INCLUSION_LIST = [ - "state_dict", - "revision", - "key_mapping", - "commit_hash", - "adapter_kwargs", - "adapter_name", - "gguf_file", - "pretrained_model_name_or_path", - "attn_implementation", - "_attn_implementation", - "qaic_config", -] - -# Minimum value for causal mask -MIN_MASKED_ATTENTION_VALUE = float("-inf") - - -# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable. -def get_models_dir(): - """ - Determine the directory for storing QEFF models. - Priority: - 1. Use $XDG_CACHE_HOME/qeff_models if XDG_CACHE_HOME is set. - 2. Use QEFF_HOME if set in environment. - 3. Default to ~/.cache/qeff_models. - Sets QEFF_MODELS_DIR environment variable if not already set. - Returns: - str: Path to the QEFF models directory. - """ - qeff_cache_home = os.environ.get("QEFF_HOME") - # Check if XDG_CACHE_HOME is set - xdg_cache_home = os.environ.get("XDG_CACHE_HOME") - if qeff_cache_home: - qeff_models_dir = os.path.join(qeff_cache_home, QEFF_CACHE_DIR_NAME) - # Check if QEFF_MODELS_DIR is set - elif xdg_cache_home: - qeff_models_dir = os.path.join(xdg_cache_home, QEFF_CACHE_DIR_NAME) - else: - # Use ~/.cache/qeff_models as the default - qeff_models_dir = os.path.join(os.path.expanduser("~"), ".cache", QEFF_CACHE_DIR_NAME) - - # Set QEFF_MODELS_DIR environment variable - return qeff_models_dir - - -QEFF_MODELS_DIR = get_models_dir() - -ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 -ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 -ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 -ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 -ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 -ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 17 -FILE_CHUNK_SIZE_DEFAULT = 10 * 2**30 # 10 GB -SIZE_THRESHOLD_DEFAULT = 1024 - - -COMPILER = ["/opt/qti-aic/exec/qaic-compile", "-aic-hw"] -DEFAULT_AIC_HW_VERSION = "ai100" -ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 - -# InternVL constants -# Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B -INTERN_FEATURE_SIZE = 256 -INTERN_NUM_PATCHES = 13 -INTERN_IMG_SIZE = 448 -INTERN_CTX_LEN = 4096 -INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 -INTERN_NUM_CHANNELS = 3 -INTERN_IMAGE_HEIGHT = 1000 -INTERN_IMAGE_WIDTH = 747 - -INTERN_IMG_CONTEXT_TOKEN = 151667 -# Specific to InternVL3_5 series, same token won't work for InternVL2_5 series -INTERN_3_5_IMG_CONTEXT_TOKEN = 151671 - -# Granite Vision Constants -# Fixing the feature size with reference to ibm-granite/granite-vision-3.2-2b -GRANITEVISION_FEATURE_SIZE = 5239 -GRANITEVISION_NUM_PATCHES = 10 -GRANITEVISION_IMG_SIZE = 384 -GRANITEVISION_IMG_SIZE_HEIGHT = 1109 -GRANITEVISION_IMG_SIZE_WIDTH = 1610 -GRANITEVISION_PIXEL_VALUE_DIM = 5 -GRANITEVISION_PREFIL_SEQ_LEN = GRANITEVISION_SEQ_LEN = 5500 -GRANITEVISION_CTX_LEN = 6000 -GRANITEVISION_NUM_CHANNELS = 3 - -VISION_MXFP6_MATMUL = False -# Llama4 Constants -LLAMA4_ATTENTION_CHUNK_SIZE = 8192 -LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 - -# DeepSeek Kimi-k2 Constant -MAX_POSITION_EMBEDDINGS = 32768 -FP16_BYTES = 2 -DEFAULT_NUM_HEADS = 64 -KV_LORA_RANK = 512 -ROPE_DIM = 64 - -# Wav2Vec2 Constant -WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) - -# Qwen2_5_vl Constants -QWEN2_5_VL_HEIGHT = 354 -QWEN2_5_VL_WIDTH = 536 - -# Qwen3_vl Constanst -QWEN3_VL_HEIGHT = 354 -QWEN3_VL_WIDTH = 536 - -# Modules to cache while clearing the pytorch weights -CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] - -# Mistral3 Constants -MISTRAL3_IMAGE_HEIGHT = 1540 -MISTRAL3_IMAGE_WIDTH = 1540 - -# Molmo Constants -MOLMO_IMAGE_HEIGHT = 536 -MOLMO_IMAGE_WIDTH = 354 -# Flux Transformer Constants -FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 -FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 -FLUX_ADALN_HIDDEN_DIM = 3072 -FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context -FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 -FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM - -# Wan Transformer Constants -WAN_TEXT_EMBED_DIM = 5120 -WAN_PROJECTION_DIM = 6 -WAN_ONNX_EXPORT_BATCH_SIZE = 1 -WAN_ONNX_EXPORT_FRAMES = 81 -WAN_ONNX_EXPORT_LATENT_FRAMES = 21 -WAN_ONNX_EXPORT_SEQ_LEN = 512 -WAN_ONNX_EXPORT_ROTARY_DIM = 128 -WAN_DIT_OUT_CHANNELS = 64 -# Wan dims for 45p -WAN_ONNX_EXPORT_CL_45P = 252 -WAN_ONNX_EXPORT_LATENT_HEIGHT_45P = 6 -WAN_ONNX_EXPORT_LATENT_WIDTH_45P = 8 -WAN_ONNX_EXPORT_HEIGHT_45P = 48 -WAN_ONNX_EXPORT_WIDTH_45P = 64 - -# WAN I2V -WAN_DIT_I2V_IMG_LATENT_CHANNELS = 32 - -# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length -CCL_START_MAP = { - 32768: (4096, 4000), - 65536: (8192, 8000), - float("inf"): (16384, 16000), -} -# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process. -CCL_MAX_ELEMENTS_LISTS = 5 -CCL_START_CTX_LEN = 4096 -CCL_MIN_CTX_LEN = 1024 -CCL_UNIQNE_STEP = 32 - -# used for gpt-oss prefill-only model Q-blocking -GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256 - - -class Constants: - # Export Constants. - SEQ_LEN = 32 - CTX_LEN = 32 - PROMPT_LEN = 8 - INPUT_STR = ["My name is"] - GB = 2**30 - MAX_QPC_LIMIT = 30 - MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download - NUM_SPECULATIVE_TOKENS = 2 - MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS - SAMPLER_OPS = { - "repetition_penalties", - "presence_penalties", - "temperatures", - "top_ks", - "top_ps", - "min_ps", - "random_numbers", - } - SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"} - SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. - SDK_PLATFORM_XML = ( - "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. - ) - - -@dataclass -class QnnConstants: - # QNN PATH to be read from environment variable. - QNN_SDK_PATH_ENV_VAR_NAME = "QNN_SDK_ROOT" - QNN_SDK_YAML = "sdk.yaml" - - # QNN Compilation tools - QAIRT_CONVERTER = "{}/bin/{}/qairt-converter" - QNN_CONTEXT_BIN = "{}/bin/{}/qnn-context-binary-generator" - - # QNN Libraries required for compilation - QNN_CONTEXT_LIB_BACKEND = "{}/lib/{}/libQnnAic.so" - QNN_CONTEXT_LIB_NET_RUN_EXTENSIONS = "{}/lib/{}/libQnnAicNetRunExtensions.so" - - # QNN Compilation target names - MODEL_NAME = "model" - QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json" - CONTEXT_BIN_NAME = "qnngraph.serialized" - CONTEXT_BIN_QPC_NAME = "programqpc.bin" - - # TARGET System Architecture - TARGET = "x86_64-linux-clang" # TODO add support in infer to be override - - # Converter Arguments - FLOAT_BITWIDTH = 16 - FLOAT_BIAS_BITWIDTH = 32 - CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification --target_backend AIC " - - # Context-Binary-Generator Arguments - LOG_LEVEL = "error" - - # qnn_compilation_backend default Arguments - COMPILER_COMPILATION_TARGET = "hardware" - COMPILER_CONVERT_TO_FP16 = True - COMPILER_DO_DDR_TO_MULTICAST = True - COMPILER_HARDWARE_VERSION = "2.0" - COMPILER_PERF_WARNINGS = False - COMPILER_PRINT_DDR_STATS = False - COMPILER_PRINT_PERF_METRICS = False - COMPILER_RETAINED_STATE = True - COMPILER_STAT_LEVEL = 10 - COMPILER_STATS_BATCH_SIZE = 1 - COMPILER_TIME_PASSES = False - GRAPH_NAMES = [f"{MODEL_NAME}_configuration_1", f"{MODEL_NAME}_configuration_2"] - GRAPH_NAMES_PREFILL_ONLY = [f"{MODEL_NAME}"] - - # qnn_config JSON file supported Keys - CONVERTER_ARGS_EXTENSION_STR = "converter_args_extension" - CONTEXT_BIN_ARGS_EXTENSION_STR = "context_binary_generator_args_extension" - QNN_COMPILATION_BACKEND_STR = "qnn_compilation_backend" - SKIP_QNN_CONVERTER_STEP_STR = "SKIP_QNN_CONVERTER_STEP" - - IMMUTABLE_CONVERTER_ARGS = [ - "--input_network ", - "--output_path ", - "--config ", - "--float_bias_bitwidth ", - "--float_bitwidth ", - "--preserve_io_datatype", - "--onnx_skip_simplification", - ] - - IMMUTABLE_CONTEXT_BIN_GEN_ARGS = [ - "--binary_file ", - "--backend_binary ", - "--output_dir ", - "--backend ", - "--model ", - "--dlc_path ", - "--config_file ", - ] - - QNN_SAMPLE_CONFIG = { - "converter_args_extension": "--onnx_defer_loading", - "context_binary_generator_args_extension": "--log_level debug", - "qnn_compilation_backend": { - "compiler_enable_depth_first": True, - "compiler_printDDRStats": False, - "compiler_printPerfMetrics": False, - }, - "SKIP_QNN_CONVERTER_STEP": False, - } +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +from dataclasses import dataclass + +UTILS_DIR = os.path.dirname(os.path.abspath(__file__)) +QEFF_DIR = os.path.dirname(UTILS_DIR) +ROOT_DIR = os.path.dirname(QEFF_DIR) +QEFF_CACHE_DIR_NAME = "qeff_cache" + +ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 +ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 +ONNX_EXPORT_EXAMPLE_FBS = 4 +ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep +ONNX_EXPORT_MAX_NUM_IMAGES = 1 +ONNX_EXPORT_MAX_IMAGE_TILES = 4 +ONNX_EXPORT_IMAGE_WIDTH = 560 +ONNX_EXPORT_IMAGE_LENGHT = 560 +ONNX_EXPORT_IMAGE_DEPTH = 3 +ONNX_EXPORT_CTX_LEN = 1024 + +NPI_MAPPING = { + "google/gemma-3-4b-it": os.path.join( + QEFF_DIR, "transformers", "models", "gemma3", "configs", "fp32_nodes_gemma3_4b.yaml" + ), + "google/gemma-3-27b-it": os.path.join( + QEFF_DIR, "transformers", "models", "gemma3", "configs", "gemma_updated_npi.yaml" + ), +} + +# Blocking defaults +VTCM_SIZE_THRESHOLD = 8 * 1024 * 1024 * 0.75 + +# Compiler defaults +DEFAULT_AIC_NUM_CORES = 16 +DEFAULT_AIC_MXPF6_MATMUL = False +MOE_PREFILL_PACKED_CHUNK_SIZE = 256 +# Hashing defaults +HASH_HEXDIGEST_STR_LEN = 16 +KWARGS_INCLUSION_LIST = [ + "state_dict", + "revision", + "key_mapping", + "commit_hash", + "adapter_kwargs", + "adapter_name", + "gguf_file", + "pretrained_model_name_or_path", + "attn_implementation", + "_attn_implementation", + "qaic_config", +] + +# Minimum value for causal mask +MIN_MASKED_ATTENTION_VALUE = float("-inf") + + +# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable. +def get_models_dir(): + """ + Determine the directory for storing QEFF models. + Priority: + 1. Use $XDG_CACHE_HOME/qeff_models if XDG_CACHE_HOME is set. + 2. Use QEFF_HOME if set in environment. + 3. Default to ~/.cache/qeff_models. + Sets QEFF_MODELS_DIR environment variable if not already set. + Returns: + str: Path to the QEFF models directory. + """ + qeff_cache_home = os.environ.get("QEFF_HOME") + # Check if XDG_CACHE_HOME is set + xdg_cache_home = os.environ.get("XDG_CACHE_HOME") + if qeff_cache_home: + qeff_models_dir = os.path.join(qeff_cache_home, QEFF_CACHE_DIR_NAME) + # Check if QEFF_MODELS_DIR is set + elif xdg_cache_home: + qeff_models_dir = os.path.join(xdg_cache_home, QEFF_CACHE_DIR_NAME) + else: + # Use ~/.cache/qeff_models as the default + qeff_models_dir = os.path.join(os.path.expanduser("~"), ".cache", QEFF_CACHE_DIR_NAME) + + # Set QEFF_MODELS_DIR environment variable + return qeff_models_dir + + +QEFF_MODELS_DIR = get_models_dir() + +ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES = 0.5 +ONNX_EXPORT_EXAMPLE_TEMPERATURES = 0.80 +ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 +ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 +ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 +ONNX_EXPORT_OPSET = 17 +FILE_CHUNK_SIZE_DEFAULT = 10 * 2**30 # 10 GB +SIZE_THRESHOLD_DEFAULT = 1024 + + +COMPILER = ["/opt/qti-aic/exec/qaic-compile", "-aic-hw"] +DEFAULT_AIC_HW_VERSION = "ai100" +ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 + +# InternVL constants +# Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B +INTERN_FEATURE_SIZE = 256 +INTERN_NUM_PATCHES = 13 +INTERN_IMG_SIZE = 448 +INTERN_CTX_LEN = 4096 +INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 +INTERN_NUM_CHANNELS = 3 +INTERN_IMAGE_HEIGHT = 1000 +INTERN_IMAGE_WIDTH = 747 + +INTERN_IMG_CONTEXT_TOKEN = 151667 +# Specific to InternVL3_5 series, same token won't work for InternVL2_5 series +INTERN_3_5_IMG_CONTEXT_TOKEN = 151671 + +# Granite Vision Constants +# Fixing the feature size with reference to ibm-granite/granite-vision-3.2-2b +GRANITEVISION_FEATURE_SIZE = 5239 +GRANITEVISION_NUM_PATCHES = 10 +GRANITEVISION_IMG_SIZE = 384 +GRANITEVISION_IMG_SIZE_HEIGHT = 1109 +GRANITEVISION_IMG_SIZE_WIDTH = 1610 +GRANITEVISION_PIXEL_VALUE_DIM = 5 +GRANITEVISION_PREFIL_SEQ_LEN = GRANITEVISION_SEQ_LEN = 5500 +GRANITEVISION_CTX_LEN = 6000 +GRANITEVISION_NUM_CHANNELS = 3 + +VISION_MXFP6_MATMUL = False +# Llama4 Constants +LLAMA4_ATTENTION_CHUNK_SIZE = 8192 +LLAMA4_MAX_POSITION_EMBEDDINGS = 65536 + +# DeepSeek Kimi-k2 Constant +MAX_POSITION_EMBEDDINGS = 32768 +FP16_BYTES = 2 +DEFAULT_NUM_HEADS = 64 +KV_LORA_RANK = 512 +ROPE_DIM = 64 + +# Wav2Vec2 Constant +WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) + +# Qwen2_5_vl Constants +QWEN2_5_VL_HEIGHT = 354 +QWEN2_5_VL_WIDTH = 536 + +# Qwen3_vl Constanst +QWEN3_VL_HEIGHT = 354 +QWEN3_VL_WIDTH = 536 + +# Modules to cache while clearing the pytorch weights +CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] + +# Mistral3 Constants +MISTRAL3_IMAGE_HEIGHT = 1540 +MISTRAL3_IMAGE_WIDTH = 1540 + +# Molmo Constants +MOLMO_IMAGE_HEIGHT = 536 +MOLMO_IMAGE_WIDTH = 354 +# Flux Transformer Constants +FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 +FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 +FLUX_ADALN_HIDDEN_DIM = 3072 +FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context +FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 +FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM + +# Wan Transformer Constants +WAN_TEXT_EMBED_DIM = 5120 +WAN_PROJECTION_DIM = 6 +WAN_ONNX_EXPORT_BATCH_SIZE = 1 +WAN_ONNX_EXPORT_FRAMES = 81 +WAN_ONNX_EXPORT_LATENT_FRAMES = 21 +WAN_ONNX_EXPORT_SEQ_LEN = 512 +WAN_ONNX_EXPORT_ROTARY_DIM = 128 +WAN_DIT_OUT_CHANNELS = 64 +# Wan dims for 45p +WAN_ONNX_EXPORT_CL_45P = 252 +WAN_ONNX_EXPORT_LATENT_HEIGHT_45P = 6 +WAN_ONNX_EXPORT_LATENT_WIDTH_45P = 8 +WAN_ONNX_EXPORT_HEIGHT_45P = 48 +WAN_ONNX_EXPORT_WIDTH_45P = 64 + +# WAN I2V +WAN_DIT_I2V_IMG_LATENT_CHANNELS = 32 + +# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length +CCL_START_MAP = { + 32768: (4096, 4000), + 65536: (8192, 8000), + float("inf"): (16384, 16000), +} +# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process. +CCL_MAX_ELEMENTS_LISTS = 5 +CCL_START_CTX_LEN = 4096 +CCL_MIN_CTX_LEN = 1024 +CCL_UNIQNE_STEP = 32 + +# used for gpt-oss prefill-only model Q-blocking +GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256 + + +class Constants: + # Export Constants. + SEQ_LEN = 32 + CTX_LEN = 32 + PROMPT_LEN = 8 + INPUT_STR = ["My name is"] + GB = 2**30 + MAX_QPC_LIMIT = 30 + MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download + NUM_SPECULATIVE_TOKENS = 2 + MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS + SAMPLER_OPS = { + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + } + SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"} + SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. + SDK_PLATFORM_XML = ( + "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. + ) + + +@dataclass +class QnnConstants: + # QNN PATH to be read from environment variable. + QNN_SDK_PATH_ENV_VAR_NAME = "QNN_SDK_ROOT" + QNN_SDK_YAML = "sdk.yaml" + + # QNN Compilation tools + QAIRT_CONVERTER = "{}/bin/{}/qairt-converter" + QNN_CONTEXT_BIN = "{}/bin/{}/qnn-context-binary-generator" + + # QNN Libraries required for compilation + QNN_CONTEXT_LIB_BACKEND = "{}/lib/{}/libQnnAic.so" + QNN_CONTEXT_LIB_NET_RUN_EXTENSIONS = "{}/lib/{}/libQnnAicNetRunExtensions.so" + + # QNN Compilation target names + MODEL_NAME = "model" + QNN_DATA_FORMAT_CONFIG_NAME = "qnn_data_format_config.json" + CONTEXT_BIN_NAME = "qnngraph.serialized" + CONTEXT_BIN_QPC_NAME = "programqpc.bin" + + # TARGET System Architecture + TARGET = "x86_64-linux-clang" # TODO add support in infer to be override + + # Converter Arguments + FLOAT_BITWIDTH = 16 + FLOAT_BIAS_BITWIDTH = 32 + CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification --target_backend AIC " + + # Context-Binary-Generator Arguments + LOG_LEVEL = "error" + + # qnn_compilation_backend default Arguments + COMPILER_COMPILATION_TARGET = "hardware" + COMPILER_CONVERT_TO_FP16 = True + COMPILER_DO_DDR_TO_MULTICAST = True + COMPILER_HARDWARE_VERSION = "2.0" + COMPILER_PERF_WARNINGS = False + COMPILER_PRINT_DDR_STATS = False + COMPILER_PRINT_PERF_METRICS = False + COMPILER_RETAINED_STATE = True + COMPILER_STAT_LEVEL = 10 + COMPILER_STATS_BATCH_SIZE = 1 + COMPILER_TIME_PASSES = False + GRAPH_NAMES = [f"{MODEL_NAME}_configuration_1", f"{MODEL_NAME}_configuration_2"] + GRAPH_NAMES_PREFILL_ONLY = [f"{MODEL_NAME}"] + + # qnn_config JSON file supported Keys + CONVERTER_ARGS_EXTENSION_STR = "converter_args_extension" + CONTEXT_BIN_ARGS_EXTENSION_STR = "context_binary_generator_args_extension" + QNN_COMPILATION_BACKEND_STR = "qnn_compilation_backend" + SKIP_QNN_CONVERTER_STEP_STR = "SKIP_QNN_CONVERTER_STEP" + + IMMUTABLE_CONVERTER_ARGS = [ + "--input_network ", + "--output_path ", + "--config ", + "--float_bias_bitwidth ", + "--float_bitwidth ", + "--preserve_io_datatype", + "--onnx_skip_simplification", + ] + + IMMUTABLE_CONTEXT_BIN_GEN_ARGS = [ + "--binary_file ", + "--backend_binary ", + "--output_dir ", + "--backend ", + "--model ", + "--dlc_path ", + "--config_file ", + ] + + QNN_SAMPLE_CONFIG = { + "converter_args_extension": "--onnx_defer_loading", + "context_binary_generator_args_extension": "--log_level debug", + "qnn_compilation_backend": { + "compiler_enable_depth_first": True, + "compiler_printDDRStats": False, + "compiler_printPerfMetrics": False, + }, + "SKIP_QNN_CONVERTER_STEP": False, + } diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py index 655de4ef51..3bc9339091 100644 --- a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py @@ -14,14 +14,15 @@ from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +model_id = "yujiepan/qwen3-moe-tiny-random" prompt = """ Explain quantum computing in simple terms. """ config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) -PREFILL_SEQ_LEN = 128 -CTX_LEN = 128 * 3 +PREFILL_SEQ_LEN = 256 +CTX_LEN = PREFILL_SEQ_LEN * 3 qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) decode_qpc_path = qeff_model.compile( @@ -48,7 +49,7 @@ num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, - num_devices=2, + num_devices=1, split_retained_state_io=True, mos=1, aic_enable_depth_first=True, diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py new file mode 100644 index 0000000000..1b115d79e1 --- /dev/null +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -0,0 +1,202 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Tests for NSP-blocked MoE prefill dispatch (Qwen3MOE + GPT-OSS). +Uses EXPERT_BLOCKING_NUM_NSP=2 so tests run fast on any num_experts. +Covers: parity, decode export, prefill+chunking export (disagg mode). +""" + +import os + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +os.environ.setdefault("EXPERT_BLOCKING_NUM_NSP", "2") + +from QEfficient import QEFFAutoModelForCausalLM + +MODEL_KWARGS = {"attn_implementation": "eager"} + +QWEN3_MOE_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=4, + hidden_size=128, + intermediate_size=512, + vocab_size=127, + num_key_value_heads=2, +) +GPTOSS_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=32, + intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, +) + + +# ── Qwen3MOE ────────────────────────────────────────────────────────────────── + + +def test_qwen3moe_blocked_forward_parity(): + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, + ) + + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks = [ + m + for _, m in model.named_modules() + if hasattr(m, "experts") and hasattr(m, "gate") and hasattr(m, "num_experts") + ] + assert blocks + + block = blocks[0] + chunked = QEffPrefillChunkedQwen3MoeSparseMoeBlock.__new__(QEffPrefillChunkedQwen3MoeSparseMoeBlock) + chunked.__dict__.update(block.__dict__) + chunked.__class__ = QEffPrefillChunkedQwen3MoeSparseMoeBlock + chunked.__qeff_init__() + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = chunked.orig_forward(x) + blocked, _ = chunked.forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "Qwen3MOE parity failed" + + +def test_qwen3moe_decode_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_qwen3moe_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() + + +def test_qwen3moe_prefill_chunked_subfunction_export_contains_cumsum_custom_ops(tmp_path): + import onnx + + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + onnx_path = qeff.export( + tmp_path / "prefill-subfunction", + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, + offload_pt_weights=False, + ) + + onnx_model = onnx.load(str(onnx_path), load_external_data=False) + function_names = {func.name for func in onnx_model.functions} + used_op_types = {node.op_type for node in onnx_model.graph.node} + for function_proto in onnx_model.functions: + used_op_types.update(node.op_type for node in function_proto.node) + + assert "CtxScatter3DInt" in function_names + assert "CtxScatter3D" in function_names + assert "CtxGather3D" in function_names + assert "CtxScatter3DInt" in used_op_types + assert "CtxScatter3D" in used_op_types + assert "CtxGather3D" in used_op_types + + +# ── GPT-OSS ─────────────────────────────────────────────────────────────────── + + +def test_gptoss_blocked_forward_parity(): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffPrefillOnlyChunkedGptOssMLP, + ) + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks_orig = [m for _, m in model.named_modules() if m.__class__.__name__ == "GptOssMLP"] + assert blocks_orig + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = blocks_orig[0].forward(x) + + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + PrefillOnlyChunkedTransform.apply(qeff.model) + + blocks_chunked = [m for _, m in qeff.model.named_modules() if isinstance(m, QEffPrefillOnlyChunkedGptOssMLP)] + assert blocks_chunked + + with torch.no_grad(): + blocked, _ = blocks_chunked[0].forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "GPT-OSS parity failed" + + +def test_gptoss_decode_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_gptoss_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() + + +def test_gptoss_prefill_chunked_export_traces_packed_chunks(tmp_path): + import onnx + from onnx import numpy_helper + + config = AutoConfig.for_model("gpt_oss", **{**GPTOSS_CFG, "max_position_embeddings": 1024}) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + onnx_path = qeff.export( + tmp_path / "prefill-subfunction-512", + prefill_only=True, + enable_chunking=True, + prefill_seq_len=512, + use_onnx_subfunctions=True, + offload_pt_weights=False, + ) + + onnx_model = onnx.load(str(onnx_path), load_external_data=False) + slice_starts = [] + op_types = [] + for nodes in [onnx_model.graph.node] + [function.node for function in onnx_model.functions]: + constants = {} + for node in nodes: + op_types.append(node.op_type) + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value": + constants[node.output[0]] = numpy_helper.to_array(attr.t).flatten().tolist() + for node in nodes: + if node.op_type == "Slice" and len(node.input) > 1 and node.input[1] in constants: + slice_starts.append(constants[node.input[1]]) + + assert [256] in slice_starts + assert op_types.count("CtxGather3D") >= 2 * op_types.count("CtxScatter3DInt") diff --git a/tests/unit_test/transforms/test_onnx_transforms.py b/tests/unit_test/transforms/test_onnx_transforms.py index 5a43b345d6..d8897364aa 100644 --- a/tests/unit_test/transforms/test_onnx_transforms.py +++ b/tests/unit_test/transforms/test_onnx_transforms.py @@ -531,6 +531,24 @@ def test_custom_op_transform_contains_ctx_gather(self): assert "CtxGatherFunc" in CustomOpTransform._custom_ops + def test_custom_op_transform_contains_ctx_scatter_3d_int(self): + """CustomOpTransform._custom_ops must contain 'CtxScatterFunc3DInt'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxScatterFunc3DInt" in CustomOpTransform._custom_ops + + def test_custom_op_transform_contains_ctx_scatter_3d_generalized(self): + """CustomOpTransform._custom_ops must contain 'CtxScatterFunc3DGeneralized'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxScatterFunc3DGeneralized" in CustomOpTransform._custom_ops + + def test_custom_op_transform_contains_ctx_gather_3d_generalized(self): + """CustomOpTransform._custom_ops must contain 'CtxGatherFunc3DGeneralized'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxGatherFunc3DGeneralized" in CustomOpTransform._custom_ops + def test_custom_op_transform_rms_norm_maps_to_custom_rms_norm(self): """CustomRMSNormFunc must map to CustomRMSNorm class.""" from QEfficient.base.onnx_transforms import CustomOpTransform