Skip to content
15 changes: 15 additions & 0 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
from ..utils import get_default_init_method


def set_quantizer_amax_reduction_group(quantizer, amax_reduction_group) -> None:
"""Set the amax reduction group on a quantizer; no-op if it doesn't support it.

Unwraps ``DebugQuantizer`` to its ``parent_quantizer``, which is the one that
actually performs the quantization (and thus the amax reduction).
"""
if quantizer is None:
return
# DebugQuantizer delegates quantization to parent_quantizer
target = getattr(quantizer, "parent_quantizer", quantizer)
if target is not None and hasattr(target, "with_amax_reduction"):
target.with_amax_reduction = amax_reduction_group is not None
target.amax_reduction_group = amax_reduction_group


def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = {
"LayerNorm": tex.layernorm_fwd,
Expand Down
63 changes: 20 additions & 43 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ._common import (
apply_normalization,
noop_cat,
set_quantizer_amax_reduction_group,
WeightGradStore,
)
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Expand Down Expand Up @@ -216,6 +221,11 @@ def forward(
if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer,
tp_group if (sequence_parallel and parallel_mode == "column") else None,
)

# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
Expand Down Expand Up @@ -691,6 +701,15 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
(
ctx.tp_group
if (ctx.sequence_parallel and ctx.parallel_mode == "row")
else None
),
)

# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -1555,8 +1574,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -1918,15 +1935,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][
Expand All @@ -1935,37 +1943,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
53 changes: 11 additions & 42 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ._common import apply_normalization, set_quantizer_amax_reduction_group, WeightGradStore
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
Expand Down Expand Up @@ -400,6 +400,11 @@ def _forward(
if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the FC1 input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
fc1_input_quantizer,
tp_group if (sequence_parallel and set_parallel_mode) else None,
)

# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
Expand Down Expand Up @@ -1139,6 +1144,11 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for FC2 grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None,
)

# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -2166,8 +2176,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -2677,15 +2685,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM2_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
Expand All @@ -2701,36 +2700,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
68 changes: 25 additions & 43 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, WeightGradStore
from ._common import noop_cat, set_quantizer_amax_reduction_group, WeightGradStore
from ..quantization import FP8GlobalStateManager, QuantizerRole
from ..utils import (
cast_if_needed,
Expand Down Expand Up @@ -306,6 +306,12 @@ def _linear_forward_impl(
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)

# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer,
tp_group if (sequence_parallel and parallel_mode == "column") else None,
)

# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
Expand Down Expand Up @@ -748,6 +754,24 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], ..
grad_weight_quantizer = args.grad_weight_quantizer
grad_output_quantizer = args.grad_output_quantizer

# Amax reduction groups (sequence parallel): input for column-parallel, grad output for row-parallel
set_quantizer_amax_reduction_group(
input_quantizer,
(
bwd_args.tp_group
if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "column")
else None
),
)
set_quantizer_amax_reduction_group(
grad_output_quantizer,
(
bwd_args.tp_group
if (bwd_args.sequence_parallel and bwd_args.parallel_mode == "row")
else None
),
)

# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.backward"
if bwd_args.ub_name is not None:
Expand Down Expand Up @@ -1746,8 +1770,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
Expand Down Expand Up @@ -2116,15 +2138,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# paralle related
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
Expand All @@ -2133,37 +2146,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
Expand Down
30 changes: 13 additions & 17 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...module._common import set_quantizer_amax_reduction_group
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
Expand Down Expand Up @@ -401,23 +402,6 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if getattr(self, "sequence_parallel", False):
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
if tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
elif tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group

# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
Expand Down Expand Up @@ -544,6 +528,10 @@ def _functional_forward(
rowwise=True,
columnwise=weight_requires_grad and backward_override is None,
)
# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer, tensor_parallel_group if with_x_all_gather else None
)
if with_x_all_gather:
input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim(
Expand Down Expand Up @@ -788,6 +776,10 @@ def _functional_backward(
rowwise=input_requires_grad,
columnwise=weight_requires_grad,
)
# Amax reduction group for grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
grad_output_quantizer, tensor_parallel_group if with_dy_all_gather else None
)
if with_dy_all_gather:
dy, dy_async = gather_along_first_dim(
dy_local,
Expand Down Expand Up @@ -828,6 +820,10 @@ def _functional_backward(
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=False, columnwise=True)
# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer, tensor_parallel_group if with_x_all_gather else None
)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
Expand Down
Loading
Loading