Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def __init__(self,

self.group_paddings = []
self.graph_harvesting = graph_harvesting

# Parameters marked ds_fp32_pinned bypass the bf16 path and remain fp32 throughout
# training. They are tracked separately so flattening and dtype checks do not fail.
self.fp32_pinned_groups = []

if self.using_real_optimizer:
self._setup_for_real_optimizer()

Expand Down Expand Up @@ -143,7 +148,13 @@ def _setup_for_real_optimizer(self):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])

# grab the original list
trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
all_trainable = [param for param in param_group['params'] if param.requires_grad]

# Separate params that must stay in fp32 from those that go through the bf16 path.
fp32_pinned_params = [p for p in all_trainable if getattr(p, 'ds_fp32_pinned', False)]
trainable_parameters = [p for p in all_trainable if not getattr(p, 'ds_fp32_pinned', False)]

self.fp32_pinned_groups.append(fp32_pinned_params)
self.bf16_groups.append(trainable_parameters)

# create flat bf16 params
Expand Down Expand Up @@ -207,6 +218,14 @@ def _setup_for_real_optimizer(self):

see_memory_usage(f'after initializing group {i}', force=True)

# Add all fp32-pinned params as an additional group in the base optimizer so they
# receive fp32 optimizer-state updates alongside the regular groups.
all_fp32_pinned = [p for grp in self.fp32_pinned_groups for p in grp]
if all_fp32_pinned:
self.optimizer.param_groups.append({'params': all_fp32_pinned, '_fp32_pinned': True})
# Keep real_dp_process_group list aligned with optimizer param_groups length.
self.real_dp_process_group.append(self.real_dp_process_group[0])

self._grad_acc_hooks = []
if self.immediate_grad_update:
self.create_grad_acc_hooks()
Expand Down Expand Up @@ -391,7 +410,10 @@ def get_grads_for_norm(self, for_clipping=False):
all_grads_for_clip = []

tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
assert len(self.bf16_groups) == len(self.optimizer.param_groups)
# fp32_pinned params are added to the base optimizer as an extra group; exclude that
# extra group from the assertion so the count still matches the bf16 groups.
num_bf16_groups = len(self.bf16_groups)
assert num_bf16_groups <= len(self.optimizer.param_groups)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
Expand All @@ -417,6 +439,16 @@ def get_grads_for_norm(self, for_clipping=False):
non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
else:
all_grads_for_clip.append(self.fp32_groups_gradients[i][j])

# Include gradients from fp32-pinned params (already fp32, no hp/lp split needed).
for grp in self.fp32_pinned_groups:
for p in grp:
if p.grad is not None:
if not for_clipping:
non_expert_grads_for_norm.append(p.grad)
else:
all_grads_for_clip.append(p.grad)

if not for_clipping:
return non_expert_grads_for_norm, expert_grads_for_norm
return all_grads_for_clip
Expand Down Expand Up @@ -448,7 +480,8 @@ def clear_lp_grads(self, set_to_none=False):
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"

zero_grads_list = []
for group in self.bf16_groups:
all_lp_groups = list(self.bf16_groups) + list(self.fp32_pinned_groups)
for group in all_lp_groups:
for param in group:
if set_to_none:
param.grad = None
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,12 @@ def _configure_bf16_optimizer(self, optimizer):

log_dist('Creating BF16 optimizer', ranks=[0])

fp32_pinned = self._config.zero_config.fp32_pinned_parameters if self.zero_optimization() else []
if fp32_pinned:
log_dist(
f'BF16 optimizer: the following parameter name patterns will be kept in FP32: {fp32_pinned}',
ranks=[0])

timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
Expand Down
18 changes: 17 additions & 1 deletion deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team

import sys
from typing import Optional, Dict, Any
from typing import List, Optional, Dict, Any
from enum import Enum
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
Expand Down Expand Up @@ -346,6 +346,22 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
Use memory efficient linear implementation, for Stage 3.
"""

fp32_pinned_parameters: List[str] = Field(default_factory=list)
"""
List of parameter name patterns (sub-strings) whose data should be kept in
FP32 even when the engine is configured with bfloat16 or fp16. Useful for
MoE router weights and other precision-sensitive parameters that must not be
cast to lower precision.

Example: ``"fp32_pinned_parameters": ["router.weight", "gate"]``

Any parameter whose fully-qualified name contains at least one of the listed
sub-strings will have ``param.ds_fp32_pinned = True`` set after model
initialisation, and its data will be kept (or re-cast) to ``torch.float32``.
The bf16/fp16 optimizer will place these parameters in a separate fp32 param
group so they are updated without mixed-precision loss.
"""
"""
Whether force load checkpoint in pipeline mode, current only for Stage 3.
"""
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,11 @@ def __init__(self,
if _ds_config is not None and _ds_config.zero_config.zero_quantized_nontrainable_weights and not self.quantized_nontrainable_weights:
self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights

# List of parameter name patterns that must remain in fp32 even when bf16/fp16 is enabled.
self.fp32_pinned_patterns: list = []
if _ds_config is not None:
self.fp32_pinned_patterns = _ds_config.zero_config.fp32_pinned_parameters

self.module = module
if (self.quantized_weights or self.quantized_nontrainable_weights):
self.quantizer_module = CUDAQuantizer()
Expand Down Expand Up @@ -1128,6 +1133,15 @@ def _convert_to_zero_parameters(self, param_list):
continue

param.data = param.data.to(self.local_device)

# If this parameter is pinned to fp32 (e.g. a MoE router weight),
# re-cast its data to float32 after it has been moved to device.
# The tensor-creation wrappers active inside the Init context may
# have already cast it to bf16/fp16; undoing that cast here ensures
# the parameter retains full precision when ds_fp32_pinned is set.
if getattr(param, "ds_fp32_pinned", False):
param.data = param.data.to(torch.float32)

self._zero_init_param(param)

def _validate_remote_device(self, remote_device, ds_config):
Expand Down Expand Up @@ -1158,6 +1172,12 @@ def _post_init_method(self, module):
if not get_accelerator().on_accelerator(param):
param.data = param.data.to(self.local_device)

# Mark parameters whose names match any fp32_pinned_parameters pattern so
# that downstream code (partition, optimizer) can skip dtype downcasting.
if any(pattern in name for pattern in self.fp32_pinned_patterns):
param.ds_fp32_pinned = True
param.data = param.data.to(torch.float32)

if name == 'weight' and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS:
_quantize_param(param, self.quantized_initialization)

Expand Down