Skip to content

[BUG] Fix: Fix gradient norm calculation and dynamic shape blocking in PP+ZeRO1 collective communication#7847

Open
Thinksky5124 wants to merge 6 commits intodeepspeedai:masterfrom
Thinksky5124:master
Open

[BUG] Fix: Fix gradient norm calculation and dynamic shape blocking in PP+ZeRO1 collective communication#7847
Thinksky5124 wants to merge 6 commits intodeepspeedai:masterfrom
Thinksky5124:master

Conversation

@Thinksky5124
Copy link
Copy Markdown

Describe the bug

This commit fixes gradient normalization bugs when using DeepSpeed Pipeline Parallel (pp) together with ZeRO Stage 1 (zero1), including the following aspects:

  1. PipelineEngine Buffer Type Consistency in Dyanmic Shape
    In deepspeed/runtime/pipe/engine.py, the activation buffer previously did not enforce dtype conversion, which could lead to inconsistent types and subsequent calculation errors. Now, the return value is explicitly cast to the target dtype, ensuring type consistency.

  2. ZeRO Stage 1/2 Gradient Normalization Logic Correction
    In deepspeed/runtime/zero/stage_1_and_2.py, for both CPU-offload and regular scenarios, the previous gradient normalization involved redundant communication and incorrect normalization:

complete_grad_norm_calculation_for_cpu_offload now only computes the local squared L2 norm without cross-rank communication, avoiding redundant normalization and double counting.

get_grad_norm_direct only supports L2 norm, directly accumulates the local gradient squared sum, and avoids double counting for pipeline parallel parameters.

scaled_global_norm unifies the normalization process: first accumulates the squared sum of gradients for all groups locally, then uses all_reduce to aggregate, and finally takes the square root to obtain the global L2 norm, ensuring consistency with ZeRO design.

  1. Code Redundancy and Exception Handling Optimization
    Removed the inf constant and unnecessary norm_type branches, now only supporting L2 norm to simplify the logic. Stricter skipping of None gradients and pipeline-replicated parameters improves robustness.
    After these fixes, gradient normalization in pp+zero1 scenarios is more accurate, avoiding double counting, type inconsistency, and redundant communication, thus improving training stability and performance. Reviewers are advised to focus on the correctness and compatibility of gradient normalization code to ensure consistent behavior across different parallel/offload scenarios.

DeepSpeed Config

{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 4,
    "steps_per_print": 1,
    "zero_optimization": {
        "stage": 1
    }
}

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 99697687fc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1997 to +1998
if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)
self._average_expert_grad_norms(local_total_sq_norm)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve per-group norms when averaging MoE expert grads

When has_moe_layers is true, _average_expert_grad_norms expects a per-parameter-group collection (it iterates by group index and checks self.is_moe_param_group[i]), but this call now passes a single accumulated tensor. In runs with multiple optimizer groups, that means MoE scaling is applied to at most index 0 (or to the already-mixed total), so expert and non-expert contributions are mis-scaled before clipping/overflow logic, producing incorrect global grad norms.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fix

self._average_expert_grad_norms(norm_groups)
self._average_expert_grad_norms(local_total_sq_norm)

local_total_sq_norm = local_total_sq_norm.to(torch.cuda.current_device())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid hard-coding CUDA device for norm reduction

This forces scaled_global_norm onto torch.cuda.current_device() even though the optimizer code is written against DeepSpeed's accelerator abstraction. On non-CUDA backends (or CPU-only execution), this line raises before all_reduce, so gradient norm computation and optimizer step fail outright; the tensor should stay on self.device or use get_accelerator().current_device_name().

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fix

…ice move in scaled_global_norm

Signed-off-by: Thinksky5124 <40914433+Thinksky5124@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Thinksky5124,

Thank you for the PR! However, since this contains significant changes to the core DeepSpeed engine, it is difficult to ensure that existing behavior won't be impacted. Given that you’ve listed three distinct changes, could you please separate them into individual PRs?

Additionally, I’m not entirely clear on what is currently incorrect regarding the gradient norm. While the reduction (dist.all_reduce(total_norm, ...)) seems redundant, the result itself appears to be correct

