From 718f0d0a00963cc343f836cd9eb4f0110af2c2c0 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Sat, 16 May 2026 11:09:26 -0400 Subject: [PATCH 1/4] fold shutdown into distributed module --- pithtrain/modules/distributed.py | 32 ++++++++++-- pithtrain/modules/shutdown.py | 57 ---------------------- pithtrain/tasks/pretrain_language_model.py | 4 +- tests/test_fsdp.py | 4 +- 4 files changed, 32 insertions(+), 65 deletions(-) delete mode 100644 pithtrain/modules/shutdown.py diff --git a/pithtrain/modules/distributed.py b/pithtrain/modules/distributed.py index 17393a2..ccc2c61 100644 --- a/pithtrain/modules/distributed.py +++ b/pithtrain/modules/distributed.py @@ -2,6 +2,8 @@ import atexit import os +import sys +import threading from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta @@ -10,7 +12,6 @@ import torch from pithtrain.config import SlottedDefault -from pithtrain.modules import shutdown @dataclass(init=False, slots=True) @@ -138,14 +139,37 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non ctx.local_rank = int(os.environ["LOCAL_RANK"]) ctx.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - shutdown.set_heartbeat_timeout(cfg.nccl_timeout_seconds) + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0") + os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1") + os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(cfg.nccl_timeout_seconds) + kwargs = dict() kwargs["backend"] = "nccl" kwargs["device_id"] = ctx.local_rank kwargs["timeout"] = timedelta(seconds=cfg.nccl_timeout_seconds) torch.distributed.init_process_group(**kwargs) - # See pithtrain.modules.shutdown for why os._exit(1), not destroy/abort. - shutdown.install_failfast_excepthook() + + # Fail-fast on uncaught exceptions: destroy/abort_process_group drain in-flight + # NCCL work that peers will never satisfy, so the rank hangs and torchrun never + # sees the death. os._exit(1) bypasses the drain; peers' NCCL ops fail fast. + original = sys.excepthook + + def excepthook(exc_type, exc_value, exc_tb, *_): + try: + original(exc_type, exc_value, exc_tb) + except Exception: + pass + try: + sys.stdout.flush() + sys.stderr.flush() + except Exception: + pass + os._exit(1) + + sys.excepthook = excepthook + threading.excepthook = lambda args: excepthook(*args) + atexit.register(torch.distributed.destroy_process_group) torch.cuda.set_device(ctx.local_rank) diff --git a/pithtrain/modules/shutdown.py b/pithtrain/modules/shutdown.py deleted file mode 100644 index f8919f2..0000000 --- a/pithtrain/modules/shutdown.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Fast-fail shutdown when any rank raises. - -Default torch.distributed shutdown can hang indefinitely: atexit -destroy_process_group drains in-flight NCCL work that peers will never -satisfy, so the failing rank hangs in atexit, torchrun never sees the -death, and peers spin until something external kills the job. - -_abort_process_group has the same drain-deadlock in NCCL 2.28 (despite -docs claiming non-blocking). The only escape is os._exit(1) from a -sys.excepthook -- the kernel reaps FDs/sockets/CUDA IPC, peers' NCCL -ops fail in milliseconds, and torchrun SIGTERMs the rest. -""" - -import os -import sys -import threading - -import torch # noqa: F401 -from torch.distributed.elastic.multiprocessing.errors import record # noqa: F401 - - -def set_env_defaults() -> None: - """NCCL env defaults that bound failure detection. Must run before init.""" - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0") - os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1") - - -def set_heartbeat_timeout(seconds: int) -> None: - """Set the NCCL watchdog heartbeat timeout. Must run before init_process_group.""" - os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(seconds) - - -def install_failfast_excepthook() -> None: - """On uncaught exception: print, flush, os._exit(1) -- no abort/destroy.""" - prev = sys.excepthook - - def hook(exc_type, exc_value, exc_tb): - try: - prev(exc_type, exc_value, exc_tb) - except Exception: - pass - try: - sys.stdout.flush() - sys.stderr.flush() - except Exception: - pass - # No abort/destroy: both block on the drain we're escaping. - os._exit(1) - - sys.excepthook = hook - # Background threads (e.g. logging workers) would otherwise just print and - # let the main thread continue into the next collective and hang. - threading.excepthook = lambda args: hook(args.exc_type, args.exc_value, args.exc_traceback) - - -set_env_defaults() diff --git a/pithtrain/tasks/pretrain_language_model.py b/pithtrain/tasks/pretrain_language_model.py index c6bedd9..36ee2ca 100644 --- a/pithtrain/tasks/pretrain_language_model.py +++ b/pithtrain/tasks/pretrain_language_model.py @@ -21,11 +21,11 @@ set_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.elastic.multiprocessing.errors import record from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from pithtrain.config import SlottedDefault -from pithtrain.modules import shutdown from pithtrain.modules.checkpoint import ( to_canonical_model, to_canonical_optim, @@ -481,7 +481,7 @@ def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) -> gc.collect() -@shutdown.record +@record def launch(cfg: PretrainLanguageModelCfg) -> None: """Launch the pretraining of a language model.""" with ExitStack() as stack: diff --git a/tests/test_fsdp.py b/tests/test_fsdp.py index 6e6bf34..d5cdafe 100644 --- a/tests/test_fsdp.py +++ b/tests/test_fsdp.py @@ -12,6 +12,7 @@ import torch.distributed.fsdp import torch.nn as nn import torch.nn.functional as F +from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from transformers import AutoConfig @@ -21,7 +22,6 @@ from pithtrain.models.deepseek_v2_lite import DeepseekV2LiteModel, DeepseekV2LiteMoEGate from pithtrain.models.gpt_oss import GptOssExperts, GptOssModel, GptOssTopKRouter from pithtrain.models.qwen3_30b_a3b import Qwen3MoeGate, Qwen3MoeModel -from pithtrain.modules import shutdown from pithtrain.modules.distributed import DistributedCfg, DistributedCtx, distributed_context @@ -387,7 +387,7 @@ def main(ctx: DistributedCtx, model_name: str): torch.distributed.barrier() -@shutdown.record +@record def _entry() -> None: models = [] models.append("examples/pretrain_language_model/deepseek-v2-lite/config.json") From 235d557c0352ed47433d42a75b1c1c8886688d03 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Sat, 16 May 2026 11:36:00 -0400 Subject: [PATCH 2/4] switch timeout to timedelta --- pithtrain/config.py | 5 ++++- pithtrain/modules/distributed.py | 13 ++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pithtrain/config.py b/pithtrain/config.py index 70ec6e0..e363f62 100644 --- a/pithtrain/config.py +++ b/pithtrain/config.py @@ -1,6 +1,7 @@ """PithTrain base classes.""" from dataclasses import MISSING, asdict, fields +from datetime import timedelta from pathlib import Path @@ -28,9 +29,11 @@ def to_json_dict(self) -> dict: @staticmethod def _make_json_serializable(obj): - """Recursively convert non-serializable types (e.g. Path) to strings.""" + """Recursively convert non-serializable types into JSON primitives.""" if isinstance(obj, dict): return {k: SlottedDefault._make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, Path): return str(obj) + elif isinstance(obj, timedelta): + return obj.total_seconds() return obj diff --git a/pithtrain/modules/distributed.py b/pithtrain/modules/distributed.py index ccc2c61..3e3c551 100644 --- a/pithtrain/modules/distributed.py +++ b/pithtrain/modules/distributed.py @@ -50,13 +50,12 @@ class DistributedCfg(SlottedDefault): processing 2 experts. The number of experts should be divisible by the expert parallel size. """ - nccl_timeout_seconds: int = 180 + timeout: timedelta = timedelta(seconds=180) """ - Timeout for NCCL collective operations and watchdog heartbeat, in seconds. + Timeout for distributed operations. - Bounds how long a rank waits on a hung collective before NCCL's watchdog - aborts the job. Applies both to the per-collective timeout passed to - ``torch.distributed.init_process_group`` and to ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``. + Currently applied to NCCL collective operations and the watchdog heartbeat + (``torch.distributed.init_process_group`` timeout + ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``). Scale up for large multi-node or deep-pipeline runs where first-iteration setup (NCCL channel build, FSDP all-gather, checkpoint load) can be slow. @@ -142,12 +141,12 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "0") os.environ.setdefault("TORCH_NCCL_DUMP_ON_TIMEOUT", "1") - os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(cfg.nccl_timeout_seconds) + os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = str(int(cfg.timeout.total_seconds())) kwargs = dict() kwargs["backend"] = "nccl" kwargs["device_id"] = ctx.local_rank - kwargs["timeout"] = timedelta(seconds=cfg.nccl_timeout_seconds) + kwargs["timeout"] = cfg.timeout torch.distributed.init_process_group(**kwargs) # Fail-fast on uncaught exceptions: destroy/abort_process_group drain in-flight From d066629c644013d8ec65275f7bc38f0ef589559a Mon Sep 17 00:00:00 2001 From: haok1402 Date: Sat, 16 May 2026 11:45:07 -0400 Subject: [PATCH 3/4] update the docstring for distributed module. --- pithtrain/modules/distributed.py | 146 +++++++++++++++---------------- 1 file changed, 71 insertions(+), 75 deletions(-) diff --git a/pithtrain/modules/distributed.py b/pithtrain/modules/distributed.py index 3e3c551..90f79bf 100644 --- a/pithtrain/modules/distributed.py +++ b/pithtrain/modules/distributed.py @@ -16,65 +16,53 @@ @dataclass(init=False, slots=True) class DistributedCfg(SlottedDefault): - """Configuration for distributed runtime.""" + """ + Configuration for distributed runtime. - pipeline_parallel_size: int = 1 + Parallelism degrees (PP, CP, EP), FSDP2 sharding strategy, and operation timeout. DP is + inferred from the world size. """ - Degree of pipeline parallelism. - Pipeline Parallelism (PP) is a technique that assigns consecutive layers or segments of a neural network - to different GPUs. This division allows each GPU to process different stages of the network sequentially. + pipeline_parallel_size: int = 1 + """ + Degree of pipeline parallelism (PP). - For example, if a model has 12 layers and the pipeline_parallel_size is set to 4, then each GPU will - handle 3 layers. + Partition the model layers across ranks; each rank holds a consecutive slice. Forward and + backward execution is scheduled by DualPipeV. """ context_parallel_size: int = 1 """ - Degree of context parallelism. + Degree of context parallelism (CP). - Context Parallelism (CP) splits the sequence dimension across GPUs. Each GPU processes a chunk - of the full sequence. Ring attention is used to compute full causal attention across the - distributed sequence chunks, passing K/V around a ring of CP ranks. + Shard the sequence dimension across CP ranks. K/V exchange uses ring attention with a zigzag + token layout. """ expert_parallel_size: int = 1 """ - Degree of expert parallelism. + Degree of expert parallelism (EP). - Expert Parallelism (EP) is a type of model parallelism that distributes experts of an MoE across GPUs. - Unlike other model-parallel techniques, EP is applied to only the expert layers thus does not impact - the parallel mapping of the rest of the layers. - - For example, if the model has 8 experts, then setting expert_parallel_size to 4 results in each GPU - processing 2 experts. The number of experts should be divisible by the expert parallel size. + Distribute the MoE experts across ranks; non-expert layers are unaffected. Token routing uses + EP dispatch and combine kernels with token deduplication. """ - timeout: timedelta = timedelta(seconds=180) + timeout: timedelta = timedelta(minutes=15) """ Timeout for distributed operations. - Currently applied to NCCL collective operations and the watchdog heartbeat - (``torch.distributed.init_process_group`` timeout + ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC``). - - Scale up for large multi-node or deep-pipeline runs where first-iteration - setup (NCCL channel build, FSDP all-gather, checkpoint load) can be slow. - Keep small for single-node runs to fail fast. + Applied to NCCL collectives and the watchdog heartbeat. Scale up for multi-node runs; keep + small to fail fast. """ sharding_strategy: Literal["fsdp", "hsdp"] = "fsdp" """ - Sharding strategy for FSDP2. - - - ``"fsdp"`` (default): Fully Sharded Data Parallel. Shards weights, gradients, - and optimizer states across the full data-parallel mesh (``dp x cp x ep`` - for non-MoE parameters; ``dp x cp`` for MoE expert weights). Maximizes - memory savings. + FSDP2 sharding strategy. - - ``"hsdp"``: Hybrid Sharded Data Parallel. Shards within the inner mesh - dimension (``cp x ep`` for non-MoE; ``cp`` for MoE) and replicates across - the ``dp`` dimension. Lower memory savings than FSDP, but uses a smaller - FSDP group. Useful when the model fits comfortably in a single DP replica. + - "fsdp": shard parameters across the full FSDP mesh (dp x cp x ep for non-MoE; dp x cp + for MoE experts). Lowest memory. + - "hsdp": shard within the inner mesh (cp x ep for non-MoE; cp for MoE) and replicate + across dp. Pick when one DP replica fits the model. """ @@ -83,52 +71,58 @@ class DistributedCtx: """Context for distributed runtime.""" rank: int - """Global rank of the current process.""" + """Global rank of this process.""" world_size: int - """Total number of workers in the distributed job.""" + """Total number of processes.""" local_rank: int - """Local rank of the current process on the node.""" + """Local rank on the node.""" local_world_size: int - """Number of workers on the current node.""" + """Number of processes on the node.""" dp_rank: int - """Rank of the current process in the data parallel group.""" + """Rank in the DP group.""" dp_size: int - """Size of the data parallel group.""" + """Size of the DP group.""" pp_rank: int - """Rank of the current process in the pipeline parallel group.""" + """Rank in the PP group.""" pp_size: int - """Size of the pipeline parallel group.""" + """Size of the PP group.""" cp_rank: int - """Rank of the current process in the context parallel group.""" + """Rank in the CP group.""" cp_size: int - """Size of the context parallel group.""" + """Size of the CP group.""" ep_rank: int - """Rank of the current process in the expert parallel group.""" + """Rank in the EP group.""" ep_size: int - """Size of the expert parallel group.""" + """Size of the EP group.""" device_mesh: torch.distributed.DeviceMesh - """Collection of process groups for multi-dimensional parallelism.""" + """Device mesh over PP/DP/CP/EP.""" + + +def setup_torch_runtime() -> None: + """Apply torch runtime tuning: enable TF32 matmul and raise the dynamo recompile cap.""" + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high") + torch._dynamo.config.recompile_limit = 64 def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> None: """ - Setup the default process group. + Initialize the default process group from torchrun environment variables. - This function initializes the default process group using environment variables by torchrun. - It also sets the current CUDA device based on the LOCAL_RANK environment variable. A cleanup - function is registered to destroy the process group at program exit. + Read global/local rank info into ctx, apply NCCL env tuning, register cleanup at exit, and set + the current CUDA device from the local rank. """ assert torch.cuda.is_available(), "CUDA is not available." assert "TORCHELASTIC_RUN_ID" in os.environ, "Not launched with torchrun." @@ -148,10 +142,18 @@ def setup_default_process_group(cfg: DistributedCfg, ctx: DistributedCtx) -> Non kwargs["device_id"] = ctx.local_rank kwargs["timeout"] = cfg.timeout torch.distributed.init_process_group(**kwargs) + atexit.register(torch.distributed.destroy_process_group) + torch.cuda.set_device(ctx.local_rank) - # Fail-fast on uncaught exceptions: destroy/abort_process_group drain in-flight - # NCCL work that peers will never satisfy, so the rank hangs and torchrun never - # sees the death. os._exit(1) bypasses the drain; peers' NCCL ops fail fast. + +def setup_failfast_excepthook() -> None: + """ + Install a fail-fast excepthook that bypasses the NCCL drain on uncaught exceptions. + + Default torch.distributed shutdown can hang indefinitely while draining in-flight NCCL work + that peers will never satisfy. Hard-exiting bypasses the drain so NCCL wor on other ranks + fail fast instead of hanging. + """ original = sys.excepthook def excepthook(exc_type, exc_value, exc_tb, *_): @@ -169,28 +171,23 @@ def excepthook(exc_type, exc_value, exc_tb, *_): sys.excepthook = excepthook threading.excepthook = lambda args: excepthook(*args) - atexit.register(torch.distributed.destroy_process_group) - torch.cuda.set_device(ctx.local_rank) - def setup_device_mesh(cfg: DistributedCfg, ctx: DistributedCtx) -> None: """ - Setup the device mesh. + Build the (PP, DP, CP, EP) device mesh and read per-axis ranks and sizes into ctx. - Process groups are created in the following order. EP and CP are the inner-most - dimensions to keep their frequent communications within the NVLink domain. - - Mesh shape: ``(PP, DP, CP, EP)`` - - 1. Pipeline Parallel (PP) - outermost - 2. Data Parallel (DP) - 3. Context Parallel (CP) - ring attention KV exchange - 4. Expert Parallel (EP) - innermost, MoE all-to-all + Mesh dimensions go outer-to-inner: PP, DP, CP, EP. CP and EP sit innermost so frequent + collectives (ring K/V exchange, MoE all-to-all) stay within the NVLink domain. """ - ctx.ep_size = cfg.expert_parallel_size - ctx.pp_size = cfg.pipeline_parallel_size - ctx.cp_size = cfg.context_parallel_size - ctx.dp_size = ctx.world_size // (ctx.ep_size * ctx.pp_size * ctx.cp_size) + ctx.pp_size = pp_size = cfg.pipeline_parallel_size + ctx.cp_size = cp_size = cfg.context_parallel_size + ctx.ep_size = ep_size = cfg.expert_parallel_size + world_size = ctx.world_size + + divisor = pp_size * cp_size * ep_size + if world_size % divisor != 0: + raise RuntimeError(f"{world_size=} not divisible by {pp_size=} * {cp_size=} * {ep_size=}") + ctx.dp_size = world_size // divisor kwargs = dict() kwargs["device_type"] = "cuda" @@ -209,9 +206,8 @@ def distributed_context(cfg: object, ctx: object) -> Generator[DistributedCtx, N """Context manager for distributed runtime.""" assert hasattr(cfg, "distributed") and isinstance(cfg.distributed, DistributedCfg) assert hasattr(ctx, "distributed") and isinstance(ctx.distributed, DistributedCtx) - torch.backends.cuda.matmul.allow_tf32 = True - torch.set_float32_matmul_precision("high") - torch._dynamo.config.recompile_limit = 64 + setup_torch_runtime() setup_default_process_group(cfg.distributed, ctx.distributed) + setup_failfast_excepthook() setup_device_mesh(cfg.distributed, ctx.distributed) yield ctx.distributed From e724bd4d882dbb24900aaec9e76e1c4c96b5d0c5 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Sun, 17 May 2026 11:12:27 -0400 Subject: [PATCH 4/4] drop the trivial docstring with {pp,dp,cp,ep}_{rank,size} --- pithtrain/modules/distributed.py | 38 ++++++++++++-------------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/pithtrain/modules/distributed.py b/pithtrain/modules/distributed.py index 90f79bf..fcf6cf0 100644 --- a/pithtrain/modules/distributed.py +++ b/pithtrain/modules/distributed.py @@ -68,46 +68,36 @@ class DistributedCfg(SlottedDefault): @dataclass(init=False, slots=True) class DistributedCtx: - """Context for distributed runtime.""" + """ + Context for distributed runtime. + + Hold the torchrun ranks alongside the (PP, DP, CP, EP) device mesh, providing a single source + of truth that the training loop, model constructors, and collectives reference. + """ rank: int - """Global rank of this process.""" + """Global worker rank.""" world_size: int - """Total number of processes.""" + """Total number of workers.""" local_rank: int - """Local rank on the node.""" + """Worker rank within the node.""" local_world_size: int - """Number of processes on the node.""" - - dp_rank: int - """Rank in the DP group.""" + """Number of workers on the node.""" - dp_size: int - """Size of the DP group.""" + device_mesh: torch.distributed.DeviceMesh + """4D mesh over (PP, DP, CP, EP) axes.""" pp_rank: int - """Rank in the PP group.""" - pp_size: int - """Size of the PP group.""" - + dp_rank: int + dp_size: int cp_rank: int - """Rank in the CP group.""" - cp_size: int - """Size of the CP group.""" - ep_rank: int - """Rank in the EP group.""" - ep_size: int - """Size of the EP group.""" - - device_mesh: torch.distributed.DeviceMesh - """Device mesh over PP/DP/CP/EP.""" def setup_torch_runtime() -> None: