Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pithtrain/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""PithTrain base classes."""

from dataclasses import MISSING, asdict, fields
from datetime import timedelta
from pathlib import Path


Expand Down Expand Up @@ -28,9 +29,11 @@ def to_json_dict(self) -> dict:

@staticmethod
def _make_json_serializable(obj):
"""Recursively convert non-serializable types (e.g. Path) to strings."""
"""Recursively convert non-serializable types into JSON primitives."""
if isinstance(obj, dict):
return {k: SlottedDefault._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, Path):
return str(obj)
elif isinstance(obj, timedelta):
return obj.total_seconds()
return obj
179 changes: 94 additions & 85 deletions pithtrain/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import atexit
import os
import sys
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
Expand All @@ -10,125 +12,107 @@
import torch

from pithtrain.config import SlottedDefault
from pithtrain.modules import shutdown


@dataclass(init=False, slots=True)
class DistributedCfg(SlottedDefault):
"""Configuration for distributed runtime."""
"""
Configuration for distributed runtime.

pipeline_parallel_size: int = 1
Parallelism degrees (PP, CP, EP), FSDP2 sharding strategy, and operation timeout. DP is
inferred from the world size.
"""
Degree of pipeline parallelism.

Pipeline Parallelism (PP) is a technique that assigns consecutive layers or segments of a neural network
to different GPUs. This division allows each GPU to process different stages of the network sequentially.
pipeline_parallel_size: int = 1
"""
Degree of pipeline parallelism (PP).

For example, if a model has 12 layers and the pipeline_parallel_size is set to 4, then each GPU will
handle 3 layers.
Partition the model layers across ranks; each rank holds a consecutive slice. Forward and
backward execution is scheduled by DualPipeV.
"""

context_parallel_size: int = 1
"""
Degree of context parallelism.
Degree of context parallelism (CP).

Context Parallelism (CP) splits the sequence dimension across GPUs. Each GPU processes a chunk
of the full sequence. Ring attention is used to compute full causal attention across the
distributed sequence chunks, passing K/V around a ring of CP ranks.
Shard the sequence dimension across CP ranks. K/V exchange uses ring attention with a zigzag
token layout.
"""

expert_parallel_size: int = 1
"""
Degree of expert parallelism.

Expert Parallelism (EP) is a type of model parallelism that distributes experts of an MoE across GPUs.
Unlike other model-parallel techniques, EP is applied to only the expert layers thus does not impact
the parallel mapping of the rest of the layers.
Degree of expert parallelism (EP).

For example, if the model has 8 experts, then setting expert_parallel_size to 4 results in each GPU
processing 2 experts. The number of experts should be divisible by the expert parallel size.
Distribute the MoE experts across ranks; non-expert layers are unaffected. Token routing uses
EP dispatch and combine kernels with token deduplication.
"""

nccl_timeout_seconds: int = 180
timeout: timedelta = timedelta(minutes=15)
"""
Timeout for NCCL collective operations and watchdog heartbeat, in seconds.
Timeout for distributed operations.

Bounds how long a rank waits on a hung collective before NCCL's watchdog
aborts the job. Applies both to the per-collective timeout passed to
``torch.distributed.init_process_group`` and to ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``.

Scale up for large multi-node or deep-pipeline runs where first-iteration
setup (NCCL channel build, FSDP all-gather, checkpoint load) can be slow.
Keep small for single-node runs to fail fast.
Applied to NCCL collectives and the watchdog heartbeat. Scale up for multi-node runs; keep
small to fail fast.
"""

sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp"
"""
Sharding strategy for FSDP2.

- ``"fsdp"`` (default): Fully Sharded Data Parallel. Shards weights, gradients,
and optimizer states across the full data-parallel mesh (``dp x cp x ep``
for non-MoE parameters; ``dp x cp`` for MoE expert weights). Maximizes
memory savings.
FSDP2 sharding strategy.

