From 808c88b91c2b4bba71f8f3e6ee5bcc4f4d16a1aa Mon Sep 17 00:00:00 2001 From: harshang03 Date: Sun, 22 Feb 2026 16:22:03 +0530 Subject: [PATCH] fix: keep fp32-pinned parameters out of the bf16 cast path in ZeRO-3 When bf16 is enabled with ZeRO stage 3, all model parameters were cast to bfloat16 inside the `Init` context via the global tensor-creation wrappers. This caused MoE router weights (and any other parameters that require full fp32 precision) to silently lose precision, leading to incorrect routing decisions and training instability. A new ZeRO config field `fp32_pinned_parameters` (list of name-pattern strings) lets users designate parameters that must remain in fp32: "zero_optimization": { "stage": 3, "fp32_pinned_parameters": ["router.weight", "gate."] } Changes: - `config.py`: add `fp32_pinned_parameters` field to `ZeroConfig` - `partition_parameters.py`: in `_post_init_method` and `_convert_to_zero_parameters`, mark matching params with `ds_fp32_pinned = True` and re-cast their data to float32 after the bf16 tensor-creation wrappers would have downcast them. - `bf16_optimizer.py`: in `_setup_for_real_optimizer`, separate fp32-pinned params from the normal bf16 groups; add them as a dedicated fp32 group in the base optimizer so their states are kept in fp32. Include their gradients in norm/clip computation (`get_grads_for_norm`) and clear them alongside bf16 params (`clear_lp_grads`). - `engine.py`: log the active fp32-pinned patterns at BF16 optimizer creation time for easier debugging. Fixes #7747 --- deepspeed/runtime/bf16_optimizer.py | 39 +++++++++++++++++-- deepspeed/runtime/engine.py | 6 +++ deepspeed/runtime/zero/config.py | 18 ++++++++- .../runtime/zero/partition_parameters.py | 20 ++++++++++ 4 files changed, 79 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index cc5f9959f57a..fda93fce87e8 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -102,6 +102,11 @@ def __init__(self, self.group_paddings = [] self.graph_harvesting = graph_harvesting + + # Parameters marked ds_fp32_pinned bypass the bf16 path and remain fp32 throughout + # training. They are tracked separately so flattening and dtype checks do not fail. + self.fp32_pinned_groups = [] + if self.using_real_optimizer: self._setup_for_real_optimizer() @@ -143,7 +148,13 @@ def _setup_for_real_optimizer(self): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) # grab the original list - trainable_parameters = [param for param in param_group['params'] if param.requires_grad] + all_trainable = [param for param in param_group['params'] if param.requires_grad] + + # Separate params that must stay in fp32 from those that go through the bf16 path. + fp32_pinned_params = [p for p in all_trainable if getattr(p, 'ds_fp32_pinned', False)] + trainable_parameters = [p for p in all_trainable if not getattr(p, 'ds_fp32_pinned', False)] + + self.fp32_pinned_groups.append(fp32_pinned_params) self.bf16_groups.append(trainable_parameters) # create flat bf16 params @@ -207,6 +218,14 @@ def _setup_for_real_optimizer(self): see_memory_usage(f'after initializing group {i}', force=True) + # Add all fp32-pinned params as an additional group in the base optimizer so they + # receive fp32 optimizer-state updates alongside the regular groups. + all_fp32_pinned = [p for grp in self.fp32_pinned_groups for p in grp] + if all_fp32_pinned: + self.optimizer.param_groups.append({'params': all_fp32_pinned, '_fp32_pinned': True}) + # Keep real_dp_process_group list aligned with optimizer param_groups length. + self.real_dp_process_group.append(self.real_dp_process_group[0]) + self._grad_acc_hooks = [] if self.immediate_grad_update: self.create_grad_acc_hooks() @@ -391,7 +410,10 @@ def get_grads_for_norm(self, for_clipping=False): all_grads_for_clip = [] tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - assert len(self.bf16_groups) == len(self.optimizer.param_groups) + # fp32_pinned params are added to the base optimizer as an extra group; exclude that + # extra group from the assertion so the count still matches the bf16 groups. + num_bf16_groups = len(self.bf16_groups) + assert num_bf16_groups <= len(self.optimizer.param_groups) for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if not for_clipping: @@ -417,6 +439,16 @@ def get_grads_for_norm(self, for_clipping=False): non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j]) else: all_grads_for_clip.append(self.fp32_groups_gradients[i][j]) + + # Include gradients from fp32-pinned params (already fp32, no hp/lp split needed). + for grp in self.fp32_pinned_groups: + for p in grp: + if p.grad is not None: + if not for_clipping: + non_expert_grads_for_norm.append(p.grad) + else: + all_grads_for_clip.append(p.grad) + if not for_clipping: return non_expert_grads_for_norm, expert_grads_for_norm return all_grads_for_clip @@ -448,7 +480,8 @@ def clear_lp_grads(self, set_to_none=False): assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None" zero_grads_list = [] - for group in self.bf16_groups: + all_lp_groups = list(self.bf16_groups) + list(self.fp32_pinned_groups) + for group in all_lp_groups: for param in group: if set_to_none: param.grad = None diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e6d838df5adf..a0eee5e27ab2 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1838,6 +1838,12 @@ def _configure_bf16_optimizer(self, optimizer): log_dist('Creating BF16 optimizer', ranks=[0]) + fp32_pinned = self._config.zero_config.fp32_pinned_parameters if self.zero_optimization() else [] + if fp32_pinned: + log_dist( + f'BF16 optimizer: the following parameter name patterns will be kept in FP32: {fp32_pinned}', + ranks=[0]) + timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = BF16_Optimizer(optimizer, self.param_names, diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index def8d1db5653..f91371ff14dc 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -4,7 +4,7 @@ # DeepSpeed Team import sys -from typing import Optional, Dict, Any +from typing import List, Optional, Dict, Any from enum import Enum from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel @@ -346,6 +346,22 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ Use memory efficient linear implementation, for Stage 3. """ + + fp32_pinned_parameters: List[str] = Field(default_factory=list) + """ + List of parameter name patterns (sub-strings) whose data should be kept in + FP32 even when the engine is configured with bfloat16 or fp16. Useful for + MoE router weights and other precision-sensitive parameters that must not be + cast to lower precision. + + Example: ``"fp32_pinned_parameters": ["router.weight", "gate"]`` + + Any parameter whose fully-qualified name contains at least one of the listed + sub-strings will have ``param.ds_fp32_pinned = True`` set after model + initialisation, and its data will be kept (or re-cast) to ``torch.float32``. + The bf16/fp16 optimizer will place these parameters in a separate fp32 param + group so they are updated without mixed-precision loss. + """ """ Whether force load checkpoint in pipeline mode, current only for Stage 3. """ diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index de7b2c11aed8..72409b036a80 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1066,6 +1066,11 @@ def __init__(self, if _ds_config is not None and _ds_config.zero_config.zero_quantized_nontrainable_weights and not self.quantized_nontrainable_weights: self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights + # List of parameter name patterns that must remain in fp32 even when bf16/fp16 is enabled. + self.fp32_pinned_patterns: list = [] + if _ds_config is not None: + self.fp32_pinned_patterns = _ds_config.zero_config.fp32_pinned_parameters + self.module = module if (self.quantized_weights or self.quantized_nontrainable_weights): self.quantizer_module = CUDAQuantizer() @@ -1128,6 +1133,15 @@ def _convert_to_zero_parameters(self, param_list): continue param.data = param.data.to(self.local_device) + + # If this parameter is pinned to fp32 (e.g. a MoE router weight), + # re-cast its data to float32 after it has been moved to device. + # The tensor-creation wrappers active inside the Init context may + # have already cast it to bf16/fp16; undoing that cast here ensures + # the parameter retains full precision when ds_fp32_pinned is set. + if getattr(param, "ds_fp32_pinned", False): + param.data = param.data.to(torch.float32) + self._zero_init_param(param) def _validate_remote_device(self, remote_device, ds_config): @@ -1158,6 +1172,12 @@ def _post_init_method(self, module): if not get_accelerator().on_accelerator(param): param.data = param.data.to(self.local_device) + # Mark parameters whose names match any fp32_pinned_parameters pattern so + # that downstream code (partition, optimizer) can skip dtype downcasting. + if any(pattern in name for pattern in self.fp32_pinned_patterns): + param.ds_fp32_pinned = True + param.data = param.data.to(torch.float32) + if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS: _quantize_param(param, self.quantized_initialization)