From 63be7c7bbaa418952730f1f7123f149c0acd9350 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Fri, 29 May 2026 20:12:42 -0700 Subject: [PATCH 1/5] Enable NVFP4 grouped MLP cuDNN wgrad Signed-off-by: Siddhartha Raman S --- .../pytorch/ops/fused/backward_grouped_mlp.py | 120 +++++++++++++----- 1 file changed, 91 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 792b6d7811..54fd1b436e 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -88,6 +88,10 @@ def _cudnn_compute_wgrad( accumulate: bool, wgrad_kernel_fn, single_grouped_weight: bool, + use_nvfp4: bool, + data_dtype: torch.dtype, + scale_view_dtype: torch.dtype, + sf_vec_size: int, current_stream=None, ): """Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel. @@ -100,8 +104,6 @@ def _cudnn_compute_wgrad( out_features, in_features = weight_shape total_tokens = grouped_dy.logical_shape[0] - fp8_dtype = torch.float8_e4m3fn - sfa_leading_dim = round_up_to_nearest_multiple(out_features, 128) sfb_leading_dim = round_up_to_nearest_multiple(in_features, 128) @@ -110,68 +112,105 @@ def _cudnn_compute_wgrad( # Even for this case, cuteDSL still requires the same # stride requirements for the input and scale tensors. device = grouped_dy.columnwise_data.device - a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device) + a_tensor = torch.empty_strided( + (out_features, 0), + (16, 1), + dtype=data_dtype, + device=device, + ) b_tensor = torch.empty_strided( - (0, in_features), (in_features, 1), dtype=fp8_dtype, device=device + (0, in_features), + (in_features, 1), + dtype=data_dtype, + device=device, ) sfa_tensor = torch.empty_strided( (sfa_leading_dim, 0), (16, 1), - dtype=torch.float8_e8m0fnu, + dtype=scale_view_dtype, device=device, ) sfb_tensor = torch.empty_strided( (sfb_leading_dim, 0), (16, 1), - dtype=torch.float8_e8m0fnu, + dtype=scale_view_dtype, device=device, ) + elif use_nvfp4: + # NVFP4 columnwise data is stored expert-major as per-expert + # (logical_K, group_M / 2) chunks. cuDNN consumes that layout directly + # with input_order="tensor_ragged". + a_tensor = grouped_dy.columnwise_data.view(dtype=data_dtype).view( + out_features, + total_tokens // 2, + ) + b_tensor = grouped_x.columnwise_data.view(dtype=data_dtype).view( + in_features, + total_tokens // 2, + ).T + sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( + dtype=scale_view_dtype + ) + sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( + dtype=scale_view_dtype + ) else: a_tensor = ( - grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T + grouped_dy.columnwise_data.view(dtype=data_dtype).view(total_tokens, out_features).T ) - b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) + b_tensor = grouped_x.columnwise_data.view(dtype=data_dtype).view(total_tokens, in_features) sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu + dtype=scale_view_dtype ) sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu + dtype=scale_view_dtype + ) + + global_scale_a = None + global_scale_b = None + if use_nvfp4: + global_scale_denom = 448.0 * 6.0 + global_scale_a = ( + grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom ) + global_scale_b = ( + grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + ) + + common_wgrad_kwargs = { + "a_tensor": a_tensor, + "b_tensor": b_tensor, + "sfa_tensor": sfa_tensor, + "sfb_tensor": sfb_tensor, + "offsets_tensor": offsets, + "global_scale_a": global_scale_a, + "global_scale_b": global_scale_b, + "acc_dtype": torch.float32, + "sf_vec_size": sf_vec_size, + "accumulate_on_output": accumulate, + "current_stream": current_stream, + } + if use_nvfp4: + common_wgrad_kwargs["input_order"] = "tensor_ragged" # Prepare wgrad output if single_grouped_weight: # Dense mode: single (num_groups, out_features, in_features) tensor wgrad_tensor = wgrad_output.rowwise_data.view(offsets.shape[0], out_features, in_features) wgrad_kernel_fn( - a_tensor=a_tensor, - b_tensor=b_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - offsets_tensor=offsets, + **common_wgrad_kwargs, output_mode="dense", wgrad_tensor=wgrad_tensor, - acc_dtype=torch.float32, wgrad_dtype=wgrad_tensor.dtype, - sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, - accumulate_on_output=accumulate, - current_stream=current_stream, ) else: # Discrete mode: per-expert wgrad device pointers wgrad_ptrs = tex.copy_data_ptrs_to_device(wgrad_output, wgrad_output[0].device) wgrad_kernel_fn( - a_tensor=a_tensor, - b_tensor=b_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - offsets_tensor=offsets, + **common_wgrad_kwargs, output_mode="discrete", wgrad_ptrs=wgrad_ptrs, - acc_dtype=torch.float32, wgrad_dtype=wgrad_output[0].dtype, - sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, - accumulate_on_output=accumulate, - current_stream=current_stream, ) @@ -189,6 +228,10 @@ def _compute_grad_params( label="", *, cudnn_wgrad_kernel_fn, + use_nvfp4, + data_dtype, + scale_view_dtype, + sf_vec_size, offsets, ): """Compute weight gradients and build grad_params for a GroupedLinear layer. @@ -258,6 +301,10 @@ def _compute_grad_params( accumulate=accumulate_into_main_grad, wgrad_kernel_fn=cudnn_wgrad_kernel_fn, single_grouped_weight=fc_op.single_grouped_weight, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, current_stream=torch.cuda.current_stream().cuda_stream, ) elif ( @@ -793,7 +840,14 @@ def fuser_backward( ) # FC2 wgrad GEMM - wgrad_kernel_fn = None if use_nvfp4 else self.grouped_gemm_wgrad_kernel() + enable_nvfp4_wgrad = ( + os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4_WGRAD", "0") == "1" + ) + wgrad_kernel_fn = ( + self.grouped_gemm_wgrad_kernel() + if (not use_nvfp4 or enable_nvfp4_wgrad) + else None + ) fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -807,6 +861,10 @@ def fuser_backward( bias_grad_packed=fc2_bias_grad_packed, label="FC2", cudnn_wgrad_kernel_fn=wgrad_kernel_fn, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, offsets=split_points, ) @@ -947,6 +1005,10 @@ def fuser_backward( bias_grad_packed=fc1_bias_grad_packed, label="FC1", cudnn_wgrad_kernel_fn=wgrad_kernel_fn, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, offsets=split_points, ) From 39632abcb0c85daf2be3048bf57bc07015902890 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:41:51 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/ops/fused/backward_grouped_mlp.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 54fd1b436e..f15a75c84f 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -144,10 +144,14 @@ def _cudnn_compute_wgrad( out_features, total_tokens // 2, ) - b_tensor = grouped_x.columnwise_data.view(dtype=data_dtype).view( - in_features, - total_tokens // 2, - ).T + b_tensor = ( + grouped_x.columnwise_data.view(dtype=data_dtype) + .view( + in_features, + total_tokens // 2, + ) + .T + ) sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( dtype=scale_view_dtype ) @@ -170,12 +174,8 @@ def _cudnn_compute_wgrad( global_scale_b = None if use_nvfp4: global_scale_denom = 448.0 * 6.0 - global_scale_a = ( - grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom - ) - global_scale_b = ( - grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom - ) + global_scale_a = grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + global_scale_b = grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom common_wgrad_kwargs = { "a_tensor": a_tensor, @@ -844,9 +844,7 @@ def fuser_backward( os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4_WGRAD", "0") == "1" ) wgrad_kernel_fn = ( - self.grouped_gemm_wgrad_kernel() - if (not use_nvfp4 or enable_nvfp4_wgrad) - else None + self.grouped_gemm_wgrad_kernel() if (not use_nvfp4 or enable_nvfp4_wgrad) else None ) fc2_grad_params = _compute_grad_params( fc_op=fc2_op, From 7e5c96ca48852b45abfac16a4a3dabafd7f5d3ed Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Mon, 1 Jun 2026 15:44:48 -0700 Subject: [PATCH 3/5] Address NVFP4 wgrad review comments Signed-off-by: Siddhartha Raman S --- .../pytorch/ops/fused/backward_grouped_mlp.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index f15a75c84f..1a46cc9fbf 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -170,27 +170,25 @@ def _cudnn_compute_wgrad( dtype=scale_view_dtype ) - global_scale_a = None - global_scale_b = None - if use_nvfp4: - global_scale_denom = 448.0 * 6.0 - global_scale_a = grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom - global_scale_b = grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom - common_wgrad_kwargs = { "a_tensor": a_tensor, "b_tensor": b_tensor, "sfa_tensor": sfa_tensor, "sfb_tensor": sfb_tensor, "offsets_tensor": offsets, - "global_scale_a": global_scale_a, - "global_scale_b": global_scale_b, "acc_dtype": torch.float32, "sf_vec_size": sf_vec_size, "accumulate_on_output": accumulate, "current_stream": current_stream, } if use_nvfp4: + global_scale_denom = 448.0 * 6.0 + common_wgrad_kwargs["global_scale_a"] = ( + grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + ) + common_wgrad_kwargs["global_scale_b"] = ( + grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + ) common_wgrad_kwargs["input_order"] = "tensor_ragged" # Prepare wgrad output @@ -840,12 +838,7 @@ def fuser_backward( ) # FC2 wgrad GEMM - enable_nvfp4_wgrad = ( - os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4_WGRAD", "0") == "1" - ) - wgrad_kernel_fn = ( - self.grouped_gemm_wgrad_kernel() if (not use_nvfp4 or enable_nvfp4_wgrad) else None - ) + wgrad_kernel_fn = self.grouped_gemm_wgrad_kernel() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, From c03367fe5fe0327b12fb4cc5a31fa687dfff3a5d Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Mon, 1 Jun 2026 16:18:27 -0700 Subject: [PATCH 4/5] Use NVFP4 amax helper for wgrad scales Signed-off-by: Siddhartha Raman S --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 1a46cc9fbf..dd784a0d48 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -184,10 +184,10 @@ def _cudnn_compute_wgrad( if use_nvfp4: global_scale_denom = 448.0 * 6.0 common_wgrad_kwargs["global_scale_a"] = ( - grouped_dy.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + _nvfp4_amax(grouped_dy, columnwise=True).to(torch.float32) / global_scale_denom ) common_wgrad_kwargs["global_scale_b"] = ( - grouped_x.columnwise_amax.view(-1).to(torch.float32) / global_scale_denom + _nvfp4_amax(grouped_x, columnwise=True).to(torch.float32) / global_scale_denom ) common_wgrad_kwargs["input_order"] = "tensor_ragged" From bef69c9cc6d47044320bc6458909d7114e16e2b9 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman S Date: Tue, 2 Jun 2026 10:36:41 -0700 Subject: [PATCH 5/5] Skip NVFP4 wgrad amax lookup for empty tokens Signed-off-by: Siddhartha Raman S --- .../pytorch/ops/fused/backward_grouped_mlp.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index dd784a0d48..4fd1220135 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -183,12 +183,25 @@ def _cudnn_compute_wgrad( } if use_nvfp4: global_scale_denom = 448.0 * 6.0 - common_wgrad_kwargs["global_scale_a"] = ( - _nvfp4_amax(grouped_dy, columnwise=True).to(torch.float32) / global_scale_denom - ) - common_wgrad_kwargs["global_scale_b"] = ( - _nvfp4_amax(grouped_x, columnwise=True).to(torch.float32) / global_scale_denom - ) + if total_tokens == 0: + global_scale_shape = (offsets.shape[0],) + common_wgrad_kwargs["global_scale_a"] = torch.zeros( + global_scale_shape, + dtype=torch.float32, + device=device, + ) + common_wgrad_kwargs["global_scale_b"] = torch.zeros( + global_scale_shape, + dtype=torch.float32, + device=device, + ) + else: + common_wgrad_kwargs["global_scale_a"] = ( + _nvfp4_amax(grouped_dy, columnwise=True).to(torch.float32) / global_scale_denom + ) + common_wgrad_kwargs["global_scale_b"] = ( + _nvfp4_amax(grouped_x, columnwise=True).to(torch.float32) / global_scale_denom + ) common_wgrad_kwargs["input_order"] = "tensor_ragged" # Prepare wgrad output