Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions pithtrain/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Generator
from typing import Generator, Literal

import torch

Expand Down Expand Up @@ -49,6 +49,34 @@ class DistributedCfg(SlottedDefault):
processing 2 experts. The number of experts should be divisible by the expert parallel size.
"""

nccl_timeout_seconds: int = 180
"""
Timeout for NCCL collective operations and watchdog heartbeat, in seconds.

Bounds how long a rank waits on a hung collective before NCCL's watchdog
aborts the job. Applies both to the per-collective timeout passed to
``torch.distributed.init_process_group`` and to ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``.

Scale up for large multi-node or deep-pipeline runs where first-iteration
setup (NCCL channel build, FSDP all-gather, checkpoint load) can be slow.
Keep small for single-node runs to fail fast.
"""

sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp"
"""
Sharding strategy for FSDP2.

- ``"fsdp"`` (default): Fully Sharded Data Parallel. Shards weights, gradients,
and optimizer states across the full data-parallel mesh (``dp x cp x ep``
for non-MoE parameters; ``dp x cp`` for MoE expert weights). Maximizes
memory savings.

- ``"hsdp"``: Hybrid Sharded Data Parallel. Shards within the inner mesh
dimension (``cp x ep`` for non-MoE; ``cp`` for MoE) and replicates across
the ``dp`` dimension. Lower memory savings than FSDP, but uses a smaller
FSDP group. Useful when the model fits comfortably in a single DP replica.
"""


@dataclass(init=False, slots=True)
class DistributedCtx:
Expand Down Expand Up @@ -110,11 +138,11 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non
ctx.local_rank = int(os.environ["LOCAL_RANK"])
ctx.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

shutdown.set_heartbeat_timeout(cfg.nccl_timeout_seconds)
kwargs = dict()
kwargs["backend"] = "nccl"
kwargs["device_id"] = ctx.local_rank
# NCCL default per-collective timeout is 30 min; cap at 3 to bound worst-case hangs.
kwargs["timeout"] = timedelta(seconds=180)
kwargs["timeout"] = timedelta(seconds=cfg.nccl_timeout_seconds)
torch.distributed.init_process_group(**kwargs)
# See pithtrain.modules.shutdown for why os._exit(1), not destroy/abort.
shutdown.install_failfast_excepthook()
Expand Down
6 changes: 5 additions & 1 deletion pithtrain/modules/shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ def set_env_defaults() -> None:
"""NCCL env defaults that bound failure detection. Must run before init."""
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "180")
os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1")
Comment thread
haok1402 marked this conversation as resolved.


def set_heartbeat_timeout(seconds: int) -> None:
"""Set the NCCL watchdog heartbeat timeout. Must run before init_process_group."""
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(seconds)


def install_failfast_excepthook() -> None:
"""On uncaught exception: print, flush, os._exit(1) -- no abort/destroy."""
prev = sys.excepthook
Expand Down
41 changes: 31 additions & 10 deletions pithtrain/modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed.fsdp
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR
Expand All @@ -25,7 +26,7 @@
from pithtrain.modules.dataset import ConcatDataset, MemmapDataset
from pithtrain.modules.load_balance import make_load_balance_loss_fn

from .distributed import DistributedCtx
from .distributed import DistributedCfg, DistributedCtx


@dataclass(init=False, slots=True)
Expand Down Expand Up @@ -228,12 +229,27 @@ def init_weights(model: nn.Module, num_layers: int, init_std: float = 0.02) -> N
torch.nn.init.normal_(param, mean=0.0, std=init_std)


def apply_fsdp(model, mesh: torch.distributed.DeviceMesh):
# MoE parameters are sharded by EP. We additionally shard on the DP and CP dimension.
# CP ranks hold identical parameters, so they participate in FSDP like DP.
# For other parameters, we shard on the both CP, DP and EP dimensions.
moe_fsdp_mesh = mesh["dp", "cp"]._flatten()
other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten()
def apply_fsdp(
model,
mesh: DeviceMesh,
sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp",
):
# MoE params: unique per EP rank, replicated across DP x CP.
# Non-MoE params: replicated across DP x CP x EP.
# FSDP shards along the replicated dims:
# "fsdp": 1D mesh; FSDP2 shards across all participants.
# "hsdp": 2D mesh; FSDP2 shards along the inner dim and replicates
# along the outer (dp) dim. For non-MoE, cp and ep are folded
# into a single inner shard dim via _concatenate.
if sharding_strategy == "fsdp":
moe_fsdp_mesh = mesh["dp", "cp"]._flatten()
other_fsdp_mesh = mesh["dp", "cp", "ep"]._flatten()
elif sharding_strategy == "hsdp":
moe_fsdp_mesh = mesh["dp", "cp"]
cp_ep_mesh = mesh["cp", "ep"]._flatten("cp_ep")
other_fsdp_mesh = DeviceMesh._concatenate([mesh["dp"], cp_ep_mesh])
else:
raise ValueError(f"Unknown sharding_strategy: {sharding_strategy!r}")
Comment thread
haok1402 marked this conversation as resolved.
mp = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
Expand Down Expand Up @@ -280,7 +296,12 @@ def apply_fsdp(model, mesh: torch.distributed.DeviceMesh):
return model


def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx) -> None:
def setup_model(
cfg: TrainingCfg,
ctx: TrainingCtx,
distributed_cfg: DistributedCfg,
distributed: DistributedCtx,
) -> None:
from pithtrain.dualpipe.utils import FP8WeightCacheControl
from pithtrain.layers.factory import ModelImplMode

Expand Down Expand Up @@ -352,7 +373,7 @@ def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx)
init_weights(module, num_layers, cfg.init_std)

modules = nn.Sequential(*modules)
apply_fsdp(modules, device_mesh)
apply_fsdp(modules, device_mesh, distributed_cfg.sharding_strategy)

local_seq_len = cfg.sequence_length // cp_size
# sequence_length = cfg.sequence_length, TODO this is kept here for stripe context parallelism
Expand Down Expand Up @@ -414,7 +435,7 @@ def training_context(cfg: object, ctx: object) -> Generator[TrainingCtx, None, N
np.random.seed(cfg.training.seed)
torch.manual_seed(cfg.training.seed)
torch.cuda.manual_seed_all(cfg.training.seed)
setup_model(cfg.training, ctx.training, ctx.distributed)
setup_model(cfg.training, ctx.training, cfg.distributed, ctx.distributed)
setup_optimizer(cfg.training, ctx.training)
setup_scheduler(cfg.training, ctx.training)
try:
Expand Down
Loading