- ``"hsdp"``: Hybrid Sharded Data Parallel. Shards within the inner mesh
dimension (``cp x ep`` for non-MoE; ``cp`` for MoE) and replicates across
the ``dp`` dimension. Lower memory savings than FSDP, but uses a smaller
FSDP group. Useful when the model fits comfortably in a single DP replica.
- "fsdp": shard parameters across the full FSDP mesh (dp x cp x ep for non-MoE; dp x cp
for MoE experts). Lowest memory.
- "hsdp": shard within the inner mesh (cp x ep for non-MoE; cp for MoE) and replicate
across dp. Pick when one DP replica fits the model.
"""


@dataclass(init=False, slots=True)
class DistributedCtx:
"""Context for distributed runtime."""
"""
Context for distributed runtime.

Hold the torchrun ranks alongside the (PP, DP, CP, EP) device mesh, providing a single source
of truth that the training loop, model constructors, and collectives reference.
"""

rank: int
"""Global rank of the current process."""
"""Global worker rank."""

world_size: int
"""Total number of workers in the distributed job."""
"""Total number of workers."""

local_rank: int
"""Local rank of the current process on the node."""
"""Worker rank within the node."""

local_world_size: int
"""Number of workers on the current node."""
"""Number of workers on the node."""

dp_rank: int
"""Rank of the current process in the data parallel group."""

dp_size: int
"""Size of the data parallel group."""
device_mesh: torch.distributed.DeviceMesh
"""4D mesh over (PP, DP, CP, EP) axes."""

pp_rank: int
"""Rank of the current process in the pipeline parallel group."""

pp_size: int
"""Size of the pipeline parallel group."""

dp_rank: int
dp_size: int
cp_rank: int
"""Rank of the current process in the context parallel group."""

cp_size: int
"""Size of the context parallel group."""

ep_rank: int
"""Rank of the current process in the expert parallel group."""

ep_size: int
"""Size of the expert parallel group."""

device_mesh: torch.distributed.DeviceMesh
"""Collection of process groups for multi-dimensional parallelism."""

def setup_torch_runtime() -> None:
"""Apply torch runtime tuning: enable TF32 matmul and raise the dynamo recompile cap."""
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch._dynamo.config.recompile_limit = 64


def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> None:
"""
Setup the default process group.
Initialize the default process group from torchrun environment variables.

This function initializes the default process group using environment variables by torchrun.
It also sets the current CUDA device based on the LOCAL_RANK environment variable. A cleanup
function is registered to destroy the process group at program exit.
Read global/local rank info into ctx, apply NCCL env tuning, register cleanup at exit, and set
the current CUDA device from the local rank.
"""
assert torch.cuda.is_available(), "CUDA is not available."
assert "TORCHELASTIC_RUN_ID" in os.environ, "Not launched with torchrun."
Expand All @@ -138,36 +122,62 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non
ctx.local_rank = int(os.environ["LOCAL_RANK"])
ctx.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

shutdown.set_heartbeat_timeout(cfg.nccl_timeout_seconds)
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
Comment thread
haok1402 marked this conversation as resolved.
os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0")
os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1")
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(int(cfg.timeout.total_seconds()))

kwargs = dict()
kwargs["backend"] = "nccl"
kwargs["device_id"] = ctx.local_rank
kwargs["timeout"] = timedelta(seconds=cfg.nccl_timeout_seconds)
kwargs["timeout"] = cfg.timeout
torch.distributed.init_process_group(**kwargs)
# See pithtrain.modules.shutdown for why os._exit(1), not destroy/abort.
shutdown.install_failfast_excepthook()
atexit.register(torch.distributed.destroy_process_group)
torch.cuda.set_device(ctx.local_rank)


def setup_device_mesh(cfg: DistributedCfg, ctx: DistributedCtx) -> None:
def setup_failfast_excepthook() -> None:
"""
Setup the device mesh.
Install a fail-fast excepthook that bypasses the NCCL drain on uncaught exceptions.

