Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeAttention,
Qwen3_5MoeDecoderLayer,
Qwen3_5MoeExperts,
Qwen3_5MoeForCausalLM,
Qwen3_5MoeForConditionalGeneration,
Qwen3_5MoeGatedDeltaNet,
Expand Down Expand Up @@ -262,6 +263,7 @@
Qwen3VLMoeModel,
Qwen3VLMoeTextAttention,
Qwen3VLMoeTextDecoderLayer,
Qwen3VLMoeTextExperts,
Qwen3VLMoeTextModel,
Qwen3VLMoeTextRMSNorm,
Qwen3VLMoeTextRotaryEmbedding,
Expand Down Expand Up @@ -559,6 +561,7 @@
QEffPrefillChunkedQwen3_5MoeSparseMoeBlock,
QEffQwen3_5MoeAttention,
QEffQwen3_5MoeDecoderLayer,
QEffQwen3_5MoeExperts,
QEffQwen3_5MoeForCausalLM,
QEffQwen3_5MoeForConditionalGeneration,
QEffQwen3_5MoeGatedDeltaNet,
Expand Down Expand Up @@ -595,6 +598,7 @@
QEffQwen3VLMoeModel,
QEffQwen3VLMoeTextAttention,
QEffQwen3VLMoeTextDecoderLayer,
QEffQwen3VLMoeTextExperts,
QEffQwen3VLMoeTextModel,
QEffQwen3VLMoeTextRotaryEmbedding,
QEffQwen3VLMoeTextSparseMoeBlock,
Expand Down Expand Up @@ -743,6 +747,7 @@ class KVCacheTransform(ModuleMappingTransform):
Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer,
Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention,
Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel,
Qwen3VLMoeTextExperts: QEffQwen3VLMoeTextExperts,
Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel,
Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock,
Qwen3VLMoeTextRotaryEmbedding: QEffQwen3VLMoeTextRotaryEmbedding,
Expand Down Expand Up @@ -868,6 +873,7 @@ class KVCacheTransform(ModuleMappingTransform):
Qwen3_5MoeVisionAttention: QEffQwen3_5MoeVisionAttention,
Qwen3_5MoeVisionModel: QEffQwen3_5MoeVisionModel,
Qwen3_5MoeTopKRouter: QEffQwen3_5MoeTopKRouter,
Qwen3_5MoeExperts: QEffQwen3_5MoeExperts,
# Qwen2.5 VL
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
Qwen2_5_VLModel: QEffQwen2_5_VLModel,
Expand Down
143 changes: 97 additions & 46 deletions QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
from QEfficient.utils.logging_utils import logger

QWEN3_5_ROPE_CACHE_EXPORT_CAP = 76800


class QEffQwen3_5GatedDeltaNetCustomRMSNormAIC(nn.Module):
"""
Expand All @@ -64,7 +66,7 @@ def forward(self, hidden_states, gate):
CustomRMSNormFunc.apply(
hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps
)
) * F.silu(gate.to(torch.float32))
) * F.silu(gate.to(gate.dtype))


class QEffQwen3_5DynamicCache(Cache):
Expand Down Expand Up @@ -215,8 +217,9 @@ class QEffQwen3_5TextRotaryEmbedding(Qwen3_5TextRotaryEmbedding):

def __init__(self, config, device=None):
super().__init__(config=config, device=device)
cached_seq_len = min(int(self.original_max_seq_len), QWEN3_5_ROPE_CACHE_EXPORT_CAP)
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len,
seq_len=cached_seq_len,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
Expand Down Expand Up @@ -265,7 +268,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section):
return freqs_t


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section):
invalid_pos_mask = position_ids < 0
safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids)
flat_pos = safe_position_ids.reshape(-1)
cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1])
sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1])
cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1)
sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1)
return cos, sin


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, mrope_section=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).

Explanation:
Expand Down Expand Up @@ -298,14 +312,15 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

cos = cos[position_ids]
sin = sin[position_ids]
if position_ids is not None:
cos = cos[position_ids]
sin = sin[position_ids]

cos = qeff_apply_interleaved_mrope(cos, mrope_section)
sin = qeff_apply_interleaved_mrope(sin, mrope_section)
cos = qeff_apply_interleaved_mrope(cos, mrope_section)
sin = qeff_apply_interleaved_mrope(sin, mrope_section)

cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
Expand Down Expand Up @@ -338,7 +353,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -379,12 +394,13 @@ class QEffQwen3_5Attention(Qwen3_5Attention):
"""

def __qeff_init__(self):
self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config)
# RoPE tensors are prepared once in the text model and passed down.
return

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[QEffQwen3_5DynamicCache] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand All @@ -405,14 +421,10 @@ def forward(
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
kv_seq_len = past_key_values.get_seq_length(self.layer_idx, cache_position)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids[1:], self.rotary_emb.mrope_section
)
if position_embeddings is None:
raise ValueError("`position_embeddings` must be provided for QEffQwen3_5Attention.")
cos, sin = position_embeddings
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig())
Expand Down Expand Up @@ -530,7 +542,7 @@ def torch_chunk_gated_delta_rule_qeff(
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
x.transpose(1, 2).contiguous().to(query.dtype) for x in (query, key, value, beta, g)
]

