Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions 2025-Ascend-Innovation-Contest/S1/MoE/我想好好睡一觉/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#### 1. MoE模块拆分 prefill 和 decode 阶段逻辑

DeepseekMoE模块
```python
def forward(self, hidden_states):
if orig_shape[1] == 1:
y = self.moe_infer_decode(...).view(*orig_shape)
...
else:
# 复用原代码
y = self.moe_infer_prefill(...).view(*orig_shape)
```

decode阶段(单token推理)
```python
@no_grad()
def moe_infer_decode(self, x, flat_expert_indices, flat_expert_weights):
# Decode时单token直接遍历激活专家
expert_cache = ops.zeros_like(x)
for i in range(self.num_experts_per_tok):
expert_id = flat_expert_indices[i].item()
weight = flat_expert_weights[i].item()
expert = self.experts[expert_id]
expert_out = expert(x)
expert_cache += expert_out * weight
return expert_cache
```
直接遍历门控网络选中的`num_experts_per_tok`个专家,逐个计算并加权累加,计算效率最大化。

prefill阶段(长序列推理)
```python
@no_grad()
def moe_infer_prefill(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = ops.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
token_idxs = idxs // self.num_experts_per_tok
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
length = (end_idx - start_idx).item()
exp_token_idx = ops.narrow(token_idxs, 0, start_idx, length)
expert_tokens = F.embedding(exp_token_idx, x)
expert_out = expert(expert_tokens)
expert_out = expert_out.mul(F.embedding(ops.narrow(idxs, 0, start_idx, length), flat_expert_weights))
expert_cache = mindspore.mint.scatter_add(expert_cache, 0, exp_token_idx.view(-1, 1).tile((1, x.shape[-1])), expert_out)
return expert_cache
```
这里直接复用原代码中逻辑。

---

Qwen2MoeSparseMoeBlock模块
(deepseek模块的实现参考了培训给出的一些示例,基于这个经验,在qwen2_moe模块也进行了拆分获取了收益)

Qwen2-MoE在同一个`forward`函数内通过分支区分阶段(代码写的没那么漂亮,主打一个效率):
```python
def forward(self, hidden_states):
if sequence_length == 1: # decode阶段
# 仅循环活跃专家的精简逻辑
...
else: # prefill阶段
# 遍历所有专家的标准逻辑
...
```

decode阶段(单token):活跃专家筛选+复用prefill核心逻辑
```python
if sequence_length == 1:
expert_mask = nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# 筛选活跃专家(仅处理被选中的专家,减少循环次数)
expert_usage = ops.sum(expert_mask, dim=(1, 2)) # 统计每个专家被选中的次数
active_experts = ops.nonzero(expert_usage > 0, as_tuple=False).squeeze(-1)
# 遍历活跃专家(复用prefill的index_add逻辑)
for expert_idx_tensor in active_experts:
expert_idx = int(expert_idx_tensor.asnumpy().item())
expert_layer = self.experts[expert_idx]
idx, top_x = ops.nonzero(expert_mask[expert_idx], as_tuple=True)
if 0 not in idx.shape:
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

final_hidden_states = final_hidden_states.index_add(
0, top_x.int(), current_hidden_states.to(hidden_states.dtype)
)
```

decode阶段生成与prefill完全一致的`expert_mask`,消除掩码维度/排列差异导致的结果偏差;
仅循环被选中的专家,单token推理循环次数从`num_experts`降至`top_k`。

prefill阶段(长序列):标准掩码+全专家遍历
```python
else:
expert_mask = nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = ops.nonzero(expert_mask[expert_idx], as_tuple=True)
if 0 not in idx.shape:
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states = final_hidden_states.index_add(
0, top_x.int(), current_hidden_states.to(hidden_states.dtype)
)
```
这部分复用原本的MoE部分代码,遍历全部专家。

这样把prefill和decode分开来写带来了在decode阶段有达到秒级的收益。这样做完总分提升到了160+。