if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
local_sq_norm += param_norm**2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With offloading enabled, local_sq_norm is on cpu and param_norm is on the accelerator.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fix

assert norm_type == 2, "only L2 norm supported"

self._model_parallel_all_reduce(tensor=total_norm, op=dist.ReduceOp.SUM)
local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set fp16/bf16 to self.gradient_accumulation_dtype. Accumulating squared norm to the lower precision is not safe.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to torch.float32 to accumulate grad

@Thinksky5124
Copy link
Copy Markdown
Author

Thinksky5124 commented Mar 27, 2026

Hi @Thinksky5124,

Thank you for the PR! However, since this contains significant changes to the core DeepSpeed engine, it is difficult to ensure that existing behavior won't be impacted. Given that you’ve listed three distinct changes, could you please separate them into individual PRs?

Additionally, I’m not entirely clear on what is currently incorrect regarding the gradient norm. While the reduction (dist.all_reduce(total_norm, ...)) seems redundant, the result itself appears to be correct

@tohtana Thank you for your patient review!

I think we don not need to separate this MR. I will provide a complete and integrated description of the bug and the corresponding fix, elaborating on both with full technical detail.


Variable and Symbol Definitions

To clearly illustrate the differences before and after the fix, we first define the dimensions used in a distributed training cluster:

  • $N_{dp}$: Number of GPUs within a Data Parallel (DP / ZeRO) group.
  • $N_{pp}$: Number of pipeline stages within a Pipeline Parallel (PP) group (i.e., how many segments the model is split into).
  • $k_i$: Index of a parameter group of rank $i$. With PP, it may be different on PP group.
  • $j$: Index of a specific parameter.
  • $\tilde{g}_{k, j}^{(dp, pp)}$: The loss-scaled gradient of parameter $j$ in group $k$, computed on a specific DP and PP rank.

Overall Objective of the Function

This function computes the global L2 norm of all loss-scaled gradients across both Data Parallel (DP) and Pipeline Parallel (PP) dimensions for all parameter groups.


Before the Fix

The original implementation only performed aggregation across DP ranks and ignored the fact that different PP ranks may hold different numbers of parameters, which leads to inconsistent communication patterns.


Step 1: Local Norm Contribution per Parameter Group

(Inside complete_grad_norm_calculation_for_cpu_offload)

For all parameters in group $k$ that are not replicated by pipeline parallelism and belong to the current model-parallel (MP) rank, compute the sum of squared L2 norms:

$$ s_{k_i}^{(\text{local})} = \sum_{j} \left| \tilde{g}_{k_i,j} \right|_2^2 $$


Step 2: All-Reduce Across DP Process Group

Perform dist.all_reduce(..., group=dp_process_group) to sum $s_k^{(\text{local})}$ across all $N_{dp}$ GPUs in the DP group.
However, this step ignores PP partitioning—different PP ranks may hold different numbers of parameters ($k_i$) , causing inconsistent numbers of all_reduce calls across ranks, which can lead to deadlocks:

$$ s_{k_i}^{(\text{dp,pp})} = \sum_{\text{dp}=1}^{N_{dp}} s_{k_i}^{(\text{local, dp})} $$


Step 3: Second All-Reduce Across Model Parallel (MP) Dimension

Use _model_parallel_all_reduce to aggregate contributions across tensor-parallel (TP) ranks:

$$ s_{k_i}^{(\text{global})} = \sum_{\text{pp}=1}^{N_{pp}} s_{k_i}^{(\text{dp, pp})} $$

Expanding fully:

$$ s_{k_i}^{(\text{global})} = \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} \sum_{j} \left| \tilde{g}_{k_i,j}^{(\text{dp,pp})} \right|_2^2 $$


Step 4: Compute Per-Group L2 Norm Scalar

Take the square root of the total to obtain the global L2 norm for group $k$:

$$ \text{norm}_{k_i} = \sqrt{s_{k_i}^{(\text{global})}} = \left( \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} \sum_{j} \left| \tilde{g}_{{k_i},j}^{(\text{dp,pp})} \right|_2^2 \right)^{1/2} $$

