[None][feat] Optimize nemotron-h from python level#13032
[None][feat] Optimize nemotron-h from python level#13032Wanli-Jiang wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
* Enable more c++ routing combinations. * Update mamba tensor operations. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
db95bb7 to
f9beed0
Compare
📝 WalkthroughWalkthroughTwo changes optimize tensor-handling logic: one simplifies routing eligibility conditions in a fused MOE implementation by removing an edge-case constraint, and another optimizes the Mamba2 decode path by pre-computing and caching expanded tensors rather than recomputing them per call. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Around line 243-251: post_load_weights currently assigns plain tensor
attributes (_A_expanded, _dt_bias_expanded, _D_expanded) which are not moved
with module.to()/cuda(); replace those plain attributes with registered buffers
using self.register_buffer("<name>", tensor, persistent=False) so they follow
module device/dtype moves—i.e., register _A_expanded (ensure
.to(dtype=torch.float32) as before), _dt_bias_expanded, and _D_expanded after
creating them with repeat, using the same names so existing code that references
self._A_expanded, self._dt_bias_expanded, and self._D_expanded continues to
work.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: f424cb9b-0111-4ba8-b308-63bc342fce7d
📒 Files selected for processing (2)
tensorrt_llm/_torch/modules/fused_moe/routing.pytensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
| # Pre-expand A, D, dt_bias for the decode path. | ||
| self._A_expanded = repeat(self.A, | ||
| "h -> h p n", | ||
| p=self.head_dim, | ||
| n=self.d_state).to(dtype=torch.float32) | ||
| self._dt_bias_expanded = repeat(self.dt_bias, | ||
| "h -> h p", | ||
| p=self.head_dim) | ||
| self._D_expanded = repeat(self.D, "h -> h p", p=self.head_dim) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "[1] Verify where expanded caches are created/used and whether buffers are used"
rg -n --type=py '_A_expanded|_dt_bias_expanded|_D_expanded|register_buffer' -C2
echo "[2] Verify all post_load_weights call sites"
rg -n --type=py '\bpost_load_weights\s*\(' -C2
echo "[3] Inspect potential ordering of post_load_weights vs device moves"
rg -n --type=py 'Mamba2Mixer\s*\(|\.to\s*\(|\.cuda\s*\(' -C2Repository: NVIDIA/TensorRT-LLM
Length of output: 50377
🏁 Script executed:
#!/bin/bash
# Get the exact code at the lines mentioned in the review
echo "=== Lines 240-260 (cache creation area) ==="
sed -n '240,260p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
echo ""
echo "=== Lines 465-480 (cache usage area) ==="
sed -n '465,480p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
echo ""
echo "=== Check for .to() or .cuda() calls in __init__ or forward ==="
grep -n "def __init__\|def forward\|def post_load_weights\|\.to(\|\.cuda(" tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py | head -30
echo ""
echo "=== Check if any buffers are already registered ==="
grep -n "register_buffer" tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2113
🏁 Script executed:
#!/bin/bash
echo "=== Find all calls to post_load_weights() ==="
rg -n "\.post_load_weights\(\)" --type=py -A3 -B3
echo ""
echo "=== Check if Mamba2Mixer is instantiated and then moved to device ==="
rg -n "Mamba2Mixer" --type=py tensorrt_llm/ | head -20
echo ""
echo "=== Check initialization of Mamba2Mixer and __init__ signature ==="
sed -n '50,120p' tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 27744
Cache tensors as buffers to ensure device consistency with module.
The plain tensor attributes _A_expanded, _dt_bias_expanded, and _D_expanded created in post_load_weights() at lines 243–251 are not managed by PyTorch's module.to()/cuda() semantics. If the module is moved to device before post_load_weights() is called, these cached tensors remain on their original device, causing a device mismatch when consumed in the decode path at lines 473–475.
Use register_buffer(persistent=False) instead to ensure device consistency.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` around lines 243 - 251,
post_load_weights currently assigns plain tensor attributes (_A_expanded,
_dt_bias_expanded, _D_expanded) which are not moved with module.to()/cuda();
replace those plain attributes with registered buffers using
self.register_buffer("<name>", tensor, persistent=False) so they follow module
device/dtype moves—i.e., register _A_expanded (ensure .to(dtype=torch.float32)
as before), _dt_bias_expanded, and _D_expanded after creating them with repeat,
using the same names so existing code that references self._A_expanded,
self._dt_bias_expanded, and self._D_expanded continues to work.
|
/bot run --disable-fail-fast |
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.