feats(transformers):add longcat_flash model#1443
feats(transformers):add longcat_flash model#1443iugoood wants to merge 1 commit intomindspore-lab:masterfrom
Conversation
Summary of ChangesHello @iugoood, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the longcat_flash model. The implementation is comprehensive, but there are a couple of significant performance concerns in the model's implementation that should be addressed, particularly an inefficient loop in the Mixture of Experts (MoE) layer and a suboptimal implementation of rotary position embeddings. Additionally, a bug in the new test file will prevent the tests from running successfully. My review provides specific feedback on these points.
|
|
||
| input_mask = None | ||
| if self.use_input_mask: | ||
| input_mask = np.tril(np.ones_like(self.batch_size, self.seq_length)) |
There was a problem hiding this comment.
The use of np.ones_like here is incorrect. np.ones_like expects an array-like object as its first argument to determine the shape and dtype, but self.batch_size is an integer. This will raise a TypeError. You should use np.ones((self.batch_size, self.seq_length)) instead to create an array of the desired shape.
| input_mask = np.tril(np.ones_like(self.batch_size, self.seq_length)) | |
| input_mask = np.tril(np.ones((self.batch_size, self.seq_length))) |
| for expert_idx in range(len(self.experts)): | ||
| expert = self.experts[expert_idx] | ||
| mask = expert_mask[expert_idx] | ||
| token_indices, weight_indices = mindspore.mint.where(mask) | ||
|
|
||
| if token_indices.numel() > 0: | ||
| expert_weights = topk_weights[token_indices, weight_indices] | ||
| expert_input = hidden_states[token_indices] | ||
| expert_output = expert(expert_input) | ||
| weighted_output = expert_output * expert_weights.unsqueeze(-1) | ||
| final_hidden_states.index_add_(0, token_indices, weighted_output) |
There was a problem hiding this comment.
The for loop over experts in the moe method is inefficient and will be a significant performance bottleneck, especially for models with a large number of experts. This should be vectorized to process experts in a batch. A common approach is to use batched matrix multiplication or similar techniques. The docstring for this method already calls this out as needing optimization.
| b, h, s, d = q.shape | ||
| q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) | ||
|
|
||
| b, h, s, d = k.shape | ||
| k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) |
There was a problem hiding this comment.
This implementation of applying rotary position embeddings is inefficient due to multiple view, transpose, and reshape operations, as noted in the function's docstring. These operations can be computationally expensive and should be refactored for better performance, for instance by using a more direct computation method.
ffbf152 to
b4fb9f0
Compare
5aefb5d to
1297b37
Compare
Add
1 add longcat_flash model
2 add UT
Notes
moe models with 560b params, not validated with real-weights.