Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage
from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized
from deepspeed.runtime.torch_autocast import get_comm_dtype, has_comm_dtype
from deepspeed.runtime.utils import maybe_loss_for_backward


Expand Down Expand Up @@ -354,7 +354,10 @@ def report_ipg_memory_usage(self, tag, param_elems, dtype=None):
)

def get_param_comm_dtype(self, param):
if is_autocast_initialized():
# Use the per-parameter comm_dtype attribute set by init_autocast_params().
# Each engine stamps its own parameters, so multiple engines with different
# autocast configs are naturally isolated without a shared global state.
if has_comm_dtype(param):
return get_comm_dtype(param)
else:
return self.communication_data_type
Expand Down
37 changes: 25 additions & 12 deletions deepspeed/runtime/torch_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

# DeepSpeed Team

from typing import Iterable, Set, List, Union
from dataclasses import dataclass, field
from typing import Iterable, Optional, Set, List, Union
import importlib
from contextlib import contextmanager

Expand All @@ -22,9 +23,17 @@
PARAM_COMM_DTYPE_ATTR_NAME = "comm_dtype"
_WARNED_NESTED_AUTOCAST = False

# TODO: Avoid using global variables
TORCH_AUTOCAST_INITIALIZED = False
TORCH_AUTOCAST_DTYPE = None

@dataclass
class _AutocastState:
"""Holds torch-autocast initialization state for one DeepSpeed engine instance.

Storing this object on the engine (``engine._autocast_state``) rather than as
a module-level singleton allows multiple engine instances to carry independent
autocast configurations without interfering with each other.
"""
initialized: bool = False
dtype: Optional[torch.dtype] = field(default=None)


def _validate_auto_cast_settings(engine):
Expand Down Expand Up @@ -56,22 +65,26 @@ def init_autocast_params(engine, dtype: torch.dtype,
for p in module.parameters(recurse=False):
setattr(p, PARAM_COMM_DTYPE_ATTR_NAME, dtype)

global TORCH_AUTOCAST_INITIALIZED
TORCH_AUTOCAST_INITIALIZED = True
global TORCH_AUTOCAST_DTYPE
TORCH_AUTOCAST_DTYPE = dtype
engine._autocast_state = _AutocastState(initialized=True, dtype=dtype)


def is_autocast_initialized() -> bool:
return TORCH_AUTOCAST_INITIALIZED
def is_autocast_initialized(engine) -> bool:
"""Return True if torch autocast was initialised for *this* engine instance.

Accepts the engine as an argument so that multiple DeepSpeed engines can
carry independent ``_autocast_state`` objects without sharing a
module-level singleton.
"""
return getattr(engine, '_autocast_state', _AutocastState()).initialized


def get_default_autocast_lower_precision_modules() -> List[str]:
return [f"{cls.__module__}.{cls.__name__}" for cls in LOWER_PRECISION_SAFE_MODULES]


def get_autocast_dtype() -> torch.dtype:
return TORCH_AUTOCAST_DTYPE
def get_autocast_dtype(engine) -> torch.dtype:
"""Return the autocast dtype configured for *this* engine instance."""
return getattr(engine, '_autocast_state', _AutocastState()).dtype


def has_comm_dtype(param: torch.nn.Parameter) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook, required_torch_version
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward
from deepspeed.runtime.zero.partition_parameters import *
Expand Down Expand Up @@ -435,9 +435,10 @@ def _enforce_optimizer_offload():
self.is_param_in_current_partition = {}

self.torch_autocast_gradscaler = None
if is_autocast_initialized():
comm_dtypes = get_all_comm_dtypes([p for params in self.fp16_groups for p in params])
if get_autocast_dtype() == torch.float16:
all_params = [p for params in self.fp16_groups for p in params]
if any(has_comm_dtype(p) for p in all_params):
comm_dtypes = get_all_comm_dtypes(all_params)
if torch.float16 in comm_dtypes:
self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name())
else:
comm_dtypes = {self.communication_data_type}
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from deepspeed.runtime.zero.offload_states import offload_optimizer_states, reload_optimizer_states
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.torch_autocast import get_all_comm_dtypes, has_comm_dtype, sort_dtypes
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace,
count_used_parameters_in_backward)
Expand Down Expand Up @@ -518,9 +518,10 @@ def _enforce_cpu_offload():
self.is_param_in_current_partition = {}

self.torch_autocast_gradscaler = None
if is_autocast_initialized():
comm_dtypes = get_all_comm_dtypes([p for params in self.bit16_groups for p in params])
if get_autocast_dtype() == torch.float16:
all_params = [p for params in self.bit16_groups for p in params]
if any(has_comm_dtype(p) for p in all_params):
comm_dtypes = get_all_comm_dtypes(all_params)
if torch.float16 in comm_dtypes:
self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name())
else:
comm_dtypes = {self.communication_data_type}
Expand Down
Loading