-
Notifications
You must be signed in to change notification settings - Fork 103
2025年昇腾AI创新大赛-昇思模型开发挑战赛(S1赛季)--MoE赛题--桶桶罐罐队提交 #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chen-luzl
wants to merge
3
commits into
mindspore-lab:dev
Choose a base branch
from
chen-luzl:dev
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # MindNLP MoE 模型推理优化说明 | ||
|
|
||
| ## 优化概述 | ||
|
|
||
| 本次提交主要针对 **DeepSeek** 和 **Qwen2-MoE** 模型的 Mixture-of-Experts (MoE) 模块进行了性能调优。针对 MoE 架构中“动态路由”导致的算子碎片化和 Host-Device 同步频繁的问题,采用了 **CPU (NumPy) 辅助索引计算** 的策略,显著减少了 NPU 的流水线气泡,提升了推理吞吐量。 | ||
|
|
||
| ## 核心优化详解 | ||
|
|
||
| ### 1\. DeepSeek 模型优化 | ||
|
|
||
| **MoE 路由机制重构 (CPU Offload)** | ||
|
|
||
| - **问题**: 原有逻辑在 NPU 上执行 `bincount` 和 `cumsum` 等动态 Shape 算子,或者对所有专家进行无差别的循环,导致计算资源浪费且图编译困难。 | ||
| - **优化方案**: | ||
| 1. **元数据回传**: 使用 `.asnumpy()` 将 Token 计数张量传回 CPU。 | ||
| 2. **CPU 预处理**: 利用 NumPy 计算累积和(cumsum),并筛选出分配了 Token 的“有效专家”。 | ||
| 3. **精准计算**: NPU 仅对包含有效 Token 的专家启动计算 Kernel,跳过空闲专家。 | ||
| 4. **算子升级**: 使用 `mindspore.mint.scatter_add` 替代旧的 `scatter_reduce_`,提升聚合性能。 | ||
|
|
||
| <!-- end list --> | ||
|
|
||
| ```python | ||
| # 优化后逻辑示意 | ||
| tokens_count_cpu = tokens_count.asnumpy() # 回传 CPU | ||
| cumulative_sum = tokens_count_cpu.cumsum() | ||
|
|
||
| # 在 CPU 上构建任务列表,过滤掉 count=0 的专家 | ||
| valid_experts = [] | ||
| for i, end_idx in enumerate(cumulative_sum): | ||
| if count > 0: | ||
| valid_experts.append(...) | ||
|
|
||
| # NPU 仅执行有效计算 | ||
| for expert_idx, ... in valid_experts: | ||
| expert_out = expert(expert_tokens) | ||
| expert_cache = mindspore.mint.scatter_add(...) | ||
| ``` | ||
|
|
||
| **通用算子优化** | ||
|
|
||
| - **RoPE 切分**: 将 Python 切片 `x[..., :half]` 替换为 `ops.split`,显式算子有助于底层内存规划,避免隐式拷贝。 | ||
|
|
||
| ### 2\. Qwen2-MoE 模型优化 | ||
|
|
||
| **稀疏索引批量预计算** | ||
|
|
||
| - **问题**: 原代码在专家循环内部调用 `ops.nonzero`。这是一个同步算子,会强制 NPU 等待计算结果返回 CPU 才能进行下一步,导致严重的 **Pipeline Bubble**(流水线气泡)。 | ||
| - **优化方案**: | ||
| 1. **批量获取 Mask**: 将整个 `expert_mask` 转为 NumPy 数组。 | ||
| 2. **一次性索引**: 使用 `np.nonzero` 一次性获取所有专家的索引信息。 | ||
| 3. **字典构建**: 在 CPU 上预先构建好 `{expert_id: indices}` 的映射表。 | ||
| 4. **无阻塞执行**: 循环内直接使用预计算好的索引调用 `index_add`,NPU 无需等待,可连续下发任务。 | ||
|
|
||
| <!-- end list --> | ||
|
|
||
| ```python | ||
| # 优化前:循环内同步,慢 | ||
| # for expert_idx in range(num_experts): | ||
| # idx, top_x = ops.nonzero(expert_mask[expert_idx]) # 阻塞点 | ||
|
|
||
| # 优化后:CPU 预计算,快 | ||
| expert_mask_np = expert_mask.asnumpy() | ||
| all_expert_idxs, ... = np.nonzero(expert_mask_np) # 一次性计算 | ||
| # ... 构建 expert_indices 字典 ... | ||
|
|
||
| for expert_idx, (idx, top_x) in expert_indices.items(): | ||
| # 无阻塞,高速下发 | ||
| final_hidden_states = final_hidden_states.index_add(...) | ||
| ``` | ||
|
|
||
| ## 技术总结 | ||
|
|
||
| 1. **消除 Host-Device 同步**: 通过将 `nonzero`、`bincount` 等依赖数据内容的动态算子移至 CPU 执行,消除了 NPU 推理过程中的同步阻塞点。 | ||
| 2. **稀疏转稠密**: 在 CPU 上过滤掉无任务的专家,确保 NPU 始终执行稠密计算,提升算力利用率。 | ||
| 3. **NumPy 极速处理**: 对于 Token 计数、索引这类小规模元数据,NumPy 在 CPU 上的处理速度远快于在 NPU 上启动大量微小 Kernel 的开销。 | ||
|
|
||
| ## 优化收益 | ||
|
|
||
| - **DeepSeek**: 减少无效计算,Decode 阶段延时显著降低。 | ||
| - **Qwen2-MoE**: 消除循环内的同步等待,极大提升了 Expert 层的并行执行效率。 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充下最后得分列表
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改了 |
||
|
|
||
| ## 评测结果 | ||
|
|
||
| | 评测指标 | 平均得分 | | ||
| |---------|---------| | ||
| | 峰值显存得分 | 100 | | ||
| | Prefill时延得分 | 145.0792 | | ||
| | Decode时延得分 | 308.0527 | | ||
| | **总分** | **184.3773** | | ||
Binary file not shown.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充下最终得分数据