If the result is $\pm\infty$ or NaN, it is replaced with $-1$.


Step 5: Final Formula for scaled_global_norm

Stack all $\text{norm}_{k_i}$ values into a vector and compute its L2 norm using torch.linalg.vector_norm(..., ord=2), which is equivalent to taking the square root of the sum of squared group norms:

$$ \boxed{ \mathcal{G} = \left| \begin{pmatrix} \text{norm}_0 \ \text{norm}_1 \ \vdots \ \text{norm}_{K-1} \end{pmatrix} \right|_2 = \sqrt{\sum_{k=0}^{K-1} \text{norm}_k^2} = \sqrt{\sum_{k=0}^{K-1} \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} \sum_{j} \left| \tilde{g}_{{k_i},j}^{(\text{dp,pp})} \right|_2^2} } $$

Here, $K$ is the total number of parameter groups (len(bit16_groups), maybe differnet in PP group). This value $\mathcal{G}$ represents the global L2 norm of all loss-scaled gradients across the entire parallel topology.


def scaled_global_norm(self, norm_type=2):
    assert norm_type == 2, "only L2 norm supported"
    norm_groups = []
    for i, group in enumerate(self.bit16_groups):  # Note: len(self.bit16_groups) may differ across PP ranks
        if self.cpu_offload:
            norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])
            norm_groups.append(norm)
        else:
            norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))

    if self.has_moe_layers:
        self._average_expert_grad_norms(norm_groups)

    # Step 5
    return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)

def complete_grad_norm_calculation_for_cpu_offload(self, params):
    total_norm = 0.0
    norm_type = 2.0
    for p in params:
        # Pipeline parallelism may replicate parameters. Avoid multi-counting.
        if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
            continue

        if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
            param_id = self.get_param_id(p)
            # Step 1
            if param_id in self.norm_for_param_grads:
                param_norm = self.norm_for_param_grads[param_id]
                total_norm += param_norm.item()**2
            else:
                ...

    # Step 2
    total_dev_norm = get_accelerator().FloatTensor([float(total_norm)])
    dist.all_reduce(total_dev_norm, op=dist.ReduceOp.SUM, group=self.dp_process_group)

    # Step 3
    self._model_parallel_all_reduce(tensor=total_dev_norm, op=dist.ReduceOp.SUM)

    # Step 4
    total_norm = total_dev_norm[0].item()**(1. / norm_type)

    if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
        total_norm = -1.0

    return torch.tensor(total_norm, device=self.device, dtype=torch.float)

Step-by-Step Breakdown of the Fixed Implementation


Step 1: Local Squared Norm per Group

complete_grad_norm_calculation_for_cpu_offload computes, for group $k_j$, the local sum of squared L2 norms on the current rank:

$$ s_{k_i}^{(\text{local})} = \sum_{j} \left|\tilde{g}_{{k_i},j}\right|_2^2 $$

No inter-rank communication occurs here.


Step 2: Accumulate Local Squared Norms Across All Groups

In scaled_global_norm, sum the local squared norms from all $K$ groups:

$$ S^{(\text{local})} = \sum_{k=0}^{K-1} s_{k_i}^{(\text{local})} $$

Unlike the old version, no cross-process communication is performed at this stage—only local accumulation.


Step 3: MoE Correction (Optional)

If MoE layers exist, _average_expert_grad_norms performs an in-place average of expert-group norms across Expert Parallel (EP) process groups. The total is then recomputed:

$$ S^{(\text{local})} \leftarrow \sum_{k=0}^{K-1} \hat{s}_{k_i}^{(\text{local})} $$

where $\hat{s}_k^{(\text{local})}$ denotes the corrected squared norm for group $k$.


Step 4: Single Global All-Reduce

Instead of performing $K$ separate reductions (as in the old version), the new implementation reduces the single accumulated scalar once across both DP and MP dimensions:

$$ S^{(\text{global})} = \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} S^{(\text{local, dp, pp})} $$

Expanded:

