diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 792b6d7811..4fd1220135 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -88,6 +88,10 @@ def _cudnn_compute_wgrad( accumulate: bool, wgrad_kernel_fn, single_grouped_weight: bool, + use_nvfp4: bool, + data_dtype: torch.dtype, + scale_view_dtype: torch.dtype, + sf_vec_size: int, current_stream=None, ): """Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel. @@ -100,8 +104,6 @@ def _cudnn_compute_wgrad( out_features, in_features = weight_shape total_tokens = grouped_dy.logical_shape[0] - fp8_dtype = torch.float8_e4m3fn - sfa_leading_dim = round_up_to_nearest_multiple(out_features, 128) sfb_leading_dim = round_up_to_nearest_multiple(in_features, 128) @@ -110,68 +112,116 @@ def _cudnn_compute_wgrad( # Even for this case, cuteDSL still requires the same # stride requirements for the input and scale tensors. device = grouped_dy.columnwise_data.device - a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device) + a_tensor = torch.empty_strided( + (out_features, 0), + (16, 1), + dtype=data_dtype, + device=device, + ) b_tensor = torch.empty_strided( - (0, in_features), (in_features, 1), dtype=fp8_dtype, device=device + (0, in_features), + (in_features, 1), + dtype=data_dtype, + device=device, ) sfa_tensor = torch.empty_strided( (sfa_leading_dim, 0), (16, 1), - dtype=torch.float8_e8m0fnu, + dtype=scale_view_dtype, device=device, ) sfb_tensor = torch.empty_strided( (sfb_leading_dim, 0), (16, 1), - dtype=torch.float8_e8m0fnu, + dtype=scale_view_dtype, device=device, ) + elif use_nvfp4: + # NVFP4 columnwise data is stored expert-major as per-expert + # (logical_K, group_M / 2) chunks. cuDNN consumes that layout directly + # with input_order="tensor_ragged". + a_tensor = grouped_dy.columnwise_data.view(dtype=data_dtype).view( + out_features, + total_tokens // 2, + ) + b_tensor = ( + grouped_x.columnwise_data.view(dtype=data_dtype) + .view( + in_features, + total_tokens // 2, + ) + .T + ) + sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( + dtype=scale_view_dtype + ) + sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( + dtype=scale_view_dtype + ) else: a_tensor = ( - grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T + grouped_dy.columnwise_data.view(dtype=data_dtype).view(total_tokens, out_features).T ) - b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) + b_tensor = grouped_x.columnwise_data.view(dtype=data_dtype).view(total_tokens, in_features) sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu + dtype=scale_view_dtype ) sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu + dtype=scale_view_dtype ) + common_wgrad_kwargs = { + "a_tensor": a_tensor, + "b_tensor": b_tensor, + "sfa_tensor": sfa_tensor, + "sfb_tensor": sfb_tensor, + "offsets_tensor": offsets, + "acc_dtype": torch.float32, + "sf_vec_size": sf_vec_size, + "accumulate_on_output": accumulate, + "current_stream": current_stream, + } + if use_nvfp4: + global_scale_denom = 448.0 * 6.0 + if total_tokens == 0: + global_scale_shape = (offsets.shape[0],) + common_wgrad_kwargs["global_scale_a"] = torch.zeros( + global_scale_shape, + dtype=torch.float32, + device=device, + ) + common_wgrad_kwargs["global_scale_b"] = torch.zeros( + global_scale_shape, + dtype=torch.float32, + device=device, + ) + else: + common_wgrad_kwargs["global_scale_a"] = ( + _nvfp4_amax(grouped_dy, columnwise=True).to(torch.float32) / global_scale_denom + ) + common_wgrad_kwargs["global_scale_b"] = ( + _nvfp4_amax(grouped_x, columnwise=True).to(torch.float32) / global_scale_denom + ) + common_wgrad_kwargs["input_order"] = "tensor_ragged" + # Prepare wgrad output if single_grouped_weight: # Dense mode: single (num_groups, out_features, in_features) tensor wgrad_tensor = wgrad_output.rowwise_data.view(offsets.shape[0], out_features, in_features) wgrad_kernel_fn( - a_tensor=a_tensor, - b_tensor=b_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - offsets_tensor=offsets, + **common_wgrad_kwargs, output_mode="dense", wgrad_tensor=wgrad_tensor, - acc_dtype=torch.float32, wgrad_dtype=wgrad_tensor.dtype, - sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, - accumulate_on_output=accumulate, - current_stream=current_stream, ) else: # Discrete mode: per-expert wgrad device pointers wgrad_ptrs = tex.copy_data_ptrs_to_device(wgrad_output, wgrad_output[0].device) wgrad_kernel_fn( - a_tensor=a_tensor, - b_tensor=b_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - offsets_tensor=offsets, + **common_wgrad_kwargs, output_mode="discrete", wgrad_ptrs=wgrad_ptrs, - acc_dtype=torch.float32, wgrad_dtype=wgrad_output[0].dtype, - sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, - accumulate_on_output=accumulate, - current_stream=current_stream, ) @@ -189,6 +239,10 @@ def _compute_grad_params( label="", *, cudnn_wgrad_kernel_fn, + use_nvfp4, + data_dtype, + scale_view_dtype, + sf_vec_size, offsets, ): """Compute weight gradients and build grad_params for a GroupedLinear layer. @@ -258,6 +312,10 @@ def _compute_grad_params( accumulate=accumulate_into_main_grad, wgrad_kernel_fn=cudnn_wgrad_kernel_fn, single_grouped_weight=fc_op.single_grouped_weight, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, current_stream=torch.cuda.current_stream().cuda_stream, ) elif ( @@ -793,7 +851,7 @@ def fuser_backward( ) # FC2 wgrad GEMM - wgrad_kernel_fn = None if use_nvfp4 else self.grouped_gemm_wgrad_kernel() + wgrad_kernel_fn = self.grouped_gemm_wgrad_kernel() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -807,6 +865,10 @@ def fuser_backward( bias_grad_packed=fc2_bias_grad_packed, label="FC2", cudnn_wgrad_kernel_fn=wgrad_kernel_fn, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, offsets=split_points, ) @@ -947,6 +1009,10 @@ def fuser_backward( bias_grad_packed=fc1_bias_grad_packed, label="FC1", cudnn_wgrad_kernel_fn=wgrad_kernel_fn, + use_nvfp4=use_nvfp4, + data_dtype=data_dtype, + scale_view_dtype=scale_view_dtype, + sf_vec_size=sf_vec_size, offsets=split_points, )