#### 2. 使用ops算子替换索引操作

```python
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
x1, x2 = ops.split(x, x.shape[-1] // 2, dim = -1)
return ops.cat((-x2, x1), dim=-1)


...

# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
cos = F.embedding(position_ids, cos).unsqueeze(unsqueeze_dim)
sin = F.embedding(position_ids, sin).unsqueeze(unsqueeze_dim)

...

hidden_states_expand = ops.unsqueeze(hidden_states, 2)
hidden_states = hidden_states_expand.broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim))

...
# self.cos_cached[:seq_len].to(dtype=x.dtype),
# self.sin_cached[:seq_len].to(dtype=x.dtype),
ops.narrow(self.cos_cached, 0, 0, seq_len).to(dtype=x.dtype),
ops.narrow(self.sin_cached, 0, 0, seq_len).to(dtype=x.dtype),

...

attention_mask_expanded = ops.unsqueeze(ops.unsqueeze(attention_mask, dim=1), dim=2)
padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask_expanded

```
使用ops替换掉tensor索引的一些操作能够带来一些细微的收益(几百ms),有的替换甚至没有收益(也可能是那块代码没执行)。

#### 3. 使用StaticCache,并开启JIT优化

使用StaticCache需要先在cache_utils.py文件中修改一下StaticCache类的update函数中的某一部分(被注释掉的是原来的):
```python
else:
# use index_add for mindspore since tensor slice is too slow and no implementation of index_copy
# k_out = ops.index_add(k_out, 2, cache_position.int(), key_states)
# v_out = ops.index_add(v_out, 2, cache_position.int(), value_states)
k_out.index_add_(2, cache_position.int(), key_states)
v_out.index_add_(2, cache_position.int(), value_states)
```

由于默认是DynamicCache,因此需要在utils.py中generate接口处修改一些地方,支持StaticCache的创建使用
```python
else:
model_type = getattr(self.config, 'model_type', '')
supports_cache_position = model_type in ['qwen2_moe']
# print('StaticCache')
if (
hasattr(self, '_supports_static_cache')
and self._supports_static_cache
and not requires_cross_attention_cache
and supports_cache_position
):
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
# print('StaticCache')
model_kwargs[cache_name] = self._get_cache(
cache_implementation="static",
max_batch_size=batch_size,
max_cache_len=max_cache_length,
model_kwargs=model_kwargs,
)
else:
num_hidden_layers = self.config.get_text_config().num_hidden_layers
model_kwargs[cache_name] = (
DynamicCache(num_hidden_layers)
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
)
```
这里StaticCache只限制了在qwen2_moe中使用,因为我尝试在deepseek中使用最后好像出现了mismatch,可能精度上有一些损失。

最后JIT优化使用在了RMSNorm部分,还在Qwen2MoeModel的_update_causal_mask部分加了JIT,都带来了收益。但是不知道为什么使用rms_norm的融合算子会导致mismatch,可能是精度有误差,所以放弃了融合算子,同样flash attention也会有这个问题。
```python
@mindspore.jit
def forward(self, hidden_states):
# if use_pyboost():
# return F.rms_norm(hidden_states, self.weight, self.variance_epsilon)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mindspore.float32)
variance = ops.mean(hidden_states.pow(2), -1, keepdim=True)
hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

...

@mindspore.jit
def _update_causal_mask(
self,
attention_mask: mindspore.Tensor,
input_tensor: mindspore.Tensor,
cache_position: mindspore.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
```

最后JIT优化牺牲了一些prefill的时延,但是对decode的时延带来了很大的提升,所以最后总分也提升到了280+。

---
#### 最终评测结果

| 评测指标 | 平均得分 |
| :-------: | :----------: |
| 峰值显存 | 100 |
| Prefill时延 | 62.3694 |
| Decode时延 | 696.3557 |
| **总分** | **286.2417** |
Binary file not shown.