diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b2b447a78..6c07d61cd 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -225,6 +225,7 @@ from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( Qwen3_5MoeAttention, Qwen3_5MoeDecoderLayer, + Qwen3_5MoeExperts, Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeGatedDeltaNet, @@ -262,6 +263,7 @@ Qwen3VLMoeModel, Qwen3VLMoeTextAttention, Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextExperts, Qwen3VLMoeTextModel, Qwen3VLMoeTextRMSNorm, Qwen3VLMoeTextRotaryEmbedding, @@ -559,6 +561,7 @@ QEffPrefillChunkedQwen3_5MoeSparseMoeBlock, QEffQwen3_5MoeAttention, QEffQwen3_5MoeDecoderLayer, + QEffQwen3_5MoeExperts, QEffQwen3_5MoeForCausalLM, QEffQwen3_5MoeForConditionalGeneration, QEffQwen3_5MoeGatedDeltaNet, @@ -595,6 +598,7 @@ QEffQwen3VLMoeModel, QEffQwen3VLMoeTextAttention, QEffQwen3VLMoeTextDecoderLayer, + QEffQwen3VLMoeTextExperts, QEffQwen3VLMoeTextModel, QEffQwen3VLMoeTextRotaryEmbedding, QEffQwen3VLMoeTextSparseMoeBlock, @@ -743,6 +747,7 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, + Qwen3VLMoeTextExperts: QEffQwen3VLMoeTextExperts, Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, Qwen3VLMoeTextRotaryEmbedding: QEffQwen3VLMoeTextRotaryEmbedding, @@ -868,6 +873,7 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3_5MoeVisionAttention: QEffQwen3_5MoeVisionAttention, Qwen3_5MoeVisionModel: QEffQwen3_5MoeVisionModel, Qwen3_5MoeTopKRouter: QEffQwen3_5MoeTopKRouter, + Qwen3_5MoeExperts: QEffQwen3_5MoeExperts, # Qwen2.5 VL Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, Qwen2_5_VLModel: QEffQwen2_5_VLModel, diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index c115022a6..abe7ff0cd 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -53,6 +53,8 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE from QEfficient.utils.logging_utils import logger +QWEN3_5_ROPE_CACHE_EXPORT_CAP = 76800 + class QEffQwen3_5GatedDeltaNetCustomRMSNormAIC(nn.Module): """ @@ -64,7 +66,7 @@ def forward(self, hidden_states, gate): CustomRMSNormFunc.apply( hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps ) - ) * F.silu(gate.to(torch.float32)) + ) * F.silu(gate.to(gate.dtype)) class QEffQwen3_5DynamicCache(Cache): @@ -215,8 +217,9 @@ class QEffQwen3_5TextRotaryEmbedding(Qwen3_5TextRotaryEmbedding): def __init__(self, config, device=None): super().__init__(config=config, device=device) + cached_seq_len = min(int(self.original_max_seq_len), QWEN3_5_ROPE_CACHE_EXPORT_CAP) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, + seq_len=cached_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) @@ -265,7 +268,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, mrope_section=None, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -298,14 +312,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) + cos = qeff_apply_interleaved_mrope(cos, mrope_section) + sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] @@ -338,7 +353,7 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -379,12 +394,13 @@ class QEffQwen3_5Attention(Qwen3_5Attention): """ def __qeff_init__(self): - self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config) + # RoPE tensors are prepared once in the text model and passed down. + return def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], attention_mask: Optional[torch.Tensor], past_key_values: Optional[QEffQwen3_5DynamicCache] = None, position_ids: Optional[torch.LongTensor] = None, @@ -405,14 +421,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_values.get_seq_length(self.layer_idx, cache_position) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rotary_emb.mrope_section - ) + if position_embeddings is None: + raise ValueError("`position_embeddings` must be provided for QEffQwen3_5Attention.") + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) @@ -530,7 +542,7 @@ def torch_chunk_gated_delta_rule_qeff( query = l2norm(query, dim=-1, eps=1e-6) key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + x.transpose(1, 2).contiguous().to(query.dtype) for x in (query, key, value, beta, g) ] mask = (position_ids[0] != -1).unsqueeze(1) @@ -587,11 +599,11 @@ def torch_chunk_gated_delta_rule_qeff( decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -619,18 +631,43 @@ def torch_chunk_gated_delta_rule_qeff( ## Horners method - # A = attn.masked_fill(mask, 0) + A = attn.masked_fill(mask, 0) + acc_dtype = torch.float32 + A64 = A.to(acc_dtype) + I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + + K = chunk_size - 1 + S64 = I64.clone() + for _ in range(K): + S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + + attn = S64 + + # Newton-Schulz + # Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # L = attn.masked_fill(mask, 0) + + # X = Eye + # for _ in range(int(math.log2(chunk_size)) + 2): + # R = Eye - (Eye - L) @ X + # X = X + X @ R + + # attn = X + + # Newton-Schulz updated # acc_dtype = torch.float32 - # A64 = A.to(acc_dtype) # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) + # L64 = attn.masked_fill(mask, 0).to(acc_dtype) - # K = chunk_size - 1 - # S64 = I64.clone() - # for _ in range(K): - # S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + # # X0 = I + # Xj = I64.clone() + # for _ in range(chunk_size - 1): + # Rj = I64 - ((I64 - L64) @ Xj) + # Xj = Xj + (Xj @ Rj).masked_fill(~strict_lower, 0) + # Xj = Xj.masked_fill(~strict_lower, 0) + I64 - # attn = S64 + # attn = Xj.to(attn.dtype) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) @@ -919,6 +956,19 @@ def forward( class QEffQwen3_5TextModel(Qwen3_5TextModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config) + # Export-only cap to avoid serializing oversized RoPE tables that are + # unreachable for deployed context lengths (e.g. 4K/8K/16K). This keeps + # non-layerwise QPC size aligned with layerwise exports. + rope_rows = min(int(self.rotary_emb.sin_cached.shape[0]), QWEN3_5_ROPE_CACHE_EXPORT_CAP) + self.sin_cached = torch.nn.Parameter( + (self.rotary_emb.sin_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + self.cos_cached = torch.nn.Parameter( + (self.rotary_emb.cos_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -968,8 +1018,9 @@ def forward( hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) - # position_embeddings = None + mrope_section = self.config.rope_parameters.get("mrope_section", [11, 11, 10]) + cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:], mrope_section) + position_embeddings = (cos, sin) all_hidden_states = () if output_hidden_states else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -1047,8 +1098,8 @@ def get_onnx_retained_state_specs( if layer_type == "full_attention": layer_names = [f"past_key.{layer_idx}", f"past_value.{layer_idx}"] layer_tensors = [ - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), + torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype), + torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype), ] layer_axes = [ {0: batch_axis_name, 2: "ctx_len"}, @@ -1060,8 +1111,8 @@ def get_onnx_retained_state_specs( recurrent_shape = (batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) layer_names = [f"conv_state.{layer_idx}", f"recurrent_state.{layer_idx}"] layer_tensors = [ - torch.zeros(conv_shape, dtype=torch.float32), - torch.zeros(recurrent_shape, dtype=torch.float32), + torch.zeros(conv_shape, dtype=self.model.config.dtype), + torch.zeros(recurrent_shape, dtype=self.model.config.dtype), ] layer_axes = [{0: batch_axis_name}, {0: batch_axis_name}] @@ -1362,7 +1413,7 @@ def forward( col_mask = (cols >= start) & (cols < end) block_mask = row_mask & col_mask - final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask = torch.ones((seq_len, seq_len), dtype=self.config.dtype) final_mask[block_mask.any(dim=0)] = 0 final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) attention_mask[0] = final_mask @@ -1758,10 +1809,10 @@ def get_dummy_inputs( vision_inputs = {} lang_inputs = {} - vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.model.config.dtype) vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.model.config.dtype) lang_inputs["position_ids"] = ( ( torch.arange(dummy_seq_len, dtype=torch.int64) @@ -1788,13 +1839,13 @@ def get_dummy_inputs( for i in range(self.model.config.text_config.num_hidden_layers): if self.model.config.text_config.layer_types[i] == "full_attention": for kv in ["key", "value"]: - lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=self.model.config.dtype)) else: layer = self.model.language_model.layers[i].linear_attn conv_shape = (linear_batch_size, layer.conv_dim, layer.conv_kernel_size) recurrent_shape = (linear_batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) - lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=torch.float32)) - lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=torch.float32)) + lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=self.model.config.dtype)) + lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=self.model.config.dtype)) # if continuous_batching: diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 39dbb0e1b..22f8aeeaf 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -20,6 +20,7 @@ Qwen3_5MoeAttention, Qwen3_5MoeCausalLMOutputWithPast, Qwen3_5MoeDecoderLayer, + Qwen3_5MoeExperts, Qwen3_5MoeForCausalLM, Qwen3_5MoeForConditionalGeneration, Qwen3_5MoeGatedDeltaNet, @@ -328,7 +329,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, mrope_section=None, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -361,14 +373,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) + cos = qeff_apply_interleaved_mrope(cos, mrope_section) + sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] @@ -399,11 +412,10 @@ def eager_attention_forward( value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - # - # MIN_MASKED_ATTENTION_VALUE = -10000 + if attention_mask is not None: attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -444,13 +456,13 @@ class QEffQwen3_5MoeAttention(Qwen3_5MoeAttention): """ def __qeff_init__(self): - # pass - self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + # RoPE tensors are prepared once in the text model and passed down. + return def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], attention_mask: Optional[torch.Tensor], past_key_values: Optional[QEffQwen3_5MoeDynamicCache] = None, position_ids: Optional[torch.LongTensor] = None, @@ -471,14 +483,11 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_values.get_seq_length(self.layer_idx, cache_position) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if position_embeddings is None: + raise ValueError("`position_embeddings` must be provided for QEffQwen3_5MoeAttention.") - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rotary_emb.mrope_section - ) + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) @@ -599,7 +608,7 @@ def torch_chunk_gated_delta_rule_qeff( query = query * torch.rsqrt(torch.einsum("bthd,bthd->bth", query, query).unsqueeze(-1) + 1e-6) key = key * torch.rsqrt(torch.einsum("bthd,bthd->bth", key, key).unsqueeze(-1) + 1e-6) query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + x.transpose(1, 2).contiguous().to(query.dtype) for x in (query, key, value, beta, g) ] mask = (position_ids[0] != -1).unsqueeze(1) @@ -658,17 +667,17 @@ def torch_chunk_gated_delta_rule_qeff( # decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() # original decay_mask diff = g.unsqueeze(-1) - g.unsqueeze(-2) # (B, H, num_chunks, C, C) - diff = diff * (~mask_strict).float() # zero upper triangle (strict) - decay_mask = diff.exp().float() - decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero + diff = diff * (~mask_strict).to(diff.dtype) # zero upper triangle (strict) + decay_mask = diff.exp().to(diff.dtype) + decay_mask = decay_mask * (~mask_strict).to(decay_mask.dtype) # ensure upper is zero attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # for i in range(1, chunk_size): + # row = attn[..., i, :i].clone() + # sub = attn[..., :i, :i].clone() + # # attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + # attn[..., i, :i] = row + torch.einsum("bghi,bghij->bghj", row, sub) + # attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) ## Approximation code ## # A = attn @@ -695,18 +704,30 @@ def torch_chunk_gated_delta_rule_qeff( # attn = L # Horners Method - # A = attn.masked_fill(mask, 0) - # acc_dtype = torch.float32 - # A64 = A.to(acc_dtype) - # I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) - # strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) - - # K = chunk_size - 1 - # S64 = I64.clone() - # for _ in range(K): - # S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + A = attn.masked_fill(mask, 0) + acc_dtype = torch.float32 + A64 = A.to(acc_dtype) + I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size) + strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size) - # attn = S64.to(A.dtype) + K = chunk_size - 1 + S64 = I64.clone() + for _ in range(K): + S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0) + + attn = S64.to(A.dtype) + + # Newton-Schulz + + # Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # L = attn.masked_fill(mask, 0) + + # X = Eye + # for _ in range(int(math.log2(chunk_size)) + 2): + # R = Eye - (Eye - L) @ X + # X = X + X @ R + + # attn = X value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) @@ -1008,8 +1029,19 @@ class QEffQwen3_5MoeTextModel(Qwen3_5MoeTextModel): _start = 0 _end = 0 _total_layers = None - # def __qeff_init__(self): - # self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3_5MoeTextRotaryEmbedding(config=self.config) + # Export-only cap to avoid serializing oversized RoPE tables that are + # unreachable for deployed context lengths (e.g. 4K/8K/16K). This keeps + # non-layerwise QPC size aligned with layerwise exports. + rope_rows = min(int(self.rotary_emb.sin_cached.shape[0]), QWEN3_5_MOE_ROPE_CACHE_EXPORT_CAP) + self.sin_cached = torch.nn.Parameter( + (self.rotary_emb.sin_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) + self.cos_cached = torch.nn.Parameter( + (self.rotary_emb.cos_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous() + ) def forward( self, @@ -1073,8 +1105,9 @@ def forward( hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) - # position_embeddings = None + mrope_section = self.config.rope_parameters.get("mrope_section", [11, 11, 10]) + cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:], mrope_section) + position_embeddings = (cos, sin) all_hidden_states = () if output_hidden_states else None layer_indices_to_run = kwargs.get("layer_indices_to_run", None) @@ -1479,8 +1512,7 @@ def forward( row_mask = (rows >= start) & (rows < end) col_mask = (cols >= start) & (cols < end) block_mask = row_mask & col_mask - - final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask = torch.ones((seq_len, seq_len), dtype=self.config.dtype) final_mask[block_mask.any(dim=0)] = 0 final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) attention_mask[0] = final_mask @@ -1783,10 +1815,12 @@ def get_specializations( image_factor = constants.IMAGE_FACTOR_QWEN_3 image_min_token_num = constants.IMAGE_MIN_TOKEN_NUM image_max_token_num = constants.IMAGE_MAX_TOKEN_NUM - mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) - if mm_processor_kwargs: - min_pixels = mm_processor_kwargs.get("min_pixels", image_min_token_num * image_factor**2) - max_pixels = mm_processor_kwargs.get("max_pixels", image_max_token_num * image_factor**2) + # mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) + # if mm_processor_kwargs: + # min_pixels = mm_processor_kwargs.get("min_pixels", image_min_token_num * image_factor**2) + # max_pixels = mm_processor_kwargs.get("max_pixels", image_max_token_num * image_factor**2) + min_pixels = image_min_token_num * image_factor**2 + max_pixels = image_max_token_num * image_factor**2 vision = [] max_vision_size = 0 @@ -2074,16 +2108,14 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=32, batch_size=1 return inputs -class QEffQwen3_5MoeTopKRouter(Qwen3_5MoeTopKRouter): - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) - router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = router_top_value / torch.einsum("bk->b", router_top_value).unsqueeze(-1) - router_top_value = router_top_value.to(router_logits.dtype) - router_scores = router_top_value - return router_logits, router_scores, router_indices +class QEffQwen3_5MoeExperts(Qwen3_5MoeExperts): + def __qeff_init__(self): + # transformers>=5 uses fused gate_up projections. Keep backward-compatible + # aliases expected by existing QEff paths. + self.expert_dim = getattr(self, "intermediate_size", self.gate_up_proj.shape[-2] // 2) + self.gate_proj = nn.Parameter(self.gate_up_proj[:, : self.expert_dim, :].detach().clone().transpose(1, 2)) + self.up_proj = nn.Parameter(self.gate_up_proj[:, self.expert_dim :, :].detach().clone().transpose(1, 2)) + self.down_proj_t = nn.Parameter(self.down_proj.detach().clone().transpose(1, 2)) class QEffQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): @@ -2095,18 +2127,18 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) - w_up = self.experts.gate_up_proj[idx.flatten()] - w_dn = self.experts.down_proj[idx.flatten()] + gate_proj = self.experts.gate_proj[idx.flatten()] + up_proj = self.experts.up_proj[idx.flatten()] + w_dn = self.experts.down_proj_t[idx.flatten()] xk = x.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous() xk = xk.view(-1, 1, H) - gate_proj, up_proj = torch.chunk(w_up, 2, dim=1) - gate = torch.bmm(xk, gate_proj.transpose(1, 2)) - up = torch.bmm(xk, up_proj.transpose(1, 2)) + gate = torch.bmm(xk, gate_proj) + up = torch.bmm(xk, up_proj) intermediate = up * self.experts.act_fn(gate) - experts_out = torch.bmm(intermediate, w_dn.transpose(1, 2)) + experts_out = torch.bmm(intermediate, w_dn) experts_out = experts_out.view(T, self.gate.top_k, H) * top_w.unsqueeze(-1) experts_out = torch.einsum("bnd->bd", experts_out) @@ -2122,7 +2154,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index""" + """Build packed->original token index.""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) @@ -2130,9 +2162,7 @@ def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) valid_dest = valid_prefix - 1 scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) - # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this - # can be switched back to ``torch.full_like(token_idx, int32_max)``. - matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = torch.full_like(token_idx, int32_max) matched_idx = CtxScatterFunc3DInt.apply( matched_idx.unsqueeze(-1), scatter_pos, @@ -2148,33 +2178,24 @@ def _cumsum_scatter_gather_update_expert_blocked( W_u: torch.Tensor, W_d: torch.Tensor, routing_weight: torch.Tensor, - experts_out: torch.Tensor, + expert_out: torch.Tensor, act_fn, - T: int, packed_chunk_size: int, ) -> torch.Tensor: """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. - Accumulates one local expert's contribution in-place onto ``experts_out``. + Accumulates one local expert's contribution in-place onto ``expert_out``. Uses a packed/cumsum layout so the MLP runs only over active rows, then scatters the weighted output back to original token positions. - - Shapes: - x : [T, H] - T2Ei : [num_nsp, T] (bool) - W_g, W_u : [num_nsp, H, I] - W_d : [num_nsp, I, H] - routing_weight : [num_nsp, T] - experts_out : [num_nsp, T, H] (accumulator, in-out) """ batch_size, seq_len = T2Ei.shape - packed_chunk_size = int(max(1, min(packed_chunk_size, seq_len))) + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) matched_idx = _build_matched_idx_from_cumsum(T2Ei) - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + # valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + valid_rows = torch.einsum("ij->i", T2Ei.to(torch.int32)).unsqueeze(1) row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) - rw_expanded = routing_weight.unsqueeze(-1) for packed_start in range(0, seq_len, packed_chunk_size): packed_stop = packed_start + packed_chunk_size chunk_matched_idx = matched_idx[:, packed_start:packed_stop] @@ -2185,132 +2206,84 @@ def _cumsum_scatter_gather_update_expert_blocked( up_prime = x_chunk @ W_u down_chunk = (up_prime * act_fn(gate_prime)) @ W_d - rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + rw_chunk = CtxGatherFunc3DGeneralized.apply(routing_weight, chunk_matched_idx) down_chunk = down_chunk * rw_chunk - - expert_out_chunk = CtxGatherFunc3DGeneralized.apply(experts_out, chunk_matched_idx) + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) updated_chunk = expert_out_chunk + down_chunk - chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + chunk_valid_rows = torch.clamp( + valid_rows - packed_start, + min=torch.zeros_like(valid_rows), + max=torch.full_like(valid_rows, packed_chunk_size), + ) updated_chunk = torch.where( (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) ) - experts_out = CtxScatterFunc3DGeneralized.apply(experts_out, chunk_matched_idx, updated_chunk) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - return experts_out + return expert_out + + +class QEffQwen3_5MoeTopKRouter(Qwen3_5MoeTopKRouter): + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1).to(router_logits.dtype) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = router_top_value / torch.einsum("bk->b", router_top_value).unsqueeze(-1) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices class QEffPrefillChunkedQwen3_5MoeSparseMoeBlock(Qwen3_5MoeSparseMoeBlock): - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - act_fn = getattr(self.experts, "act_fn", F.silu) - T, H = x.shape - num_nsp = EXPERT_BLOCKING_NUM_NSP - if self.gate.num_experts % num_nsp != 0: - raise ValueError( - f"num_experts ({self.gate.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" - ) - local_experts = self.gate.num_experts // num_nsp - rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - experts_out = x.new_zeros((num_nsp, T, H)) - inter = self.experts.gate_up_proj.shape[1] // 2 + supports_moe_prefill_blocking = True - # gate_up_proj is [E, 2I, H]. After split we get [E, I, H], so transpose to [E, H, I] - # before grouping into [num_nsp, local_experts, H, I]. - wt_g, wt_u = torch.split(self.experts.gate_up_proj, inter, dim=1) - wt_g = wt_g.transpose(1, 2).contiguous() - wt_u = wt_u.transpose(1, 2).contiguous() - W_g = wt_g.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_u = wt_u.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + def __qeff_init__(self): + self.top_k = getattr(self.gate, "top_k", None) + self.norm_topk_prob = getattr(self.gate, "norm_topk_prob", False) + self.num_experts = getattr(self.gate, "num_experts", self.experts.gate_proj.shape[0]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits, top_w, top_i = self.gate(x) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) - # down_proj is [E, H, I]; blocked matmul expects [num_nsp, local_experts, I, H]. - W_d = self.experts.down_proj.transpose(1, 2).contiguous() - W_d = W_d.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + num_nsp = getattr(self, "expert_blocking_num_nsp", self.num_experts) + packed_chunk_size = getattr(self, "expert_blocking_packed_chunk_size", T) + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by expert_blocking_num_nsp ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.experts.down_proj_t.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out = x.new_zeros((num_nsp, T, H)) + routing_weights_unsqueezed = rw.unsqueeze(-1) + act_fn = getattr(self.experts, "act_fn", F.silu) for slot in range(local_experts): - routing_weight = rw[:, slot, :] - T2Ei = routing_weight > 0 - experts_out = _cumsum_scatter_gather_update_expert_blocked( + T2Ei = rw[:, slot, :] > 0 + expert_out = _cumsum_scatter_gather_update_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, slot], W_u=W_u[:, slot], W_d=W_d[:, slot], - routing_weight=routing_weight, - experts_out=experts_out, + routing_weight=routing_weights_unsqueezed[:, slot], + expert_out=expert_out, act_fn=act_fn, - T=T, - packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, - ) - return experts_out.sum(dim=0) - - # def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # B, S, H = hidden_states.shape - # T = B * S - # x = hidden_states.view(T, H) - # router_logits = self.gate(x) # [T, E] - # prob = F.softmax(router_logits, -1, dtype=torch.float) - # top_w, top_i = torch.topk(prob, self.top_k, -1) - # if self.norm_topk_prob: # only diff with mixtral sparse moe block! - # top_w /= top_w.sum(-1, keepdim=True) - # top_w = top_w.to(hidden_states.dtype) - # masked_logits = torch.zeros_like(router_logits) - # masked_logits.scatter_(1, top_i, top_w) - # routing_weights = masked_logits - # experts_out = x.new_zeros((T, H)) - # for e in range(self.gate.num_experts): - # routing_weight = routing_weights[:, e].unsqueeze(-1) - # W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T - # W_d = self.experts[e].down_proj.weight.T - # gate = x @ W_g - # up = x @ W_u - # down = (up * self.experts[e].act_fn(gate)) @ W_d - # experts_out += down * routing_weight - - # shared_expert_output = self.shared_expert(x) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - - # experts_out = experts_out + shared_expert_output - # return experts_out.view(B, S, H), router_logits - - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - B, S, H = hidden_states.shape - T = B * S - x = hidden_states.view(T, H) - act = getattr(self.experts, "act_fn", F.silu) - - prob, top_w, top_i = self.gate(hidden_states) - top_w = top_w.to(x.dtype) - routing_weights = torch.zeros((T, self.gate.num_experts), dtype=x.dtype) - routing_weights.scatter_(1, top_i, top_w) - - # if self.gate.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: - # experts_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) - - # shared_expert_output = self.shared_expert(x) - # shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - # expert_output = experts_out + shared_expert_output - # return expert_output.view(B, S, H) - - experts_out = torch.zeros_like(x, dtype=x.dtype) - # breakpoint() - for e in range(self.gate.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) - - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - # - gate_up = x @ W_gate_up_e.T # [T, 2I] - - I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] - intermediate = up * act(gate) - down = intermediate @ W_dn_e.T - masked_down = torch.where( - routing_weight > 0, down * routing_weight, torch.zeros_like(experts_out, dtype=down.dtype) + packed_chunk_size=packed_chunk_size, ) - # masked_down = down * routing_weight - experts_out += masked_down + experts_out = torch.einsum("ijk->jk", expert_out) shared_expert_output = self.shared_expert(x) shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * shared_expert_output diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index cbef36f64..e8573159f 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -22,6 +22,7 @@ Qwen3VLMoeTextAttention, Qwen3VLMoeTextConfig, Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextExperts, Qwen3VLMoeTextModel, Qwen3VLMoeTextRotaryEmbedding, Qwen3VLMoeTextSparseMoeBlock, @@ -662,6 +663,27 @@ def _deepstack_process( return hidden_states + (visual_embeds * visual_mask) +class QEffQwen3VLMoeTextExperts(Qwen3VLMoeTextExperts): + def __qeff_init__(self): + hidden_dim = self.hidden_dim + if self.gate_up_proj.shape[1] == hidden_dim: + inter = self.gate_up_proj.shape[-1] // 2 + all_gate, all_up = torch.split(self.gate_up_proj, inter, dim=-1) + else: + inter = self.gate_up_proj.shape[1] // 2 + all_gate, all_up = torch.split(self.gate_up_proj, inter, dim=1) + all_gate = all_gate.transpose(1, 2) + all_up = all_up.transpose(1, 2) + + all_down = self.down_proj + if all_down.shape[1] == hidden_dim: + all_down = all_down.transpose(1, 2) + + self.all_gate = nn.Parameter(all_gate.contiguous().detach(), requires_grad=False) + self.all_up = nn.Parameter(all_up.contiguous().detach(), requires_grad=False) + self.all_down = nn.Parameter(all_down.contiguous().detach(), requires_grad=False) + + class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape @@ -674,7 +696,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w, top_i = torch.topk(router_logits, self.gate.top_k, dim=-1) top_w = F.softmax(top_w, dim=-1, dtype=torch.float) top_w = top_w.to(hidden_states.dtype) - num_experts = getattr(self, "num_experts", self.gate.num_experts) + num_experts = getattr(self.experts, "num_experts", self.gate.num_experts) routing_weights = torch.zeros((T, num_experts), dtype=x.dtype) routing_weights.scatter_(1, top_i, top_w) @@ -683,19 +705,10 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens for e in range(num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) - W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] - W_dn_e = self.experts.down_proj[e] # [I, H] - if W_gate_up_e.shape[0] != H: - W_gate_up_e = W_gate_up_e.transpose(0, 1) - gate_up = x @ W_gate_up_e # [T, 2I] - - I2 = gate_up.shape[-1] // 2 - gate = gate_up[:, :I2] # [T, I] - up = gate_up[:, I2:] # [T, I] + gate = x @ self.experts.all_gate[e] + up = x @ self.experts.all_up[e] intermediate = up * act(gate) - if W_dn_e.shape[0] != I2: - W_dn_e = W_dn_e.transpose(0, 1) - down = intermediate @ W_dn_e + down = intermediate @ self.experts.all_down[e] masked_down = torch.where( routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) ) # TODO: verify and remove @@ -958,16 +971,15 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = F.softmax(top_w, dim=-1, dtype=torch.float) top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) - w_up = self.experts.gate_up_proj.transpose(1, 2).index_select(0, idx) - w_dn = self.experts.down_proj.transpose(1, 2).index_select(0, idx) + w_gate = self.experts.all_gate[idx] + w_up = self.experts.all_up[idx] + w_dn = self.experts.all_down[idx] top_k = top_i.shape[-1] xk = x.unsqueeze(1).expand(-1, top_k, -1).contiguous() xk = xk.view(-1, 1, H) - gate_up = torch.bmm(xk, w_up) - I2 = gate_up.size(-1) - half = I2 // 2 - gate, up = gate_up[..., :half], gate_up[..., half:] + gate = torch.bmm(xk, w_gate) + up = torch.bmm(xk, w_up) intermediate = up * self.experts.act_fn(gate) experts_out = torch.bmm(intermediate, w_dn) experts_out = experts_out.view(T, top_k, H) * top_w.unsqueeze(-1) diff --git a/dbg.log b/dbg.log new file mode 100644 index 000000000..e69de29bb diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py index 9f8f498b4..1000c3a25 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import requests +import torch import transformers from PIL import Image from qwen_vl_utils import process_vision_info @@ -13,16 +14,15 @@ from QEfficient import QEFFAutoModelForImageTextToText -model_id = "Qwen/Qwen3.5-0.8B" +model_id = "Qwen/Qwen3.5-27B" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 -config.text_config.num_hidden_layers = 2 -config.torch_dtype = "float32" +# config.text_config.num_hidden_layers = 2 qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) @@ -31,11 +31,11 @@ # Enable KV blocking for full-attention layers with 2 KV blocks # To disable KV blocking, comment out the qaic_config line below # Set skip_kv=True to skip future KV blocks during inference (optimization) -qaic_config = {"blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} +qaic_config = {"enable_blocking": True, "blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} enable_blocking = False # By default blocking is false ### use skip_vision=True, if want to run only text, or false ### -skip_vision = False +skip_vision = True BS = 1 PREFILL_SEQ_LEN = 64 @@ -49,8 +49,9 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, skip_vision=True, @@ -121,6 +122,7 @@ height=354, width=536, mxfp6_matmul=False, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, mos=1, @@ -159,7 +161,7 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Describe all the colors seen in the image."}, + {"type": "text", "text": "Describe this image."}, ], }, ] diff --git a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py index ffead3420..3b601cb12 100644 --- a/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5/qwen3_5_continous_batching.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import torch import transformers from transformers import AutoConfig, AutoProcessor @@ -23,6 +24,7 @@ attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, continuous_batching=True, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) @@ -40,6 +42,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=True, aic_enable_depth_first=True, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py index 4a2a15e21..76e8ed694 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_disagg_mode.py @@ -26,13 +26,13 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 4 -config.torch_dtype = "float16" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, layerwise=LAYERWISE, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) @@ -79,7 +79,7 @@ width=536, num_cores=16, num_devices=1, - mxfp6_matmul=False, + mxfp6_matmul=True, mxint8_kv_cache=False, retain_full_kv=True, split_model_io=True, # This should be used for disagg serving via VLLM @@ -151,11 +151,13 @@ lang_decode_session = QAICInferenceSession(decode_qpc_path.get("lang_decode_qpc_path")) if skip_vision: + text_prompt_2 = "Describe yourself as a large language model, including your purpose, capabilities, and limitations. Explain how you process and generate responses, interact with users, and handle uncertainty, while emphasizing accuracy, safety, and helpfulness in diverse conversations across various topics and domains." + messages = [ { "role": "user", "content": [ - {"type": "text", "text": "Tell me about yourself."}, + {"type": "text", "text": text_prompt_2}, ], }, ] @@ -169,7 +171,7 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Describe all the colors seen in the image."}, + {"type": "text", "text": "Describe the image."}, ], }, ] diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py index 435f352b3..cdc5df85e 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import requests +import torch import transformers from PIL import Image from qwen_vl_utils import process_vision_info @@ -19,10 +20,9 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 4 -config.torch_dtype = "float32" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float32 ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) @@ -31,7 +31,7 @@ # Enable KV blocking for full-attention layers with 2 KV blocks # To disable KV blocking, comment out the qaic_config line below # Set skip_kv=True to skip future KV blocks during inference (optimization) -qaic_config = {"blocking_mode": "kv", "num_kv_blocks": 2, "skip_kv": True} +qaic_config = {"enable_blocking": True, "blocking_mode": "kv", "num_kv_blocks": 4, "skip_kv": True} enable_blocking = False ## By default blocking is false ### use skip_vision=Ture, if want to run only text, or false ### @@ -49,14 +49,15 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, height=354, width=536, mxfp6_matmul=True, + split_model_io=True, aic_enable_depth_first=True, skip_vision=True, mos=1, - # qaic_config=qaic_config, # Enable KV blocking - comment out to disable + qaic_config=qaic_config, # Enable KV blocking - comment out to disable ) if enable_blocking: @@ -121,10 +122,11 @@ prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - num_devices=1, + num_devices=4, height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=True, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py index 95ae66d12..97bbb6531 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_continous_batching.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import torch import transformers from transformers import AutoConfig, AutoProcessor @@ -16,13 +17,13 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only config.vision_config.depth = 4 config.text_config.num_hidden_layers = 2 -config.torch_dtype = "float32" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config, + dtype=torch.float32, continuous_batching=True, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) @@ -40,6 +41,7 @@ height=354, width=536, mxfp6_matmul=True, + split_model_io=True, mxint8_kv_cache=False, aic_enable_depth_first=False, mos=1, diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py b/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py index d323ad608..99fe38ead 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/multi_specialization_inference.py @@ -28,7 +28,7 @@ processor = AutoProcessor.from_pretrained(model_id) # use skip_vision=True, if want to run only text -skip_vision = False +skip_vision = True if skip_vision: # Only Text batch_size = 1 diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index f057b01bc..8d4b199a3 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -22,8 +22,8 @@ from QEfficient import QEFFAutoModelForImageTextToText # MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -# MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" -MODEL_ID = "tiny-random/qwen3-vl-moe" +MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# MODEL_ID = "tiny-random/qwen3-vl-moe" def main():