diff --git "a/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/README.md" "b/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/README.md" new file mode 100644 index 00000000..bd83306d --- /dev/null +++ "b/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/README.md" @@ -0,0 +1,229 @@ +#### 1. MoE模块拆分 prefill 和 decode 阶段逻辑 + +DeepseekMoE模块 +```python +def forward(self, hidden_states): + if orig_shape[1] == 1: + y = self.moe_infer_decode(...).view(*orig_shape) + ... + else: + # 复用原代码 + y = self.moe_infer_prefill(...).view(*orig_shape) +``` + +decode阶段(单token推理) +```python +@no_grad() +def moe_infer_decode(self, x, flat_expert_indices, flat_expert_weights): + # Decode时单token直接遍历激活专家 + expert_cache = ops.zeros_like(x) + for i in range(self.num_experts_per_tok): + expert_id = flat_expert_indices[i].item() + weight = flat_expert_weights[i].item() + expert = self.experts[expert_id] + expert_out = expert(x) + expert_cache += expert_out * weight + return expert_cache +``` +直接遍历门控网络选中的`num_experts_per_tok`个专家,逐个计算并加权累加,计算效率最大化。 + +prefill阶段(长序列推理) +```python +@no_grad() +def moe_infer_prefill(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = ops.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cumsum(0) + token_idxs = idxs // self.num_experts_per_tok + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + length = (end_idx - start_idx).item() + exp_token_idx = ops.narrow(token_idxs, 0, start_idx, length) + expert_tokens = F.embedding(exp_token_idx, x) + expert_out = expert(expert_tokens) + expert_out = expert_out.mul(F.embedding(ops.narrow(idxs, 0, start_idx, length), flat_expert_weights)) + expert_cache = mindspore.mint.scatter_add(expert_cache, 0, exp_token_idx.view(-1, 1).tile((1, x.shape[-1])), expert_out) + return expert_cache +``` +这里直接复用原代码中逻辑。 + +--- + +Qwen2MoeSparseMoeBlock模块 +(deepseek模块的实现参考了培训给出的一些示例,基于这个经验,在qwen2_moe模块也进行了拆分获取了收益) + +Qwen2-MoE在同一个`forward`函数内通过分支区分阶段(代码写的没那么漂亮,主打一个效率): +```python +def forward(self, hidden_states): + if sequence_length == 1: # decode阶段 + # 仅循环活跃专家的精简逻辑 + ... + else: # prefill阶段 + # 遍历所有专家的标准逻辑 + ... +``` + +decode阶段(单token):活跃专家筛选+复用prefill核心逻辑 +```python +if sequence_length == 1: + expert_mask = nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + # 筛选活跃专家(仅处理被选中的专家,减少循环次数) + expert_usage = ops.sum(expert_mask, dim=(1, 2)) # 统计每个专家被选中的次数 + active_experts = ops.nonzero(expert_usage > 0, as_tuple=False).squeeze(-1) + # 遍历活跃专家(复用prefill的index_add逻辑) + for expert_idx_tensor in active_experts: + expert_idx = int(expert_idx_tensor.asnumpy().item()) + expert_layer = self.experts[expert_idx] + idx, top_x = ops.nonzero(expert_mask[expert_idx], as_tuple=True) + if 0 not in idx.shape: + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states = final_hidden_states.index_add( + 0, top_x.int(), current_hidden_states.to(hidden_states.dtype) + ) +``` + + decode阶段生成与prefill完全一致的`expert_mask`,消除掩码维度/排列差异导致的结果偏差; + 仅循环被选中的专家,单token推理循环次数从`num_experts`降至`top_k`。 + +prefill阶段(长序列):标准掩码+全专家遍历 +```python +else: + expert_mask = nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = ops.nonzero(expert_mask[expert_idx], as_tuple=True) + if 0 not in idx.shape: + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + final_hidden_states = final_hidden_states.index_add( + 0, top_x.int(), current_hidden_states.to(hidden_states.dtype) + ) +``` +这部分复用原本的MoE部分代码,遍历全部专家。 + +这样把prefill和decode分开来写带来了在decode阶段有达到秒级的收益。这样做完总分提升到了160+。 + + +#### 2. 使用ops算子替换索引操作 + +```python +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + # x1 = x[..., : x.shape[-1] // 2] + # x2 = x[..., x.shape[-1] // 2 :] + x1, x2 = ops.split(x, x.shape[-1] // 2, dim = -1) + return ops.cat((-x2, x1), dim=-1) + + +... + + # cos = cos[position_ids].unsqueeze(unsqueeze_dim) + # sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = F.embedding(position_ids, cos).unsqueeze(unsqueeze_dim) + sin = F.embedding(position_ids, sin).unsqueeze(unsqueeze_dim) + +... + + hidden_states_expand = ops.unsqueeze(hidden_states, 2) + hidden_states = hidden_states_expand.broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) + +... + # self.cos_cached[:seq_len].to(dtype=x.dtype), + # self.sin_cached[:seq_len].to(dtype=x.dtype), + ops.narrow(self.cos_cached, 0, 0, seq_len).to(dtype=x.dtype), + ops.narrow(self.sin_cached, 0, 0, seq_len).to(dtype=x.dtype), + +... + + attention_mask_expanded = ops.unsqueeze(ops.unsqueeze(attention_mask, dim=1), dim=2) + padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask_expanded + +``` +使用ops替换掉tensor索引的一些操作能够带来一些细微的收益(几百ms),有的替换甚至没有收益(也可能是那块代码没执行)。 + +#### 3. 使用StaticCache,并开启JIT优化 + +使用StaticCache需要先在cache_utils.py文件中修改一下StaticCache类的update函数中的某一部分(被注释掉的是原来的): +```python + else: + # use index_add for mindspore since tensor slice is too slow and no implementation of index_copy + # k_out = ops.index_add(k_out, 2, cache_position.int(), key_states) + # v_out = ops.index_add(v_out, 2, cache_position.int(), value_states) + k_out.index_add_(2, cache_position.int(), key_states) + v_out.index_add_(2, cache_position.int(), value_states) +``` + +由于默认是DynamicCache,因此需要在utils.py中generate接口处修改一些地方,支持StaticCache的创建使用 +```python +else: + model_type = getattr(self.config, 'model_type', '') + supports_cache_position = model_type in ['qwen2_moe'] + # print('StaticCache') + if ( + hasattr(self, '_supports_static_cache') + and self._supports_static_cache + and not requires_cross_attention_cache + and supports_cache_position + ): + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + # print('StaticCache') + model_kwargs[cache_name] = self._get_cache( + cache_implementation="static", + max_batch_size=batch_size, + max_cache_len=max_cache_length, + model_kwargs=model_kwargs, + ) + else: + num_hidden_layers = self.config.get_text_config().num_hidden_layers + model_kwargs[cache_name] = ( + DynamicCache(num_hidden_layers) + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) + ) +``` +这里StaticCache只限制了在qwen2_moe中使用,因为我尝试在deepseek中使用最后好像出现了mismatch,可能精度上有一些损失。 + +最后JIT优化使用在了RMSNorm部分,还在Qwen2MoeModel的_update_causal_mask部分加了JIT,都带来了收益。但是不知道为什么使用rms_norm的融合算子会导致mismatch,可能是精度有误差,所以放弃了融合算子,同样flash attention也会有这个问题。 +```python +@mindspore.jit +def forward(self, hidden_states): + # if use_pyboost(): + # return F.rms_norm(hidden_states, self.weight, self.variance_epsilon) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mindspore.float32) + variance = ops.mean(hidden_states.pow(2), -1, keepdim=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +... + +@mindspore.jit +def _update_causal_mask( + self, + attention_mask: mindspore.Tensor, + input_tensor: mindspore.Tensor, + cache_position: mindspore.Tensor, + past_key_values: Cache, + output_attentions: bool, +): +``` + +最后JIT优化牺牲了一些prefill的时延,但是对decode的时延带来了很大的提升,所以最后总分也提升到了280+。 + +--- +#### 最终评测结果 + +| 评测指标 | 平均得分 | +| :-------: | :----------: | +| 峰值显存 | 100 | +| Prefill时延 | 62.3694 | +| Decode时延 | 696.3557 | +| **总分** | **286.2417** | diff --git "a/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/patches.zip" "b/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/patches.zip" new file mode 100644 index 00000000..cd92eb43 Binary files /dev/null and "b/2025-Ascend-Innovation-Contest/S1/MoE/\346\210\221\346\203\263\345\245\275\345\245\275\347\235\241\344\270\200\350\247\211/patches.zip" differ