Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 95 additions & 29 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def _cudnn_compute_wgrad(
accumulate: bool,
wgrad_kernel_fn,
single_grouped_weight: bool,
use_nvfp4: bool,
data_dtype: torch.dtype,
scale_view_dtype: torch.dtype,
sf_vec_size: int,
current_stream=None,
):
"""Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel.
Expand All @@ -100,8 +104,6 @@ def _cudnn_compute_wgrad(
out_features, in_features = weight_shape
total_tokens = grouped_dy.logical_shape[0]

fp8_dtype = torch.float8_e4m3fn

sfa_leading_dim = round_up_to_nearest_multiple(out_features, 128)
sfb_leading_dim = round_up_to_nearest_multiple(in_features, 128)

Expand All @@ -110,68 +112,116 @@ def _cudnn_compute_wgrad(
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.
device = grouped_dy.columnwise_data.device
a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device)
a_tensor = torch.empty_strided(
(out_features, 0),
(16, 1),
dtype=data_dtype,
device=device,
)
b_tensor = torch.empty_strided(
(0, in_features), (in_features, 1), dtype=fp8_dtype, device=device
(0, in_features),
(in_features, 1),
dtype=data_dtype,
device=device,
)
sfa_tensor = torch.empty_strided(
(sfa_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
dtype=scale_view_dtype,
device=device,
)
sfb_tensor = torch.empty_strided(
(sfb_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
dtype=scale_view_dtype,
device=device,
)
elif use_nvfp4:
# NVFP4 columnwise data is stored expert-major as per-expert
# (logical_K, group_M / 2) chunks. cuDNN consumes that layout directly
# with input_order="tensor_ragged".
a_tensor = grouped_dy.columnwise_data.view(dtype=data_dtype).view(
out_features,
total_tokens // 2,
)
b_tensor = (
grouped_x.columnwise_data.view(dtype=data_dtype)
.view(
in_features,
total_tokens // 2,
)
.T
)
sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view(
dtype=scale_view_dtype
)
sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view(
dtype=scale_view_dtype
)
else:
a_tensor = (
grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T
grouped_dy.columnwise_data.view(dtype=data_dtype).view(total_tokens, out_features).T
)
b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features)
b_tensor = grouped_x.columnwise_data.view(dtype=data_dtype).view(total_tokens, in_features)
sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
dtype=scale_view_dtype
)
sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
dtype=scale_view_dtype
)

common_wgrad_kwargs = {
"a_tensor": a_tensor,
"b_tensor": b_tensor,
"sfa_tensor": sfa_tensor,
"sfb_tensor": sfb_tensor,
"offsets_tensor": offsets,
"acc_dtype": torch.float32,
"sf_vec_size": sf_vec_size,
"accumulate_on_output": accumulate,
"current_stream": current_stream,
}
if use_nvfp4:
global_scale_denom = 448.0 * 6.0
if total_tokens == 0:
global_scale_shape = (offsets.shape[0],)
common_wgrad_kwargs["global_scale_a"] = torch.zeros(
global_scale_shape,
dtype=torch.float32,
device=device,
)
common_wgrad_kwargs["global_scale_b"] = torch.zeros(
global_scale_shape,
dtype=torch.float32,
device=device,
)
else:
common_wgrad_kwargs["global_scale_a"] = (
_nvfp4_amax(grouped_dy, columnwise=True).to(torch.float32) / global_scale_denom
)
common_wgrad_kwargs["global_scale_b"] = (
_nvfp4_amax(grouped_x, columnwise=True).to(torch.float32) / global_scale_denom
)
common_wgrad_kwargs["input_order"] = "tensor_ragged"

# Prepare wgrad output
if single_grouped_weight:
# Dense mode: single (num_groups, out_features, in_features) tensor
wgrad_tensor = wgrad_output.rowwise_data.view(offsets.shape[0], out_features, in_features)
wgrad_kernel_fn(
a_tensor=a_tensor,
b_tensor=b_tensor,
sfa_tensor=sfa_tensor,
sfb_tensor=sfb_tensor,
offsets_tensor=offsets,
**common_wgrad_kwargs,
output_mode="dense",
wgrad_tensor=wgrad_tensor,
acc_dtype=torch.float32,
wgrad_dtype=wgrad_tensor.dtype,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,
accumulate_on_output=accumulate,
current_stream=current_stream,
)
else:
# Discrete mode: per-expert wgrad device pointers
wgrad_ptrs = tex.copy_data_ptrs_to_device(wgrad_output, wgrad_output[0].device)
wgrad_kernel_fn(
a_tensor=a_tensor,
b_tensor=b_tensor,
sfa_tensor=sfa_tensor,
sfb_tensor=sfb_tensor,
offsets_tensor=offsets,
**common_wgrad_kwargs,
output_mode="discrete",
wgrad_ptrs=wgrad_ptrs,
acc_dtype=torch.float32,
wgrad_dtype=wgrad_output[0].dtype,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,
accumulate_on_output=accumulate,
current_stream=current_stream,
)


Expand All @@ -189,6 +239,10 @@ def _compute_grad_params(
label="",
*,
cudnn_wgrad_kernel_fn,
use_nvfp4,
data_dtype,
scale_view_dtype,
sf_vec_size,
offsets,
):
"""Compute weight gradients and build grad_params for a GroupedLinear layer.
Expand Down Expand Up @@ -258,6 +312,10 @@ def _compute_grad_params(
accumulate=accumulate_into_main_grad,
wgrad_kernel_fn=cudnn_wgrad_kernel_fn,
single_grouped_weight=fc_op.single_grouped_weight,
use_nvfp4=use_nvfp4,
data_dtype=data_dtype,
scale_view_dtype=scale_view_dtype,
sf_vec_size=sf_vec_size,
current_stream=torch.cuda.current_stream().cuda_stream,
)
elif (
Expand Down Expand Up @@ -793,7 +851,7 @@ def fuser_backward(
)

# FC2 wgrad GEMM
wgrad_kernel_fn = None if use_nvfp4 else self.grouped_gemm_wgrad_kernel()
wgrad_kernel_fn = self.grouped_gemm_wgrad_kernel()
fc2_grad_params = _compute_grad_params(
fc_op=fc2_op,
ctx=fc2_ctx,
Expand All @@ -807,6 +865,10 @@ def fuser_backward(
bias_grad_packed=fc2_bias_grad_packed,
label="FC2",
cudnn_wgrad_kernel_fn=wgrad_kernel_fn,
use_nvfp4=use_nvfp4,
data_dtype=data_dtype,
scale_view_dtype=scale_view_dtype,
sf_vec_size=sf_vec_size,
offsets=split_points,
)

Expand Down Expand Up @@ -947,6 +1009,10 @@ def fuser_backward(
bias_grad_packed=fc1_bias_grad_packed,
label="FC1",
cudnn_wgrad_kernel_fn=wgrad_kernel_fn,
use_nvfp4=use_nvfp4,
data_dtype=data_dtype,
scale_view_dtype=scale_view_dtype,
sf_vec_size=sf_vec_size,
offsets=split_points,
)

Expand Down
Loading