Process groups are created in the following order. EP and CP are the inner-most
dimensions to keep their frequent communications within the NVLink domain.
Default torch.distributed shutdown can hang indefinitely while draining in-flight NCCL work
that peers will never satisfy. Hard-exiting bypasses the drain so NCCL wor on other ranks
fail fast instead of hanging.
"""
original = sys.excepthook

Mesh shape: ``(PP, DP, CP, EP)``
def excepthook(exc_type, exc_value, exc_tb, *_):
try:
original(exc_type, exc_value, exc_tb)
except Exception:
pass
try:
sys.stdout.flush()
sys.stderr.flush()
except Exception:
pass
os._exit(1)

1. Pipeline Parallel (PP) - outermost
2. Data Parallel (DP)
3. Context Parallel (CP) - ring attention KV exchange
4. Expert Parallel (EP) - innermost, MoE all-to-all
sys.excepthook = excepthook
threading.excepthook = lambda args: excepthook(*args)
Comment thread
haok1402 marked this conversation as resolved.


def setup_device_mesh(cfg: DistributedCfg, ctx: DistributedCtx) -> None:
"""
Build the (PP, DP, CP, EP) device mesh and read per-axis ranks and sizes into ctx.

Mesh dimensions go outer-to-inner: PP, DP, CP, EP. CP and EP sit innermost so frequent
collectives (ring K/V exchange, MoE all-to-all) stay within the NVLink domain.
"""
ctx.ep_size = cfg.expert_parallel_size
ctx.pp_size = cfg.pipeline_parallel_size
ctx.cp_size = cfg.context_parallel_size
ctx.dp_size = ctx.world_size // (ctx.ep_size * ctx.pp_size * ctx.cp_size)
ctx.pp_size = pp_size = cfg.pipeline_parallel_size
ctx.cp_size = cp_size = cfg.context_parallel_size
ctx.ep_size = ep_size = cfg.expert_parallel_size
world_size = ctx.world_size

divisor = pp_size * cp_size * ep_size
if world_size % divisor != 0:
raise RuntimeError(f"{world_size=} not divisible by {pp_size=} * {cp_size=} * {ep_size=}")
ctx.dp_size = world_size // divisor

kwargs = dict()
kwargs["device_type"] = "cuda"
Expand All @@ -186,9 +196,8 @@ def distributed_context(cfg: object, ctx: object) -> Generator[DistributedCtx, N
"""Context manager for distributed runtime."""
assert hasattr(cfg, "distributed") and isinstance(cfg.distributed, DistributedCfg)
assert hasattr(ctx, "distributed") and isinstance(ctx.distributed, DistributedCtx)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch._dynamo.config.recompile_limit = 64
setup_torch_runtime()
setup_default_process_group(cfg.distributed, ctx.distributed)
setup_failfast_excepthook()
setup_device_mesh(cfg.distributed, ctx.distributed)
yield ctx.distributed
57 changes: 0 additions & 57 deletions pithtrain/modules/shutdown.py

This file was deleted.

4 changes: 2 additions & 2 deletions pithtrain/tasks/pretrain_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
set_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from pithtrain.config import SlottedDefault
from pithtrain.modules import shutdown
from pithtrain.modules.checkpoint import (
to_canonical_model,
to_canonical_optim,
Expand Down Expand Up @@ -481,7 +481,7 @@ def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) ->
gc.collect()


@shutdown.record
@record
def launch(cfg: PretrainLanguageModelCfg) -> None:
"""Launch the pretraining of a language model."""
with ExitStack() as stack:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.distributed.fsdp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from transformers import AutoConfig

Expand All @@ -21,7 +22,6 @@
from pithtrain.models.deepseek_v2_lite import DeepseekV2LiteModel, DeepseekV2LiteMoEGate
from pithtrain.models.gpt_oss import GptOssExperts, GptOssModel, GptOssTopKRouter
from pithtrain.models.qwen3_30b_a3b import Qwen3MoeGate, Qwen3MoeModel
from pithtrain.modules import shutdown
from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context


Expand Down Expand Up @@ -387,7 +387,7 @@ def main(ctx: DistributedCtx, model_name: str):
torch.distributed.barrier()


@shutdown.record
@record
def _entry() -> None:
models = []
models.append("examples/pretrain_language_model/deepseek-v2-lite/config.json")
Expand Down
Loading