$$ S^{(\text{global})} = \sum_{k=0}^{K-1} \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} \sum_{j} \left|\tilde{g}_{{k_i},j}^{(\text{dp,pp})}\right|_2^2 $$


Step 5: Final Result

Take the square root of the global squared norm:

$$ \boxed{ \mathcal{G} = \sqrt{S^{(\text{global})}} = \sqrt{\sum_{k=0}^{K-1} \sum_{\text{dp}=1}^{N_{dp}} \sum_{\text{pp}=1}^{N_{pp}} \sum_{j} \left|\tilde{g}_{{k_i},j}^{(\text{dp,pp})}\right|_2^2} } $$

If the result is $\pm\infty$ or NaN, mask_nan_or_inf_with_val_inplace replaces it with a safe value (typically $0$ or $-1$).

def scaled_global_norm(self, norm_type=2):
    assert norm_type == 2, "only L2 norm supported"
    group_sq_norms = []
    local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
    for i, _ in enumerate(self.bit16_groups):
        if self.cpu_offload:
            group_sq_norm = self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])
        else:
            group_sq_norm = self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])
        group_sq_norms.append(group_sq_norm)
        # Step 2
        local_total_sq_norm += group_sq_norm

    if self.has_moe_layers:
        # _average_expert_grad_norms expects an indexable collection of per-group norms
        # and updates them in-place for MoE groups. Pass the list instead of the
        # single accumulated tensor so expert and non-expert groups are scaled
        # correctly prior to global reduction.
        self._average_expert_grad_norms(group_sq_norms)

        # Recompute the total from possibly-updated per-group norms to reflect
        # any MoE-specific averaging that occurred.
        local_total_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
        # Step 3
        for g in group_sq_norms:
            # ensure device/dtype compatibility when summing
            local_total_sq_norm += g.to(local_total_sq_norm.device)

    # Step 4
    local_total_sq_norm = local_total_sq_norm.to(get_accelerator().current_device_name())
    dist.all_reduce(
        local_total_sq_norm,
        op=dist.ReduceOp.SUM,
        group=self.dp_process_group,
    )
    self._model_parallel_all_reduce(
        tensor=local_total_sq_norm,
        op=dist.ReduceOp.SUM,
    )

    # Step 5
    total_norm = torch.sqrt(local_total_sq_norm)
    mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
    return total_norm

def complete_grad_norm_calculation_for_cpu_offload(self, params):
    """
    Compute local squared L2 norm of gradients for CPU-offloaded parameters.
    No cross-rank communication is performed here.
    """
    local_sq_norm = torch.zeros(1, device=self.device, dtype=self.gradient_accumulation_dtype)
    for p in params:
        # Pipeline parallelism may replicate parameters. Avoid multi-counting.
        if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
            continue

        if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
            param_id = self.get_param_id(p)
            # Step 1
            if param_id in self.norm_for_param_grads:
                param_norm = self.norm_for_param_grads[param_id]
                local_sq_norm += param_norm**2
            else:
                # As unused parameters in modules may not be expected sometimes,
                # add an explicit error msg when it occurred and an option to
                # avoid the error
                assert self.ignore_unused_parameters, """
                    This assert indicates that your module has parameters that
                    were not used in producing loss.
                    You can avoid this assert by
                    (1) enable ignore_unused_parameters option in zero_optimization config;
                    (2) making sure all trainable parameters and `forward` function
                        outputs participate in calculating loss.
                """
    return local_sq_norm

Key Differences Between Old and New Implementations

Aspect Old Version New Version
Return value of complete_grad_norm_... $\sqrt{s_k}$ (L2 norm) $s_k$ (squared L2 norm)
Intra-group communication Performed inside the function for each group None — deferred to final global reduction
Number of all_reduce calls $k_i$ times (risk of deadlock if PP ranks have different group counts) 1 time (after summing all groups)
Final square root Applied to vector norm Applied to global scalar sum

The mathematical result is identical, but the new version **eliminates the risk of communication deadlock due to inconsistent all_reduce counts ** and significantly reduces communication overhead by reducing the number of collective operations from $K$ to 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants