Skip to content

Commit db95bb7

Browse files
committed
[None][feat] Optimize nemotron-h from python level
* Enable more c++ routing combinations. * Update mamba tensor operations. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent e79eac2 commit db95bb7

2 files changed

Lines changed: 19 additions & 16 deletions

File tree

tensorrt_llm/_torch/modules/fused_moe/routing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ def noaux_tc(self, logits, e_score_correction_bias):
264264
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
265265
)
266266
self.is_fused = False
267-
elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22)
268-
or (self.topk_group == 1 and self.top_k != 22)):
269-
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
267+
elif num_experts > 512 or (self.top_k > 8 and self.top_k != 22):
268+
# The fused noaux_tc_op kernel supports n_group==1 with top_k<=8
269+
# or top_k==22, and num_experts<=512.
270270
if self.is_fused:
271271
warnings.warn(
272272
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ def post_load_weights(self):
240240
and self.norm.nvfp4_scale is None):
241241
self._try_attach_nvfp4_scale()
242242

243+
# Pre-expand A, D, dt_bias for the decode path.
244+
self._A_expanded = repeat(self.A,
245+
"h -> h p n",
246+
p=self.head_dim,
247+
n=self.d_state).to(dtype=torch.float32)
248+
self._dt_bias_expanded = repeat(self.dt_bias,
249+
"h -> h p",
250+
p=self.head_dim)
251+
self._D_expanded = repeat(self.D, "h -> h p", p=self.head_dim)
252+
243253
def _try_attach_nvfp4_scale(self):
244254
"""Attach input_scale from out_proj to norm for fused RMSNorm+Quant."""
245255

@@ -454,22 +464,15 @@ def convert_dt():
454464
],
455465
dim=-1,
456466
)
457-
# Use .contiguous() to ensure proper 128-byte alignment required by
458-
# flashinfer's selective_state_update kernel. x_d, B_d, C_d are views
459-
# into sliced tensors which may not be 128-byte aligned.
460-
x_d = rearrange(x_d, "b (h p) -> b h p",
461-
p=self.head_dim).contiguous()
467+
x_d = rearrange(x_d, "b (h p) -> b h p", p=self.head_dim)
462468
dt_d = repeat(dt_d, "b h -> b h p", p=self.head_dim)
463-
B_d = rearrange(B_d, "b (g n) -> b g n",
464-
g=self.tp_ngroups).contiguous()
465-
C_d = rearrange(C_d, "b (g n) -> b g n",
466-
g=self.tp_ngroups).contiguous()
469+
B_d = rearrange(B_d, "b (g n) -> b g n", g=self.tp_ngroups)
470+
C_d = rearrange(C_d, "b (g n) -> b g n", g=self.tp_ngroups)
467471
z_d = rearrange(z_d, "b (h p) -> b h p", p=self.head_dim)
468472

469-
A = repeat(self.A, "h -> h p n", p=self.head_dim,
470-
n=self.d_state).to(dtype=torch.float32)
471-
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim)
472-
D = repeat(self.D, "h -> h p", p=self.head_dim)
473+
A = self._A_expanded
474+
dt_bias = self._dt_bias_expanded
475+
D = self._D_expanded
473476
if is_target_verify:
474477
intermediate_ssm_states = layer_cache.intermediate_ssm
475478
# Build kwargs for MTP selective_state_update

0 commit comments

Comments
 (0)