From 38143e4f35793584c7d3d3000a1bdca4344f63f4 Mon Sep 17 00:00:00 2001 From: vtirumal Date: Fri, 5 Jun 2026 14:59:47 +0530 Subject: [PATCH 1/3] Fix for logits issue in Qwen3 VL Signed-off-by: vtirumal --- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 4a6259bf8..dfa0942be 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -578,6 +578,11 @@ def forward( layer_idx = 0 start = QEffQwen3VLMoeTextModel._start end = QEffQwen3VLMoeTextModel._end + if QEffQwen3VLMoeTextModel._end == 0: + total_layers = len(self.layers) + end = total_layers + QEffQwen3VLMoeTextModel._end = total_layers + QEffQwen3VLMoeTextModel._total_layers = total_layers layer_indices_to_run = kwargs.get("layer_indices_to_run", None) for layer_idx, decoder_layer in enumerate(self.layers): @@ -849,11 +854,18 @@ def forward( visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, ) - if outputs.last_hidden_state.shape[1] > 1: - hidden_states = outputs.last_hidden_state + if QEffQwen3VLMoeTextModel._end == QEffQwen3VLMoeTextModel._total_layers: + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[ + torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index + ] + logits = self.model.lm_head(hidden_states) else: - hidden_states = outputs.last_hidden_state[:, -1:, :] - logits = hidden_states + if outputs.last_hidden_state.shape[1] > 1: + hidden_states = outputs.last_hidden_state + else: + hidden_states = outputs.last_hidden_state[:, -1:, :] + logits = hidden_states image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values From 5d1847afc651a1a6ee1d35e09ef7795ef1538499 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Sat, 6 Jun 2026 09:42:21 +0530 Subject: [PATCH 2/3] Added minnor fix Signed-off-by: Abhishek Kumar Singh --- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 24 ++++++++++++++----- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c564b644a..e55c1a369 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1013,7 +1013,15 @@ def forward( # position_embeddings = None all_hidden_states = () if output_hidden_states else None layer_indices_to_run = kwargs.get("layer_indices_to_run", None) - + + layer_idx = 0 + if QEffQwen3_5MoeTextModel._end == 0: + total_layers = len(self.layers) + end = total_layers + QEffQwen3_5MoeTextModel._end = total_layers + QEffQwen3_5MoeTextModel._total_layers = total_layers + layer_indices_to_run = kwargs.get("layer_indices_to_run", None) + for layer_idx, decoder_layer in enumerate(self.layers): if layer_idx < start or layer_idx >= end: continue @@ -1498,12 +1506,16 @@ def forward( batch_index=batch_index, use_cache=True, ) - logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) - if outputs.last_hidden_state.shape[1] > 1: - hidden_states = outputs.last_hidden_state + if QEffQwen3VLMoeTextModel._end == 0: + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) else: - hidden_states = outputs.last_hidden_state[:, -1:, :] - logits = hidden_states + if outputs.last_hidden_state.shape[1] > 1: + hidden_states = outputs.last_hidden_state + else: + hidden_states = outputs.last_hidden_state[:, -1:, :] + logits = hidden_states image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return logits, vision_embeds, image_idx, outputs.past_key_values diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index dfa0942be..7d57050d9 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -854,7 +854,7 @@ def forward( visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, ) - if QEffQwen3VLMoeTextModel._end == QEffQwen3VLMoeTextModel._total_layers: + if QEffQwen3VLMoeTextModel._end == 0: logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs.last_hidden_state[ torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index From b0e2f806aafb2336da6fd87802573ce88c32af0b Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Singh Date: Sat, 6 Jun 2026 10:00:41 +0530 Subject: [PATCH 3/3] Update modeling_qwen3_5_moe.py Signed-off-by: Abhishek Kumar Singh --- .../transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index e55c1a369..a267c525f 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1506,7 +1506,7 @@ def forward( batch_index=batch_index, use_cache=True, ) - if QEffQwen3VLMoeTextModel._end == 0: + if QEffQwen3_5MoeTextModel._end == 0: logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.model.lm_head(hidden_states)