@@ -240,6 +240,16 @@ def post_load_weights(self):
240240 and self .norm .nvfp4_scale is None ):
241241 self ._try_attach_nvfp4_scale ()
242242
243+ # Pre-expand A, D, dt_bias for the decode path.
244+ self ._A_expanded = repeat (self .A ,
245+ "h -> h p n" ,
246+ p = self .head_dim ,
247+ n = self .d_state ).to (dtype = torch .float32 )
248+ self ._dt_bias_expanded = repeat (self .dt_bias ,
249+ "h -> h p" ,
250+ p = self .head_dim )
251+ self ._D_expanded = repeat (self .D , "h -> h p" , p = self .head_dim )
252+
243253 def _try_attach_nvfp4_scale (self ):
244254 """Attach input_scale from out_proj to norm for fused RMSNorm+Quant."""
245255
@@ -454,22 +464,15 @@ def convert_dt():
454464 ],
455465 dim = - 1 ,
456466 )
457- # Use .contiguous() to ensure proper 128-byte alignment required by
458- # flashinfer's selective_state_update kernel. x_d, B_d, C_d are views
459- # into sliced tensors which may not be 128-byte aligned.
460- x_d = rearrange (x_d , "b (h p) -> b h p" ,
461- p = self .head_dim ).contiguous ()
467+ x_d = rearrange (x_d , "b (h p) -> b h p" , p = self .head_dim )
462468 dt_d = repeat (dt_d , "b h -> b h p" , p = self .head_dim )
463- B_d = rearrange (B_d , "b (g n) -> b g n" ,
464- g = self .tp_ngroups ).contiguous ()
465- C_d = rearrange (C_d , "b (g n) -> b g n" ,
466- g = self .tp_ngroups ).contiguous ()
469+ B_d = rearrange (B_d , "b (g n) -> b g n" , g = self .tp_ngroups )
470+ C_d = rearrange (C_d , "b (g n) -> b g n" , g = self .tp_ngroups )
467471 z_d = rearrange (z_d , "b (h p) -> b h p" , p = self .head_dim )
468472
469- A = repeat (self .A , "h -> h p n" , p = self .head_dim ,
470- n = self .d_state ).to (dtype = torch .float32 )
471- dt_bias = repeat (self .dt_bias , "h -> h p" , p = self .head_dim )
472- D = repeat (self .D , "h -> h p" , p = self .head_dim )
473+ A = self ._A_expanded
474+ dt_bias = self ._dt_bias_expanded
475+ D = self ._D_expanded
473476 if is_target_verify :
474477 intermediate_ssm_states = layer_cache .intermediate_ssm
475478 # Build kwargs for MTP selective_state_update
0 commit comments