From 0d9bfb0261547cb0461ac0cc769f18d784b48cd0 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Thu, 4 Jun 2026 18:15:11 +0530 Subject: [PATCH 01/12] [EB] Qwen_3_5_Moe Signed-off-by: Mohit Soni --- .../transformers/models/pytorch_transforms.py | 3 + .../qwen3_5_moe/modeling_qwen3_5_moe.py | 99 ++++++++++++------- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b2b447a78..7af7167cb 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -236,6 +236,7 @@ Qwen3_5MoeTopKRouter, Qwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel, + Qwen3_5MoeExperts, ) from transformers.models.qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeAttention, @@ -569,6 +570,7 @@ QEffQwen3_5MoeTopKRouter, QEffQwen3_5MoeVisionAttention, QEffQwen3_5MoeVisionModel, + QEffQwen3_5MoeExperts, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( QEffPrefillChunkedQwen3MoeSparseMoeBlock, @@ -868,6 +870,7 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3_5MoeVisionAttention: QEffQwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel: QEffQwen3_5MoeVisionModel, Qwen3_5MoeTopKRouter: QEffQwen3_5MoeTopKRouter, + Qwen3_5MoeExperts: QEffQwen3_5MoeExperts, # Qwen2.5 VL Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, Qwen2_5_VLModel: QEffQwen2_5_VLModel, diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 39dbb0e1b..79009392b 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -31,6 +31,7 @@ Qwen3_5MoeTopKRouter, Qwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel, + Qwen3_5MoeExperts, apply_rotary_pos_emb_vision, repeat_kv, rotate_half, @@ -2085,6 +2086,13 @@ def forward(self, hidden_states): router_scores = router_top_value return router_logits, router_scores, router_indices +class QEffQwen3_5MoeExperts(Qwen3_5MoeExperts): + def __qeff_init__(self): + # transformers>=5 uses fused gate_up projections. Keep backward-compatible + # aliases expected by existing QEff paths. + self.expert_dim = getattr(self, "intermediate_size", self.gate_up_proj.shape[-2] // 2) + self.gate_proj = nn.Parameter(self.gate_up_proj[:, :self.expert_dim, :].detach().clone()) + self.up_proj = nn.Parameter(self.gate_up_proj[:, self.expert_dim :, :].detach().clone()) class QEffQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -2095,13 +2103,13 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) - w_up = self.experts.gate_up_proj[idx.flatten()] + gate_proj = self.experts.gate_proj[idx.flatten()] + up_proj = self.experts.up_proj[idx.flatten()] w_dn = self.experts.down_proj[idx.flatten()] xk = x.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous() xk = xk.view(-1, 1, H) - gate_proj, up_proj = torch.chunk(w_up, 2, dim=1) gate = torch.bmm(xk, gate_proj.transpose(1, 2)) up = torch.bmm(xk, up_proj.transpose(1, 2)) @@ -2122,7 +2130,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index""" + """Build packed->original token index.""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) @@ -2130,9 +2138,7 @@ def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) valid_dest = valid_prefix - 1 scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) - # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this - # can be switched back to ``torch.full_like(token_idx, int32_max)``. - matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = torch.full_like(token_idx, int32_max) matched_idx = CtxScatterFunc3DInt.apply( matched_idx.unsqueeze(-1), scatter_pos, @@ -2148,33 +2154,23 @@ def _cumsum_scatter_gather_update_expert_blocked( W_u: torch.Tensor, W_d: torch.Tensor, routing_weight: torch.Tensor, - experts_out: torch.Tensor, + expert_out: torch.Tensor, act_fn, - T: int, packed_chunk_size: int, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. - Accumulates one local expert's contribution in-place onto ``experts_out``. + Accumulates one local expert's contribution in-place onto ``expert_out``. Uses a packed/cumsum layout so the MLP runs only over active rows, then scatters the weighted output back to original token positions. - - Shapes: - x : [T, H] - T2Ei : [num_nsp, T] (bool) - W_g, W_u : [num_nsp, H, I] - W_d : [num_nsp, I, H] - routing_weight : [num_nsp, T] - experts_out : [num_nsp, T, H] (accumulator, in-out) """ batch_size, seq_len = T2Ei.shape - packed_chunk_size = int(max(1, min(packed_chunk_size, seq_len))) + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) - rw_expanded = routing_weight.unsqueeze(-1) for packed_start in range(0, seq_len, packed_chunk_size): packed_stop = packed_start + packed_chunk_size chunk_matched_idx = matched_idx[:, packed_start:packed_stop] @@ -2185,26 +2181,54 @@ def _cumsum_scatter_gather_update_expert_blocked( up_prime = x_chunk @ W_u down_chunk = (up_prime * act_fn(gate_prime)) @ W_d - rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + rw_chunk = CtxGatherFunc3DGeneralized.apply(routing_weight, chunk_matched_idx) down_chunk = down_chunk * rw_chunk - - expert_out_chunk = CtxGatherFunc3DGeneralized.apply(experts_out, chunk_matched_idx) + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + chunk_valid_rows = torch.clamp( + valid_rows - packed_start, + min=torch.zeros_like(valid_rows), + max=torch.full_like(valid_rows, packed_chunk_size), + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) - experts_out = CtxScatterFunc3DGeneralized.apply(experts_out, chunk_matched_idx, updated_chunk) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - return experts_out + return expert_out class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): + + def __qeff_init__(self): + self.top_k = getattr(self.gate, "top_k", None) + self.norm_topk_prob = getattr(self.gate, "norm_topk_prob", False) + self.num_experts = getattr(self.gate, "num_experts", self.experts.gate_up_proj.shape[0]) + self.gate_up_proj_w = self.experts.gate_up_proj + self.down_proj_w = self.experts.down_proj + + def _split_expert_weights(self, hidden_size: int): + gate_up_proj_w = self.gate_up_proj_w + if gate_up_proj_w.shape[2] != hidden_size: + gate_up_proj_w = gate_up_proj_w.transpose(1, 2) + intermediate_size = gate_up_proj_w.shape[1] // 2 + gate_proj_w = gate_up_proj_w[:, :intermediate_size, :] + up_proj_w = gate_up_proj_w[:, intermediate_size:, :] + + down_proj_w = self.down_proj_w + if down_proj_w.shape[1] != intermediate_size: + down_proj_w = down_proj_w.transpose(1, 2) + return gate_proj_w, up_proj_w, down_proj_w + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: act_fn = getattr(self.experts, "act_fn", F.silu) T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP + + packed_chunk_size = getattr(self, "expert_blocking_packed_chunk_size", T) + if self.gate.num_experts % num_nsp != 0: raise ValueError( f"num_experts ({self.gate.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" @@ -2216,15 +2240,17 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor # gate_up_proj is [E, 2I, H]. After split we get [E, I, H], so transpose to [E, H, I] # before grouping into [num_nsp, local_experts, H, I]. - wt_g, wt_u = torch.split(self.experts.gate_up_proj, inter, dim=1) + # wt_g, wt_u = torch.split(self.experts.gate_up_proj, inter, dim=1) + wt_g, wt_u, W_d = self._split_expert_weights(H) wt_g = wt_g.transpose(1, 2).contiguous() wt_u = wt_u.transpose(1, 2).contiguous() W_g = wt_g.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_u = wt_u.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() # down_proj is [E, H, I]; blocked matmul expects [num_nsp, local_experts, I, H]. - W_d = self.experts.down_proj.transpose(1, 2).contiguous() + W_d = W_d.transpose(1, 2).contiguous() W_d = W_d.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + routing_weights_unsqueezed = rw.unsqueeze(-1) for slot in range(local_experts): routing_weight = rw[:, slot, :] @@ -2235,11 +2261,10 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor W_g=W_g[:, slot], W_u=W_u[:, slot], W_d=W_d[:, slot], - routing_weight=routing_weight, - experts_out=experts_out, + routing_weight=routing_weights_unsqueezed[:, slot], + expert_out=experts_out, act_fn=act_fn, - T=T, - packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, + packed_chunk_size=packed_chunk_size, ) return experts_out.sum(dim=0) @@ -2283,13 +2308,13 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights = torch.zeros((T, self.gate.num_experts), dtype=x.dtype) routing_weights.scatter_(1, top_i, top_w) - # if self.gate.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - # experts_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + if self.gate.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + experts_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - # shared_expert_output = self.shared_expert(x) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - # expert_output = experts_out + shared_expert_output - # return expert_output.view(B, S, H) + shared_expert_output = self.shared_expert(x) + shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output + expert_output = experts_out + shared_expert_output + return expert_output.view(B, S, H) experts_out = torch.zeros_like(x, dtype=x.dtype) # breakpoint() From b4748d7eb9134b3d467533ceb53d54cfedd882a3 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Thu, 4 Jun 2026 19:45:57 +0530 Subject: [PATCH 02/12] Updatin attention to newton schulitz Signed-off-by: Mohit Soni --- .../transformers/models/pytorch_transforms.py | 4 +-- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 35 ++++++++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 7af7167cb..2d88bbb38 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -225,6 +225,7 @@ from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( Qwen3_5MoeAttention, Qwen3_5MoeDecoderLayer, + Qwen3_5MoeExperts, Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeGatedDeltaNet, @@ -236,7 +237,6 @@ Qwen3_5MoeTopKRouter, Qwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel, - Qwen3_5MoeExperts, ) from transformers.models.qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeAttention, @@ -560,6 +560,7 @@ QEffPrefillChunkedQwen3_5MoeSparseMoeBlock, QEffQwen3_5MoeAttention, QEffQwen3_5MoeDecoderLayer, + QEffQwen3_5MoeExperts, QEffQwen3_5MoeForCausalLM, QEffQwen3_5MoeForConditionalGeneration, QEffQwen3_5MoeGatedDeltaNet, @@ -570,7 +571,6 @@ QEffQwen3_5MoeTopKRouter, QEffQwen3_5MoeVisionAttention, QEffQwen3_5MoeVisionModel, - QEffQwen3_5MoeExperts, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( QEffPrefillChunkedQwen3MoeSparseMoeBlock, diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 79009392b..48e201484 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -20,6 +20,7 @@ Qwen3_5MoeAttention, Qwen3_5MoeCausalLMOutputWithPast, Qwen3_5MoeDecoderLayer, + Qwen3_5MoeExperts, Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeGatedDeltaNet, @@ -31,7 +32,6 @@ Qwen3_5MoeTopKRouter, Qwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel, - Qwen3_5MoeExperts, apply_rotary_pos_emb_vision, repeat_kv, rotate_half, @@ -664,12 +664,12 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -709,6 +709,17 @@ def torch_chunk_gated_delta_rule_qeff( # attn = S64.to(A.dtype) + # Newton-Schulz + I = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + L = attn.masked_fill(mask, 0) + + X = I + for _ in range(int(math.log2(chunk_size)) + 2): + R = I - (I - L) @ X + X = X + X @ R + + attn = X + value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) @@ -2091,9 +2102,10 @@ def __qeff_init__(self): # transformers>=5 uses fused gate_up projections. Keep backward-compatible # aliases expected by existing QEff paths. self.expert_dim = getattr(self, "intermediate_size", self.gate_up_proj.shape[-2] // 2) - self.gate_proj = nn.Parameter(self.gate_up_proj[:, :self.expert_dim, :].detach().clone()) + self.gate_proj = nn.Parameter(self.gate_up_proj[:, : self.expert_dim, :].detach().clone()) self.up_proj = nn.Parameter(self.gate_up_proj[:, self.expert_dim :, :].detach().clone()) + class QEffQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape @@ -2200,14 +2212,13 @@ def _cumsum_scatter_gather_update_expert_blocked( class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): - def __qeff_init__(self): self.top_k = getattr(self.gate, "top_k", None) self.norm_topk_prob = getattr(self.gate, "norm_topk_prob", False) self.num_experts = getattr(self.gate, "num_experts", self.experts.gate_up_proj.shape[0]) self.gate_up_proj_w = self.experts.gate_up_proj self.down_proj_w = self.experts.down_proj - + def _split_expert_weights(self, hidden_size: int): gate_up_proj_w = self.gate_up_proj_w if gate_up_proj_w.shape[2] != hidden_size: @@ -2221,7 +2232,6 @@ def _split_expert_weights(self, hidden_size: int): down_proj_w = down_proj_w.transpose(1, 2) return gate_proj_w, up_proj_w, down_proj_w - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: act_fn = getattr(self.experts, "act_fn", F.silu) T, H = x.shape @@ -2236,9 +2246,8 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor local_experts = self.gate.num_experts // num_nsp rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() experts_out = x.new_zeros((num_nsp, T, H)) - inter = self.experts.gate_up_proj.shape[1] // 2 - # gate_up_proj is [E, 2I, H]. After split we get [E, I, H], so transpose to [E, H, I] + # gate_up_proj is [E, 2I, H]. after split we get [E, I, H], so transpose to [E, H, I] # before grouping into [num_nsp, local_experts, H, I]. # wt_g, wt_u = torch.split(self.experts.gate_up_proj, inter, dim=1) wt_g, wt_u, W_d = self._split_expert_weights(H) From 437a97293e1eef70f61be4640b9d18dc90be4234 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Fri, 5 Jun 2026 15:01:32 +0530 Subject: [PATCH 03/12] Adding fp16 support and comments addresses Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 26 ++++++------ .../qwen3_5_moe/modeling_qwen3_5_moe.py | 42 ++++++++++++------- .../models/qwen3_5/qwen3_5.py | 4 +- .../qwen3_5/qwen3_5_continous_batching.py | 2 + .../models/qwen3_5_moe/qwen3_5_disagg_mode.py | 8 ++++ .../models/qwen3_5_moe/qwen3_5_moe.py | 4 +- .../qwen3_5_moe_continous_batching.py | 3 +- 7 files changed, 56 insertions(+), 33 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index c115022a6..b757de76b 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -64,7 +64,7 @@ def forward(self, hidden_states, gate): CustomRMSNormFunc.apply( hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps ) - ) * F.silu(gate.to(torch.float32)) + ) * F.silu(gate.to(gate.dtype)) class QEffQwen3_5DynamicCache(Cache): @@ -338,7 +338,7 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -530,7 +530,7 @@ def torch_chunk_gated_delta_rule_qeff( query = l2norm(query, dim=-1, eps=1e-6) key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + x.transpose(1, 2).contiguous().to(query.dtype) for x in (query, key, value, beta, g) ] mask = (position_ids[0] != -1).unsqueeze(1) @@ -1047,8 +1047,8 @@ def get_onnx_retained_state_specs( if layer_type == "full_attention": layer_names = [f"past_key.{layer_idx}", f"past_value.{layer_idx}"] layer_tensors = [ - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), + torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype), + torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype), ] layer_axes = [ {0: batch_axis_name, 2: "ctx_len"}, @@ -1060,8 +1060,8 @@ def get_onnx_retained_state_specs( recurrent_shape = (batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) layer_names = [f"conv_state.{layer_idx}", f"recurrent_state.{layer_idx}"] layer_tensors = [ - torch.zeros(conv_shape, dtype=torch.float32), - torch.zeros(recurrent_shape, dtype=torch.float32), + torch.zeros(conv_shape, dtype=self.model.config.dtype), + torch.zeros(recurrent_shape, dtype=self.model.config.dtype), ] layer_axes = [{0: batch_axis_name}, {0: batch_axis_name}] @@ -1362,7 +1362,7 @@ def forward( col_mask = (cols >= start) & (cols < end) block_mask = row_mask & col_mask - final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask = torch.ones((seq_len, seq_len), dtype=self.config.dtype) final_mask[block_mask.any(dim=0)] = 0 final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) attention_mask[0] = final_mask @@ -1758,10 +1758,10 @@ def get_dummy_inputs( vision_inputs = {} lang_inputs = {} - vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.model.config.dtype) vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.model.config.dtype) lang_inputs["position_ids"] = ( ( torch.arange(dummy_seq_len, dtype=torch.int64) @@ -1788,13 +1788,13 @@ def get_dummy_inputs( for i in range(self.model.config.text_config.num_hidden_layers): if self.model.config.text_config.layer_types[i] == "full_attention": for kv in ["key", "value"]: - lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=self.model.config.dtype)) else: layer = self.model.language_model.layers[i].linear_attn conv_shape = (linear_batch_size, layer.conv_dim, layer.conv_kernel_size) recurrent_shape = (linear_batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) - lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=torch.float32)) - lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=torch.float32)) + lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=self.model.config.dtype)) + lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=self.model.config.dtype)) # if continuous_batching: diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 48e201484..450ea796f 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -400,11 +400,10 @@ def eager_attention_forward( value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - # - # MIN_MASKED_ATTENTION_VALUE = -10000 + if attention_mask is not None: attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -600,7 +599,7 @@ def torch_chunk_gated_delta_rule_qeff( query = query * torch.rsqrt(torch.einsum("bthd,bthd->bth", query, query).unsqueeze(-1) + 1e-6) key = key * torch.rsqrt(torch.einsum("bthd,bthd->bth", key, key).unsqueeze(-1) + 1e-6) query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + x.transpose(1, 2).contiguous().to(query.dtype) for x in (query, key, value, beta, g) ] mask = (position_ids[0] != -1).unsqueeze(1) @@ -659,9 +658,9 @@ def torch_chunk_gated_delta_rule_qeff( # decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() # original decay_mask diff = g.unsqueeze(-1) - g.unsqueeze(-2) # (B, H, num_chunks, C, C) - diff = diff * (~mask_strict).float() # zero upper triangle (strict) - decay_mask = diff.exp().float() - decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero + diff = diff * (~mask_strict).to(diff.dtype) # zero upper triangle (strict) + decay_mask = diff.exp().to(diff.dtype) + decay_mask = decay_mask * (~mask_strict).to(decay_mask.dtype) # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) # for i in range(1, chunk_size): @@ -709,13 +708,14 @@ def torch_chunk_gated_delta_rule_qeff( # attn = S64.to(A.dtype) - # Newton-Schulz - I = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # Newton-Schulz + + Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) L = attn.masked_fill(mask, 0) - X = I + X = Eye for _ in range(int(math.log2(chunk_size)) + 2): - R = I - (I - L) @ X + R = Eye - (Eye - L) @ X X = X + X @ R attn = X @@ -1491,8 +1491,7 @@ def forward( row_mask = (rows >= start) & (rows < end) col_mask = (cols >= start) & (cols < end) block_mask = row_mask & col_mask - - final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask = torch.ones((seq_len, seq_len), dtype=self.config.dtype) final_mask[block_mask.any(dim=0)] = 0 final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) attention_mask[0] = final_mask @@ -2180,7 +2179,8 @@ def _cumsum_scatter_gather_update_expert_blocked( packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) matched_idx = _build_matched_idx_from_cumsum(T2Ei) - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + # valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) for packed_start in range(0, seq_len, packed_chunk_size): @@ -2211,6 +2211,18 @@ def _cumsum_scatter_gather_update_expert_blocked( return expert_out +class QEffQwen3_5MoeTopKRouter(Qwen3_5MoeTopKRouter): + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1).to(router_logits.dtype) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = router_top_value / torch.einsum("bk->b", router_top_value).unsqueeze(-1) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + + class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): def __qeff_init__(self): self.top_k = getattr(self.gate, "top_k", None) @@ -2275,7 +2287,7 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor act_fn=act_fn, packed_chunk_size=packed_chunk_size, ) - return experts_out.sum(dim=0) + return torch.einsum("ijk->jk", experts_out) # def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # B, S, H = hidden_states.shape diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py index 9f8f498b4..a8e0f5af8 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import requests +import torch import transformers from PIL import Image from qwen_vl_utils import process_vision_info @@ -19,10 +20,9 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 2 -config.torch_dtype = "float32" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py index ffead3420..f365e976a 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import torch import transformers from transformers import AutoConfig, AutoProcessor @@ -23,6 +24,7 @@ attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, continuous_batching=True, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py index 4a2a15e21..fc221c94b 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py @@ -24,6 +24,7 @@ LAYERWISE = True # For faster execution user can run with lesser layers, For Testing Purpose Only +<<<<<<< HEAD config.vision_config.depth = 4 config.text_config.num_hidden_layers = 4 config.torch_dtype = "float16" @@ -34,6 +35,13 @@ kv_offload=True, config=config, layerwise=LAYERWISE, +======= +config.vision_config.depth = 5 +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 +>>>>>>> 2932096 (Adding fp16 support and comments addresses) ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py index 435f352b3..27ad5d49f 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import requests +import torch import transformers from PIL import Image from qwen_vl_utils import process_vision_info @@ -19,10 +20,9 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 4 -config.torch_dtype = "float32" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py index 95ae66d12..f83647fd7 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import torch import transformers from transformers import AutoConfig, AutoProcessor @@ -16,13 +17,13 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 2 -config.torch_dtype = "float32" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, continuous_batching=True, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) From 2d1893f4958a399e84f7ef86939c905326fc3416 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Fri, 5 Jun 2026 15:33:24 +0530 Subject: [PATCH 04/12] Adding Newton Schulz method in qwen3_5 Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index b757de76b..ac8a10c03 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -587,11 +587,11 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -632,6 +632,17 @@ def torch_chunk_gated_delta_rule_qeff( # attn = S64 + # Newton-Schulz + Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + L = attn.masked_fill(mask, 0) + + X = Eye + for _ in range(int(math.log2(chunk_size)) + 2): + R = Eye - (Eye - L) @ X + X = X + X @ R + + attn = X + value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) From 8d0a886356f737724d5ab8825a09a3944aeb84dd Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 8 Jun 2026 16:38:00 +0530 Subject: [PATCH 05/12] updating moe code and minor fixes Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 30 +++- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 155 ++++-------------- .../models/qwen3_5/qwen3_5.py | 6 +- .../models/qwen3_5_moe/qwen3_5_disagg_mode.py | 14 +- .../models/qwen3_5_moe/qwen3_5_moe.py | 8 +- 5 files changed, 75 insertions(+), 138 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index ac8a10c03..99d414976 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -623,7 +623,7 @@ def torch_chunk_gated_delta_rule_qeff( # acc_dtype = torch.float32 # A64 = A.to(acc_dtype) # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) # K = chunk_size - 1 # S64 = I64.clone() @@ -633,15 +633,29 @@ def torch_chunk_gated_delta_rule_qeff( # attn = S64 # Newton-Schulz - Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - L = attn.masked_fill(mask, 0) + # Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # L = attn.masked_fill(mask, 0) - X = Eye - for _ in range(int(math.log2(chunk_size)) + 2): - R = Eye - (Eye - L) @ X - X = X + X @ R + # X = Eye + # for _ in range(int(math.log2(chunk_size)) + 2): + # R = Eye - (Eye - L) @ X + # X = X + X @ R - attn = X + # attn = X + + # Newton-Schulz updated + acc_dtype = torch.float32 + I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + L64 = attn.masked_fill(mask, 0).to(acc_dtype) + + # X0 = I + Xj = I64.clone() + for _ in range(chunk_size - 1): + Rj = I64 - ((I64 - L64) @ Xj) + Xj = Xj + (Xj @ Rj).masked_fill(~strict_lower, 0) + Xj = Xj.masked_fill(~strict_lower, 0) + I64 + + attn = Xj.to(attn.dtype) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 450ea796f..db687d8d5 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -2101,8 +2101,9 @@ def __qeff_init__(self): # transformers>=5 uses fused gate_up projections. Keep backward-compatible # aliases expected by existing QEff paths. self.expert_dim = getattr(self, "intermediate_size", self.gate_up_proj.shape[-2] // 2) - self.gate_proj = nn.Parameter(self.gate_up_proj[:, : self.expert_dim, :].detach().clone()) - self.up_proj = nn.Parameter(self.gate_up_proj[:, self.expert_dim :, :].detach().clone()) + self.gate_proj = nn.Parameter(self.gate_up_proj[:, : self.expert_dim, :].detach().clone().transpose(1, 2)) + self.up_proj = nn.Parameter(self.gate_up_proj[:, self.expert_dim :, :].detach().clone().transpose(1, 2)) + self.down_proj_t = nn.Parameter(self.down_proj.detach().clone().transpose(1, 2)) class QEffQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): @@ -2116,16 +2117,16 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens gate_proj = self.experts.gate_proj[idx.flatten()] up_proj = self.experts.up_proj[idx.flatten()] - w_dn = self.experts.down_proj[idx.flatten()] + w_dn = self.experts.down_proj_t[idx.flatten()] xk = x.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous() xk = xk.view(-1, 1, H) - gate = torch.bmm(xk, gate_proj.transpose(1, 2)) - up = torch.bmm(xk, up_proj.transpose(1, 2)) + gate = torch.bmm(xk, gate_proj) + up = torch.bmm(xk, up_proj) intermediate = up * self.experts.act_fn(gate) - experts_out = torch.bmm(intermediate, w_dn.transpose(1, 2)) + experts_out = torch.bmm(intermediate, w_dn) experts_out = experts_out.view(T, self.gate.top_k, H) * top_w.unsqueeze(-1) experts_out = torch.einsum("bnd->bd", experts_out) @@ -2224,139 +2225,53 @@ def forward(self, hidden_states): class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): + supports_moe_prefill_blocking = True + def __qeff_init__(self): self.top_k = getattr(self.gate, "top_k", None) self.norm_topk_prob = getattr(self.gate, "norm_topk_prob", False) - self.num_experts = getattr(self.gate, "num_experts", self.experts.gate_up_proj.shape[0]) - self.gate_up_proj_w = self.experts.gate_up_proj - self.down_proj_w = self.experts.down_proj - - def _split_expert_weights(self, hidden_size: int): - gate_up_proj_w = self.gate_up_proj_w - if gate_up_proj_w.shape[2] != hidden_size: - gate_up_proj_w = gate_up_proj_w.transpose(1, 2) - intermediate_size = gate_up_proj_w.shape[1] // 2 - gate_proj_w = gate_up_proj_w[:, :intermediate_size, :] - up_proj_w = gate_up_proj_w[:, intermediate_size:, :] - - down_proj_w = self.down_proj_w - if down_proj_w.shape[1] != intermediate_size: - down_proj_w = down_proj_w.transpose(1, 2) - return gate_proj_w, up_proj_w, down_proj_w - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - act_fn = getattr(self.experts, "act_fn", F.silu) - T, H = x.shape - num_nsp = EXPERT_BLOCKING_NUM_NSP + self.num_experts = getattr(self.gate, "num_experts", self.experts.gate_proj.shape[0]) - packed_chunk_size = getattr(self, "expert_blocking_packed_chunk_size", T) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits, top_w, top_i = self.gate(x) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) - if self.gate.num_experts % num_nsp != 0: + num_nsp = getattr(self, "expert_blocking_num_nsp", self.num_experts) + packed_chunk_size = getattr(self, "expert_blocking_packed_chunk_size", T) + if self.num_experts % num_nsp != 0: raise ValueError( - f"num_experts ({self.gate.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" ) - local_experts = self.gate.num_experts // num_nsp + + local_experts = self.num_experts // num_nsp rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - experts_out = x.new_zeros((num_nsp, T, H)) - - # gate_up_proj is [E, 2I, H]. after split we get [E, I, H], so transpose to [E, H, I] - # before grouping into [num_nsp, local_experts, H, I]. - # wt_g, wt_u = torch.split(self.experts.gate_up_proj, inter, dim=1) - wt_g, wt_u, W_d = self._split_expert_weights(H) - wt_g = wt_g.transpose(1, 2).contiguous() - wt_u = wt_u.transpose(1, 2).contiguous() - W_g = wt_g.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_u = wt_u.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - - # down_proj is [E, H, I]; blocked matmul expects [num_nsp, local_experts, I, H]. - W_d = W_d.transpose(1, 2).contiguous() - W_d = W_d.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.experts.down_proj_t.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) routing_weights_unsqueezed = rw.unsqueeze(-1) - + act_fn = getattr(self.experts, "act_fn", F.silu) for slot in range(local_experts): - routing_weight = rw[:, slot, :] - T2Ei = routing_weight > 0 - experts_out = _cumsum_scatter_gather_update_expert_blocked( + T2Ei = rw[:, slot, :] > 0 + expert_out = _cumsum_scatter_gather_update_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, slot], W_u=W_u[:, slot], W_d=W_d[:, slot], routing_weight=routing_weights_unsqueezed[:, slot], - expert_out=experts_out, + expert_out=expert_out, act_fn=act_fn, packed_chunk_size=packed_chunk_size, ) - return torch.einsum("ijk->jk", experts_out) - - # def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # B, S, H = hidden_states.shape - # T = B * S - # x = hidden_states.view(T, H) - # router_logits = self.gate(x) # [T, E] - # prob = F.softmax(router_logits, -1, dtype=torch.float) - # top_w, top_i = torch.topk(prob, self.top_k, -1) - # if self.norm_topk_prob: # only diff with mixtral sparse moe block! - # top_w /= top_w.sum(-1, keepdim=True) - # top_w = top_w.to(hidden_states.dtype) - # masked_logits = torch.zeros_like(router_logits) - # masked_logits.scatter_(1, top_i, top_w) - # routing_weights = masked_logits - # experts_out = x.new_zeros((T, H)) - # for e in range(self.gate.num_experts): - # routing_weight = routing_weights[:, e].unsqueeze(-1) - # W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T - # W_d = self.experts[e].down_proj.weight.T - # gate = x @ W_g - # up = x @ W_u - # down = (up * self.experts[e].act_fn(gate)) @ W_d - # experts_out += down * routing_weight - - # shared_expert_output = self.shared_expert(x) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - - # experts_out = experts_out + shared_expert_output - # return experts_out.view(B, S, H), router_logits - - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - B, S, H = hidden_states.shape - T = B * S - x = hidden_states.view(T, H) - act = getattr(self.experts, "act_fn", F.silu) - - prob, top_w, top_i = self.gate(hidden_states) - top_w = top_w.to(x.dtype) - routing_weights = torch.zeros((T, self.gate.num_experts), dtype=x.dtype) - routing_weights.scatter_(1, top_i, top_w) - - if self.gate.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - experts_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - - shared_expert_output = self.shared_expert(x) - shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - expert_output = experts_out + shared_expert_output - return expert_output.view(B, S, H) - - experts_out = torch.zeros_like(x, dtype=x.dtype) - # breakpoint() - for e in range(self.gate.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) - - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - # - gate_up = x @ W_gate_up_e.T # [T, 2I] - - I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] - intermediate = up * act(gate) - down = intermediate @ W_dn_e.T - masked_down = torch.where( - routing_weight > 0, down * routing_weight, torch.zeros_like(experts_out, dtype=down.dtype) - ) - # masked_down = down * routing_weight - experts_out += masked_down + experts_out = torch.einsum("ijk->jk", expert_out) shared_expert_output = self.shared_expert(x) shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py index a8e0f5af8..d116bcb73 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py @@ -31,11 +31,11 @@ # Enable KV blocking for full-attention layers with 2 KV blocks # To disable KV blocking, comment out the qaic_config line below # Set skip_kv=True to skip future KV blocks during inference (optimization) -qaic_config = {"blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} +qaic_config = {"enable_blocking": True, "blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} enable_blocking = False # By default blocking is false ### use skip_vision=True, if want to run only text, or false ### -skip_vision = False +skip_vision = True BS = 1 PREFILL_SEQ_LEN = 64 @@ -159,7 +159,7 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Describe all the colors seen in the image."}, + {"type": "text", "text": "Describe this image."}, ], }, ] diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py index fc221c94b..ca93259c7 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py @@ -87,7 +87,7 @@ width=536, num_cores=16, num_devices=1, - mxfp6_matmul=False, + mxfp6_matmul=True, mxint8_kv_cache=False, retain_full_kv=True, split_model_io=True, # This should be used for disagg serving via VLLM @@ -98,9 +98,12 @@ enable_chunking=True, skip_vision=True, use_onnx_subfunctions=False, +<<<<<<< HEAD layerwise_window_size=1, layerwise=LAYERWISE, offload_pt_weights=False, +======= +>>>>>>> 3d3435c (updating moe code and minor fixes) # qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) @@ -122,8 +125,11 @@ prefill_only=False, skip_vision=True, use_onnx_subfunctions=False, +<<<<<<< HEAD layerwise=LAYERWISE, layerwise_window_size=1, +======= +>>>>>>> 3d3435c (updating moe code and minor fixes) # qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) @@ -159,11 +165,13 @@ lang_decode_session = QAICInferenceSession(decode_qpc_path.get("lang_decode_qpc_path")) if skip_vision: + text_prompt_2 = "Describe yourself as a large language model, including your purpose, capabilities, and limitations. Explain how you process and generate responses, interact with users, and handle uncertainty, while emphasizing accuracy, safety, and helpfulness in diverse conversations across various topics and domains." + messages = [ { "role": "user", "content": [ - {"type": "text", "text": "Tell me about yourself."}, + {"type": "text", "text": text_prompt_2}, ], }, ] @@ -177,7 +185,7 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Describe all the colors seen in the image."}, + {"type": "text", "text": "Describe the image."}, ], }, ] diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py index 27ad5d49f..025c66bc5 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py @@ -31,7 +31,7 @@ # Enable KV blocking for full-attention layers with 2 KV blocks # To disable KV blocking, comment out the qaic_config line below # Set skip_kv=True to skip future KV blocks during inference (optimization) -qaic_config = {"blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} +qaic_config = {"enable_blocking": True, "blocking_mode": "kv", "num_kv_blocks": 4, "skip_kv": True} enable_blocking = False ## By default blocking is false ### use skip_vision=Ture, if want to run only text, or false ### @@ -49,14 +49,14 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, height=354, width=536, mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, mos=1, - # qaic_config=qaic_config, # Enable KV blocking - comment out to disable + qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) if enable_blocking: @@ -121,7 +121,7 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, height=354, width=536, mxfp6_matmul=True, From 794b2accc45aa7d91985ee93e4afcdb46769d030 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 8 Jun 2026 16:58:01 +0530 Subject: [PATCH 06/12] Updating attention Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 36 +++++++++---------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 26 +++++++------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index 99d414976..f4e556392 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -587,11 +587,11 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - # for i in range(1, chunk_size): - # row = attn[..., i, :i].clone() - # sub = attn[..., :i, :i].clone() - # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -623,7 +623,7 @@ def torch_chunk_gated_delta_rule_qeff( # acc_dtype = torch.float32 # A64 = A.to(acc_dtype) # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) # K = chunk_size - 1 # S64 = I64.clone() @@ -644,18 +644,18 @@ def torch_chunk_gated_delta_rule_qeff( # attn = X # Newton-Schulz updated - acc_dtype = torch.float32 - I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - L64 = attn.masked_fill(mask, 0).to(acc_dtype) - - # X0 = I - Xj = I64.clone() - for _ in range(chunk_size - 1): - Rj = I64 - ((I64 - L64) @ Xj) - Xj = Xj + (Xj @ Rj).masked_fill(~strict_lower, 0) - Xj = Xj.masked_fill(~strict_lower, 0) + I64 - - attn = Xj.to(attn.dtype) + # acc_dtype = torch.float32 + # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + # L64 = attn.masked_fill(mask, 0).to(acc_dtype) + + # # X0 = I + # Xj = I64.clone() + # for _ in range(chunk_size - 1): + # Rj = I64 - ((I64 - L64) @ Xj) + # Xj = Xj + (Xj @ Rj).masked_fill(~strict_lower, 0) + # Xj = Xj.masked_fill(~strict_lower, 0) + I64 + + # attn = Xj.to(attn.dtype) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index db687d8d5..71c820723 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -663,12 +663,12 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).to(decay_mask.dtype) # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - # for i in range(1, chunk_size): - # row = attn[..., i, :i].clone() - # sub = attn[..., :i, :i].clone() - # # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - # attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) - # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -710,15 +710,15 @@ def torch_chunk_gated_delta_rule_qeff( # Newton-Schulz - Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - L = attn.masked_fill(mask, 0) + # Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # L = attn.masked_fill(mask, 0) - X = Eye - for _ in range(int(math.log2(chunk_size)) + 2): - R = Eye - (Eye - L) @ X - X = X + X @ R + # X = Eye + # for _ in range(int(math.log2(chunk_size)) + 2): + # R = Eye - (Eye - L) @ X + # X = X + X @ R - attn = X + # attn = X value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) From 09db8ab9957b98ebbc931f5473891f8236e5b9df Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 10 Jun 2026 11:52:15 +0530 Subject: [PATCH 07/12] Updating substituion with Horners method Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 30 ++++++++-------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 36 +++++++++---------- .../models/qwen3_5/qwen3_5.py | 8 +++-- .../qwen3_5/qwen3_5_continous_batching.py | 1 + .../models/qwen3_5_moe/qwen3_5_moe.py | 2 ++ .../qwen3_5_moe_continous_batching.py | 1 + 6 files changed, 42 insertions(+), 36 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index f4e556392..609155b5c 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -587,11 +587,11 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -619,18 +619,18 @@ def torch_chunk_gated_delta_rule_qeff( ## Horners method - # A = attn.masked_fill(mask, 0) - # acc_dtype = torch.float32 - # A64 = A.to(acc_dtype) - # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + A = attn.masked_fill(mask, 0) + acc_dtype = torch.float32 + A64 = A.to(acc_dtype) + I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) - # K = chunk_size - 1 - # S64 = I64.clone() - # for _ in range(K): - # S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + K = chunk_size - 1 + S64 = I64.clone() + for _ in range(K): + S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) - # attn = S64 + attn = S64 # Newton-Schulz # Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 71c820723..7036886b7 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -663,12 +663,12 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).to(decay_mask.dtype) # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -695,18 +695,18 @@ def torch_chunk_gated_delta_rule_qeff( # attn = L # Horners Method - # A = attn.masked_fill(mask, 0) - # acc_dtype = torch.float32 - # A64 = A.to(acc_dtype) - # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) - - # K = chunk_size - 1 - # S64 = I64.clone() - # for _ in range(K): - # S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) - - # attn = S64.to(A.dtype) + A = attn.masked_fill(mask, 0) + acc_dtype = torch.float32 + A64 = A.to(acc_dtype) + I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + + K = chunk_size - 1 + S64 = I64.clone() + for _ in range(K): + S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + + attn = S64.to(A.dtype) # Newton-Schulz diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py index d116bcb73..1000c3a25 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py @@ -14,12 +14,12 @@ from QEfficient import QEFFAutoModelForImageTextToText -model_id = "Qwen/Qwen3.5-0.8B" +model_id = "Qwen/Qwen3.5-27B" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 -config.text_config.num_hidden_layers = 2 +# config.text_config.num_hidden_layers = 2 qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 @@ -49,8 +49,9 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, skip_vision=True, @@ -121,6 +122,7 @@ height=354, width=536, mxfp6_matmul=False, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py index f365e976a..3b601cb12 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py @@ -42,6 +42,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=True, aic_enable_depth_first=True, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py index 025c66bc5..cdc5df85e 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py @@ -53,6 +53,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, aic_enable_depth_first=True, skip_vision=True, mos=1, @@ -125,6 +126,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=True, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py index f83647fd7..97bbb6531 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py @@ -41,6 +41,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, mos=1, From c7da3622ab1a257e3b09e310ccb41323475d84b5 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 10 Jun 2026 14:54:03 +0530 Subject: [PATCH 08/12] minor fixes Signed-off-by: Mohit Soni --- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 10 ++++++---- dbg.log | 0 .../models/qwen3_5_moe/qwen3_5_disagg_mode.py | 16 +--------------- 3 files changed, 7 insertions(+), 19 deletions(-) create mode 100644 dbg.log diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 7036886b7..178772462 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1795,10 +1795,12 @@ def get_specializations( image_min_token_num = constants.IMAGE_MIN_TOKEN_NUM image_max_token_num = constants.IMAGE_MAX_TOKEN_NUM mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) - if mm_processor_kwargs: - min_pixels = mm_processor_kwargs.get("min_pixels", image_min_token_num * image_factor**2) - max_pixels = mm_processor_kwargs.get("max_pixels", image_max_token_num * image_factor**2) - + # if mm_processor_kwargs: + # min_pixels = mm_processor_kwargs.get("min_pixels", image_min_token_num * image_factor**2) + # max_pixels = mm_processor_kwargs.get("max_pixels", image_max_token_num * image_factor**2) + min_pixels = image_min_token_num * image_factor**2 + max_pixels = image_max_token_num * image_factor**2 + vision = [] max_vision_size = 0 user_vision_size = compiler_options.pop("vision_size", None) diff --git a/dbg.log b/dbg.log new file mode 100644 index 000000000..e69de29bb diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py index ca93259c7..76e8ed694 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py @@ -24,24 +24,16 @@ LAYERWISE = True # For faster execution user can run with lesser layers, For Testing Purpose Only -<<<<<<< HEAD config.vision_config.depth = 4 config.text_config.num_hidden_layers = 4 -config.torch_dtype = "float16" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, layerwise=LAYERWISE, -======= -config.vision_config.depth = 5 -config.text_config.num_hidden_layers = 2 - -qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 ->>>>>>> 2932096 (Adding fp16 support and comments addresses) ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) @@ -98,12 +90,9 @@ enable_chunking=True, skip_vision=True, use_onnx_subfunctions=False, -<<<<<<< HEAD layerwise_window_size=1, layerwise=LAYERWISE, offload_pt_weights=False, -======= ->>>>>>> 3d3435c (updating moe code and minor fixes) # qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) @@ -125,11 +114,8 @@ prefill_only=False, skip_vision=True, use_onnx_subfunctions=False, -<<<<<<< HEAD layerwise=LAYERWISE, layerwise_window_size=1, -======= ->>>>>>> 3d3435c (updating moe code and minor fixes) # qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) From 00720e90a58521c956e971ccb10fc518e3486cd9 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 10 Jun 2026 14:58:48 +0530 Subject: [PATCH 09/12] removing redundant code Signed-off-by: Mohit Soni --- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 178772462..dcc68d767 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1794,13 +1794,13 @@ def get_specializations( image_factor = constants.IMAGE_FACTOR_QWEN_3 image_min_token_num = constants.IMAGE_MIN_TOKEN_NUM image_max_token_num = constants.IMAGE_MAX_TOKEN_NUM - mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) + # mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) # if mm_processor_kwargs: # min_pixels = mm_processor_kwargs.get("min_pixels", image_min_token_num * image_factor**2) # max_pixels = mm_processor_kwargs.get("max_pixels", image_max_token_num * image_factor**2) min_pixels = image_min_token_num * image_factor**2 max_pixels = image_max_token_num * image_factor**2 - + vision = [] max_vision_size = 0 user_vision_size = compiler_options.pop("vision_size", None) @@ -2086,18 +2086,6 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=32, batch_size=1 inputs.pop("mm_token_type_ids") return inputs - -class QEffQwen3_5MoeTopKRouter(Qwen3_5MoeTopKRouter): - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) - router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = router_top_value / torch.einsum("bk->b", router_top_value).unsqueeze(-1) - router_top_value = router_top_value.to(router_logits.dtype) - router_scores = router_top_value - return router_logits, router_scores, router_indices - class QEffQwen3_5MoeExperts(Qwen3_5MoeExperts): def __qeff_init__(self): # transformers>=5 uses fused gate_up projections. Keep backward-compatible @@ -2225,7 +2213,6 @@ def forward(self, hidden_states): router_scores = router_top_value return router_logits, router_scores, router_indices - class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): supports_moe_prefill_blocking = True From 918001cff3cc5207c0ed9113451f58c8f5c66732 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 10 Jun 2026 15:00:15 +0530 Subject: [PATCH 10/12] minor fixes Signed-off-by: Mohit Soni --- .../transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index dcc68d767..a042cc3ca 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -2086,6 +2086,7 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=32, batch_size=1 inputs.pop("mm_token_type_ids") return inputs + class QEffQwen3_5MoeExperts(Qwen3_5MoeExperts): def __qeff_init__(self): # transformers>=5 uses fused gate_up projections. Keep backward-compatible @@ -2213,6 +2214,7 @@ def forward(self, hidden_states): router_scores = router_top_value return router_logits, router_scores, router_indices + class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): supports_moe_prefill_blocking = True From 96a9768cf32598db1cfec63a7d8ba9cad9bc52bf Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Wed, 10 Jun 2026 15:25:54 +0530 Subject: [PATCH 11/12] Updating rope Signed-off-by: Mohit Soni --- .../models/qwen3_5/modeling_qwen3_5.py | 66 +++++++++++++------ .../qwen3_5_moe/modeling_qwen3_5_moe.py | 63 ++++++++++++------ 2 files changed, 88 insertions(+), 41 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index 609155b5c..abe7ff0cd 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -53,6 +53,8 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE from QEfficient.utils.logging_utils import logger +QWEN3_5_ROPE_CACHE_EXPORT_CAP = 76800 + class QEffQwen3_5GatedDeltaNetCustomRMSNormAIC(nn.Module): """ @@ -215,8 +217,9 @@ class QEffQwen3_5TextRotaryEmbedding(Qwen3_5TextRotaryEmbedding): def __init__(self, config, device=None): super().__init__(config=config, device=device) + cached_seq_len = min(int(self.original_max_seq_len), QWEN3_5_ROPE_CACHE_EXPORT_CAP) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, + seq_len=cached_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) @@ -265,7 +268,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, mrope_section=None, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -298,14 +312,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) + cos = qeff_apply_interleaved_mrope(cos, mrope_section) + sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] @@ -379,12 +394,13 @@ class QEffQwen3_5Attention(Qwen3_5Attention): """ def __qeff_init__(self): - self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config) + # RoPE tensors are prepared once in the text model and passed down. + return def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], attention_mask: Optional[torch.Tensor], past_key_values: Optional[QEffQwen3_5DynamicCache] = None, position_ids: Optional[torch.LongTensor] = None, @@ -405,14 +421,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_values.get_seq_length(self.layer_idx, cache_position) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rotary_emb.mrope_section - ) + if position_embeddings is None: + raise ValueError("`position_embeddings` must be provided for QEffQwen3_5Attention.") + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) @@ -944,6 +956,19 @@ def forward( class QEffQwen3_5TextModel(Qwen3_5TextModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config) + # Export-only cap to avoid serializing oversized RoPE tables that are + # unreachable for deployed context lengths (e.g. 4K/8K/16K). This keeps + # non-layerwise QPC size aligned with layerwise exports. + rope_rows = min(int(self.rotary_emb.sin_cached.shape[0]), QWEN3_5_ROPE_CACHE_EXPORT_CAP) + self.sin_cached = torch.nn.Parameter( + (self.rotary_emb.sin_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + self.cos_cached = torch.nn.Parameter( + (self.rotary_emb.cos_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -993,8 +1018,9 @@ def forward( hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) - # position_embeddings = None + mrope_section = self.config.rope_parameters.get("mrope_section", [11, 11, 10]) + cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:], mrope_section) + position_embeddings = (cos, sin) all_hidden_states = () if output_hidden_states else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index a042cc3ca..22f8aeeaf 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -329,7 +329,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, mrope_section=None, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -362,14 +373,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) + cos = qeff_apply_interleaved_mrope(cos, mrope_section) + sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] @@ -444,13 +456,13 @@ class QEffQwen3_5MoeAttention(Qwen3_5MoeAttention): """ def __qeff_init__(self): - # pass - self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + # RoPE tensors are prepared once in the text model and passed down. + return def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], attention_mask: Optional[torch.Tensor], past_key_values: Optional[QEffQwen3_5MoeDynamicCache] = None, position_ids: Optional[torch.LongTensor] = None, @@ -471,14 +483,11 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_values.get_seq_length(self.layer_idx, cache_position) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if position_embeddings is None: + raise ValueError("`position_embeddings` must be provided for QEffQwen3_5MoeAttention.") - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rotary_emb.mrope_section - ) + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) @@ -1020,8 +1029,19 @@ class QEffQwen3_5MoeTextModel(Qwen3_5MoeTextModel): _start = 0 _end = 0 _total_layers = None - # def __qeff_init__(self): - # self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + # Export-only cap to avoid serializing oversized RoPE tables that are + # unreachable for deployed context lengths (e.g. 4K/8K/16K). This keeps + # non-layerwise QPC size aligned with layerwise exports. + rope_rows = min(int(self.rotary_emb.sin_cached.shape[0]), QWEN3_5_MOE_ROPE_CACHE_EXPORT_CAP) + self.sin_cached = torch.nn.Parameter( + (self.rotary_emb.sin_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + self.cos_cached = torch.nn.Parameter( + (self.rotary_emb.cos_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) def forward( self, @@ -1085,8 +1105,9 @@ def forward( hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) - # position_embeddings = None + mrope_section = self.config.rope_parameters.get("mrope_section", [11, 11, 10]) + cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:], mrope_section) + position_embeddings = (cos, sin) all_hidden_states = () if output_hidden_states else None layer_indices_to_run = kwargs.get("layer_indices_to_run", None) From 7f566fdac92008a6eca154cc840ccdb51c1a8a9e Mon Sep 17 00:00:00 2001 From: ochougul Date: Wed, 10 Jun 2026 17:11:15 +0530 Subject: [PATCH 12/12] fixed MOE Signed-off-by: ochougul --- .../transformers/models/pytorch_transforms.py | 3 ++ .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 50 ++++++++++++------- .../multi_specialization_inference.py | 2 +- .../qwen3_vl_moe_layerwise_decode.py | 4 +- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2d88bbb38..6c07d61cd 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -263,6 +263,7 @@ Qwen3VLMoeModel, Qwen3VLMoeTextAttention, Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextExperts, Qwen3VLMoeTextModel, Qwen3VLMoeTextRMSNorm, Qwen3VLMoeTextRotaryEmbedding, @@ -597,6 +598,7 @@ QEffQwen3VLMoeModel, QEffQwen3VLMoeTextAttention, QEffQwen3VLMoeTextDecoderLayer, + QEffQwen3VLMoeTextExperts, QEffQwen3VLMoeTextModel, QEffQwen3VLMoeTextRotaryEmbedding, QEffQwen3VLMoeTextSparseMoeBlock, @@ -745,6 +747,7 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, + Qwen3VLMoeTextExperts: QEffQwen3VLMoeTextExperts, Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, Qwen3VLMoeTextRotaryEmbedding: QEffQwen3VLMoeTextRotaryEmbedding, diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index cbef36f64..e8573159f 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -22,6 +22,7 @@ Qwen3VLMoeTextAttention, Qwen3VLMoeTextConfig, Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextExperts, Qwen3VLMoeTextModel, Qwen3VLMoeTextRotaryEmbedding, Qwen3VLMoeTextSparseMoeBlock, @@ -662,6 +663,27 @@ def _deepstack_process( return hidden_states + (visual_embeds * visual_mask) +class QEffQwen3VLMoeTextExperts(Qwen3VLMoeTextExperts): + def __qeff_init__(self): + hidden_dim = self.hidden_dim + if self.gate_up_proj.shape[1] == hidden_dim: + inter = self.gate_up_proj.shape[-1] // 2 + all_gate, all_up = torch.split(self.gate_up_proj, inter, dim=-1) + else: + inter = self.gate_up_proj.shape[1] // 2 + all_gate, all_up = torch.split(self.gate_up_proj, inter, dim=1) + all_gate = all_gate.transpose(1, 2) + all_up = all_up.transpose(1, 2) + + all_down = self.down_proj + if all_down.shape[1] == hidden_dim: + all_down = all_down.transpose(1, 2) + + self.all_gate = nn.Parameter(all_gate.contiguous().detach(), requires_grad=False) + self.all_up = nn.Parameter(all_up.contiguous().detach(), requires_grad=False) + self.all_down = nn.Parameter(all_down.contiguous().detach(), requires_grad=False) + + class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape @@ -674,7 +696,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w, top_i = torch.topk(router_logits, self.gate.top_k, dim=-1) top_w = F.softmax(top_w, dim=-1, dtype=torch.float) top_w = top_w.to(hidden_states.dtype) - num_experts = getattr(self, "num_experts", self.gate.num_experts) + num_experts = getattr(self.experts, "num_experts", self.gate.num_experts) routing_weights = torch.zeros((T, num_experts), dtype=x.dtype) routing_weights.scatter_(1, top_i, top_w) @@ -683,19 +705,10 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens for e in range(num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - if W_gate_up_e.shape[0] != H: - W_gate_up_e = W_gate_up_e.transpose(0, 1) - gate_up = x @ W_gate_up_e # [T, 2I] - - I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] + gate = x @ self.experts.all_gate[e] + up = x @ self.experts.all_up[e] intermediate = up * act(gate) - if W_dn_e.shape[0] != I2: - W_dn_e = W_dn_e.transpose(0, 1) - down = intermediate @ W_dn_e + down = intermediate @ self.experts.all_down[e] masked_down = torch.where( routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) ) # TODO: verify and remove @@ -958,16 +971,15 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = F.softmax(top_w, dim=-1, dtype=torch.float) top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) - w_up = self.experts.gate_up_proj.transpose(1, 2).index_select(0, idx) - w_dn = self.experts.down_proj.transpose(1, 2).index_select(0, idx) + w_gate = self.experts.all_gate[idx] + w_up = self.experts.all_up[idx] + w_dn = self.experts.all_down[idx] top_k = top_i.shape[-1] xk = x.unsqueeze(1).expand(-1, top_k, -1).contiguous() xk = xk.view(-1, 1, H) - gate_up = torch.bmm(xk, w_up) - I2 = gate_up.size(-1) - half = I2 // 2 - gate, up = gate_up[..., :half], gate_up[..., half:] + gate = torch.bmm(xk, w_gate) + up = torch.bmm(xk, w_up) intermediate = up * self.experts.act_fn(gate) experts_out = torch.bmm(intermediate, w_dn) experts_out = experts_out.view(T, top_k, H) * top_w.unsqueeze(-1) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py b/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py index d323ad608..99fe38ead 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py @@ -28,7 +28,7 @@ processor = AutoProcessor.from_pretrained(model_id) # use skip_vision=True, if want to run only text -skip_vision = False +skip_vision = True if skip_vision: # Only Text batch_size = 1 diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index f057b01bc..8d4b199a3 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -22,8 +22,8 @@ from QEfficient import QEFFAutoModelForImageTextToText # MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -# MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" -MODEL_ID = "tiny-random/qwen3-vl-moe" +MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# MODEL_ID = "tiny-random/qwen3-vl-moe" def main():