mask = (position_ids[0] != -1).unsqueeze(1)
Expand Down Expand Up @@ -587,11 +599,11 @@ def torch_chunk_gated_delta_rule_qeff(
decay_mask = decay_mask * (~mask_strict).float() # ensure upper is zero

attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
# for i in range(1, chunk_size):
# row = attn[..., i, :i].clone()
# sub = attn[..., :i, :i].clone()
# attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
# attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)

## Approximation code ##
# A = attn
Expand Down Expand Up @@ -619,18 +631,43 @@ def torch_chunk_gated_delta_rule_qeff(

## Horners method

# A = attn.masked_fill(mask, 0)
A = attn.masked_fill(mask, 0)
acc_dtype = torch.float32
A64 = A.to(acc_dtype)
I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size)
strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size)

K = chunk_size - 1
S64 = I64.clone()
for _ in range(K):
S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0)

attn = S64

# Newton-Schulz
# Eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
# L = attn.masked_fill(mask, 0)

# X = Eye
# for _ in range(int(math.log2(chunk_size)) + 2):
# R = Eye - (Eye - L) @ X
# X = X + X @ R

# attn = X

# Newton-Schulz updated
# acc_dtype = torch.float32
# A64 = A.to(acc_dtype)
# I64 = torch.eye(chunk_size, device=attn.device, dtype=acc_dtype).view(1, 1, 1, chunk_size, chunk_size)
# strict_lower = (~mask).view(1, 1, 1, chunk_size, chunk_size)
# L64 = attn.masked_fill(mask, 0).to(acc_dtype)

# K = chunk_size - 1
# S64 = I64.clone()
# for _ in range(K):
# S64 = I64 + (A64 @ S64).masked_fill(~strict_lower, 0)
# # X0 = I
# Xj = I64.clone()
# for _ in range(chunk_size - 1):
# Rj = I64 - ((I64 - L64) @ Xj)
# Xj = Xj + (Xj @ Rj).masked_fill(~strict_lower, 0)
# Xj = Xj.masked_fill(~strict_lower, 0) + I64

# attn = S64
# attn = Xj.to(attn.dtype)

value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
Expand Down Expand Up @@ -919,6 +956,19 @@ def forward(


class QEffQwen3_5TextModel(Qwen3_5TextModel):
def __qeff_init__(self):
self.rotary_emb = QEffQwen3_5TextRotaryEmbedding(config=self.config)
# Export-only cap to avoid serializing oversized RoPE tables that are
# unreachable for deployed context lengths (e.g. 4K/8K/16K). This keeps
# non-layerwise QPC size aligned with layerwise exports.
rope_rows = min(int(self.rotary_emb.sin_cached.shape[0]), QWEN3_5_ROPE_CACHE_EXPORT_CAP)
self.sin_cached = torch.nn.Parameter(
(self.rotary_emb.sin_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous()
)
self.cos_cached = torch.nn.Parameter(
(self.rotary_emb.cos_cached[:rope_rows] * self.rotary_emb.attention_scaling).contiguous()
)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -968,8 +1018,9 @@ def forward(

hidden_states = inputs_embeds

position_embeddings = self.rotary_emb(hidden_states, position_ids[1:])
# position_embeddings = None
mrope_section = self.config.rope_parameters.get("mrope_section", [11, 11, 10])
cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:], mrope_section)
position_embeddings = (cos, sin)
all_hidden_states = () if output_hidden_states else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
Expand Down Expand Up @@ -1047,8 +1098,8 @@ def get_onnx_retained_state_specs(
if layer_type == "full_attention":
layer_names = [f"past_key.{layer_idx}", f"past_value.{layer_idx}"]
layer_tensors = [
torch.zeros(tuple(kv_cache_shape), dtype=torch.float32),
torch.zeros(tuple(kv_cache_shape), dtype=torch.float32),
torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype),
torch.zeros(tuple(kv_cache_shape), dtype=self.model.config.dtype),
]
layer_axes = [
{0: batch_axis_name, 2: "ctx_len"},
Expand All @@ -1060,8 +1111,8 @@ def get_onnx_retained_state_specs(
recurrent_shape = (batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim)
layer_names = [f"conv_state.{layer_idx}", f"recurrent_state.{layer_idx}"]
layer_tensors = [
torch.zeros(conv_shape, dtype=torch.float32),
torch.zeros(recurrent_shape, dtype=torch.float32),
torch.zeros(conv_shape, dtype=self.model.config.dtype),
torch.zeros(recurrent_shape, dtype=self.model.config.dtype),
]
layer_axes = [{0: batch_axis_name}, {0: batch_axis_name}]

Expand Down Expand Up @@ -1362,7 +1413,7 @@ def forward(
col_mask = (cols >= start) & (cols < end)
block_mask = row_mask & col_mask

final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32)
final_mask = torch.ones((seq_len, seq_len), dtype=self.config.dtype)
final_mask[block_mask.any(dim=0)] = 0
final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask)
attention_mask[0] = final_mask
Expand Down Expand Up @@ -1758,10 +1809,10 @@ def get_dummy_inputs(

vision_inputs = {}
lang_inputs = {}
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.model.config.dtype)
vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64)
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.model.config.dtype)
lang_inputs["position_ids"] = (
(
torch.arange(dummy_seq_len, dtype=torch.int64)
Expand All @@ -1788,13 +1839,13 @@ def get_dummy_inputs(
for i in range(self.model.config.text_config.num_hidden_layers):
if self.model.config.text_config.layer_types[i] == "full_attention":
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=self.model.config.dtype))
else:
layer = self.model.language_model.layers[i].linear_attn
conv_shape = (linear_batch_size, layer.conv_dim, layer.conv_kernel_size)
recurrent_shape = (linear_batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim)
lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=torch.float32))
lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=torch.float32))
lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=self.model.config.dtype))
lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=self.model.config.dtype))

#
if continuous_batching:
Expand Down
Loading
Loading