Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions auto_round/modeling/fused_moe/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

# convert router indices into OHE list
# reshape to be (num_experts, top_k, batch_size * sequence_length)
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx, expert_layer in enumerate(self.experts):
idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0))

if self.calibrate_all_experts:
with torch.no_grad():
Comment thread
WeiweiZhang1 marked this conversation as resolved.
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
if self.calibrate_all_experts:
for expert_idx, expert_layer in enumerate(self.experts):
idx, token_idx = torch.where(expert_mask[expert_idx])
expert_out = expert_layer(hidden_states)[token_idx]
else:
expert_out = expert_layer(hidden_states[token_idx])

if len(token_idx) > 0:
# if there are tokens meant for this expert, further scale the expert
# output by the score
if len(token_idx) > 0:
weighted_output = expert_out * routing_weights[token_idx, idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
else:
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
Comment thread
WeiweiZhang1 marked this conversation as resolved.
idx, token_idx = torch.where(expert_mask[expert_idx])
expert_out = self.experts[expert_idx](hidden_states[token_idx])
weighted_output = expert_out * routing_weights[token_idx, idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.reshape(batch_size, sequence_length, hidden_dim)
Expand Down