Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,15 @@ def forward(
# position_embeddings = None
all_hidden_states = () if output_hidden_states else None
layer_indices_to_run = kwargs.get("layer_indices_to_run", None)


layer_idx = 0
if QEffQwen3_5MoeTextModel._end == 0:
total_layers = len(self.layers)
end = total_layers
QEffQwen3_5MoeTextModel._end = total_layers
QEffQwen3_5MoeTextModel._total_layers = total_layers
layer_indices_to_run = kwargs.get("layer_indices_to_run", None)

for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx < start or layer_idx >= end:
continue
Expand Down Expand Up @@ -1498,12 +1506,16 @@ def forward(
batch_index=batch_index,
use_cache=True,
)
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
if outputs.last_hidden_state.shape[1] > 1:
hidden_states = outputs.last_hidden_state
if QEffQwen3_5MoeTextModel._end == 0:
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
logits = self.model.lm_head(hidden_states)
else:
hidden_states = outputs.last_hidden_state[:, -1:, :]
logits = hidden_states
if outputs.last_hidden_state.shape[1] > 1:
hidden_states = outputs.last_hidden_state
else:
hidden_states = outputs.last_hidden_state[:, -1:, :]
logits = hidden_states
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return logits, vision_embeds, image_idx, outputs.past_key_values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ def forward(
layer_idx = 0
start = QEffQwen3VLMoeTextModel._start
end = QEffQwen3VLMoeTextModel._end
if QEffQwen3VLMoeTextModel._end == 0:
total_layers = len(self.layers)
end = total_layers
QEffQwen3VLMoeTextModel._end = total_layers
QEffQwen3VLMoeTextModel._total_layers = total_layers
layer_indices_to_run = kwargs.get("layer_indices_to_run", None)

for layer_idx, decoder_layer in enumerate(self.layers):
Expand Down Expand Up @@ -849,11 +854,18 @@ def forward(
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
)
if outputs.last_hidden_state.shape[1] > 1:
hidden_states = outputs.last_hidden_state
if QEffQwen3VLMoeTextModel._end == 0:
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[
torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index
]
logits = self.model.lm_head(hidden_states)
else:
hidden_states = outputs.last_hidden_state[:, -1:, :]
logits = hidden_states
if outputs.last_hidden_state.shape[1] > 1:
hidden_states = outputs.last_hidden_state
else:
hidden_states = outputs.last_hidden_state[:, -1:, :]
logits = hidden_states
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values

Expand Down
Loading