From 254d93e8d8db95c6bf195a40938ce68ac708a640 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Fri, 15 May 2026 23:39:50 -0400 Subject: [PATCH] add support for hsdp and configurable nccl timeout --- pithtrain/modules/distributed.py | 34 +++++++++++++++++++++++--- pithtrain/modules/shutdown.py | 6 ++++- pithtrain/modules/training.py | 41 ++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/pithtrain/modules/distributed.py b/pithtrain/modules/distributed.py index 81af575..17393a2 100644 --- a/pithtrain/modules/distributed.py +++ b/pithtrain/modules/distributed.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Generator +from typing import Generator, Literal import torch @@ -49,6 +49,34 @@ class DistributedCfg(SlottedDefault): processing 2 experts. The number of experts should be divisible by the expert parallel size. """ + nccl_timeout_seconds: int = 180 + """ + Timeout for NCCL collective operations and watchdog heartbeat, in seconds. + + Bounds how long a rank waits on a hung collective before NCCL's watchdog + aborts the job. Applies both to the per-collective timeout passed to + ``torch.distributed.init_process_group`` and to ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``. + + Scale up for large multi-node or deep-pipeline runs where first-iteration + setup (NCCL channel build, FSDP all-gather, checkpoint load) can be slow. + Keep small for single-node runs to fail fast. + """ + + sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp" + """ + Sharding strategy for FSDP2. + + - ``"fsdp"`` (default): Fully Sharded Data Parallel. Shards weights, gradients, + and optimizer states across the full data-parallel mesh (``dp x cp x ep`` + for non-MoE parameters; ``dp x cp`` for MoE expert weights). Maximizes + memory savings. + + - ``"hsdp"``: Hybrid Sharded Data Parallel. Shards within the inner mesh + dimension (``cp x ep`` for non-MoE; ``cp`` for MoE) and replicates across + the ``dp`` dimension. Lower memory savings than FSDP, but uses a smaller + FSDP group. Useful when the model fits comfortably in a single DP replica. + """ + @dataclass(init=False, slots=True) class DistributedCtx: @@ -110,11 +138,11 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non ctx.local_rank = int(os.environ["LOCAL_RANK"]) ctx.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + shutdown.set_heartbeat_timeout(cfg.nccl_timeout_seconds) kwargs = dict() kwargs["backend"] = "nccl" kwargs["device_id"] = ctx.local_rank - # NCCL default per-collective timeout is 30 min; cap at 3 to bound worst-case hangs. - kwargs["timeout"] = timedelta(seconds=180) + kwargs["timeout"] = timedelta(seconds=cfg.nccl_timeout_seconds) torch.distributed.init_process_group(**kwargs) # See pithtrain.modules.shutdown for why os._exit(1), not destroy/abort. shutdown.install_failfast_excepthook() diff --git a/pithtrain/modules/shutdown.py b/pithtrain/modules/shutdown.py index f6bd015..f8919f2 100644 --- a/pithtrain/modules/shutdown.py +++ b/pithtrain/modules/shutdown.py @@ -23,10 +23,14 @@ def set_env_defaults() -> None: """NCCL env defaults that bound failure detection. Must run before init.""" os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0") - os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "180") os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1") +def set_heartbeat_timeout(seconds: int) -> None: + """Set the NCCL watchdog heartbeat timeout. Must run before init_process_group.""" + os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(seconds) + + def install_failfast_excepthook() -> None: """On uncaught exception: print, flush, os._exit(1) -- no abort/destroy.""" prev = sys.excepthook diff --git a/pithtrain/modules/training.py b/pithtrain/modules/training.py index 8448fd1..2370dee 100644 --- a/pithtrain/modules/training.py +++ b/pithtrain/modules/training.py @@ -12,6 +12,7 @@ import torch import torch.distributed.fsdp import torch.nn as nn +from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR @@ -25,7 +26,7 @@ from pithtrain.modules.dataset import ConcatDataset, MemmapDataset from pithtrain.modules.load_balance import make_load_balance_loss_fn -from .distributed import DistributedCtx +from .distributed import DistributedCfg, DistributedCtx @dataclass(init=False, slots=True) @@ -228,12 +229,27 @@ def init_weights(model: nn.Module, num_layers: int, init_std: float = 0.02) -> N torch.nn.init.normal_(param, mean=0.0, std=init_std) -def apply_fsdp(model, mesh: torch.distributed.DeviceMesh): - # MoE parameters are sharded by EP. We additionally shard on the DP and CP dimension. - # CP ranks hold identical parameters, so they participate in FSDP like DP. - # For other parameters, we shard on the both CP, DP and EP dimensions. - moe_fsdp_mesh = mesh["dp", "cp"]._flatten() - other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten() +def apply_fsdp( + model, + mesh: DeviceMesh, + sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp", +): + # MoE params: unique per EP rank, replicated across DP x CP. + # Non-MoE params: replicated across DP x CP x EP. + # FSDP shards along the replicated dims: + # "fsdp": 1D mesh; FSDP2 shards across all participants. + # "hsdp": 2D mesh; FSDP2 shards along the inner dim and replicates + # along the outer (dp) dim. For non-MoE, cp and ep are folded + # into a single inner shard dim via _concatenate. + if sharding_strategy == "fsdp": + moe_fsdp_mesh = mesh["dp", "cp"]._flatten() + other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten() + elif sharding_strategy == "hsdp": + moe_fsdp_mesh = mesh["dp", "cp"] + cp_ep_mesh = mesh["cp", "ep"]._flatten("cp_ep") + other_fsdp_mesh = DeviceMesh._concatenate([mesh["dp"], cp_ep_mesh]) + else: + raise ValueError(f"Unknown sharding_strategy: {sharding_strategy!r}") mp = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, @@ -280,7 +296,12 @@ def apply_fsdp(model, mesh: torch.distributed.DeviceMesh): return model -def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx) -> None: +def setup_model( + cfg: TrainingCfg, + ctx: TrainingCtx, + distributed_cfg: DistributedCfg, + distributed: DistributedCtx, +) -> None: from pithtrain.dualpipe.utils import FP8WeightCacheControl from pithtrain.layers.factory import ModelImplMode @@ -352,7 +373,7 @@ def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx) init_weights(module, num_layers, cfg.init_std) modules = nn.Sequential(*modules) - apply_fsdp(modules, device_mesh) + apply_fsdp(modules, device_mesh, distributed_cfg.sharding_strategy) local_seq_len = cfg.sequence_length // cp_size # sequence_length = cfg.sequence_length, TODO this is kept here for stripe context parallelism @@ -414,7 +435,7 @@ def training_context(cfg: object, ctx: object) -> Generator[TrainingCtx, None, N np.random.seed(cfg.training.seed) torch.manual_seed(cfg.training.seed) torch.cuda.manual_seed_all(cfg.training.seed) - setup_model(cfg.training, ctx.training, ctx.distributed) + setup_model(cfg.training, ctx.training, cfg.distributed, ctx.distributed) setup_optimizer(cfg.training, ctx.training) setup_scheduler(cfg.training, ctx.training) try: