Minimal parallel-GPT training playground in pure PyTorch.
heiretsu implements composable parallelism for transformer training:
DP(data parallel): manual gradient averaging.TP(tensor parallel): column/row-sharded linears for attention + MLP.PP(pipeline parallel): GPipe-style stage split + microbatching.EP(expert parallel): Mixture-of-Experts expert sharding/routing.- Mixed precision (
--amp fp16|bf16) and optionaltorch.compile.
The topology is 4D and uses:
world_size = dp * ep * tp * pp
train.py— main trainer and CLI.gpt_model.py— GPT blocks + optional MoE blocks.topo.py— 4D process-group topology helpers.tp_linear.py— tensor-parallel linears.pipeline.py— stage wrapper + GPipe engine.moe.py,ep_comm.py— MoE routing + expert comms.tests/— forward/grad parity and smoke tests.
cd heiretsu
uv venv
source .venv/bin/activate
uv pip install -r requirements.txtDownload FineWeb GPT-2 token bins:
python data/data_bins.py 1This downloads validation and 1 train shard into data/fineweb10B.
Increase the argument for more training shards.
python train.py \
--device auto \
--data_dir data/fineweb10B \
--max_iters 50 \
--batch_size 8 \
--block_size 256torchrun --standalone --nproc_per_node=4 train.py \
--data_dir data/fineweb10B \
--dp 4 --ep 1 --tp 1 --pp 1 \
--batch_size 8 --grad_accum_steps 2 --amp bf16torchrun --standalone --nproc_per_node=4 train.py \
--data_dir data/fineweb10B \
--dp 2 --ep 1 --tp 2 --pp 1 \
--batch_size 8 --grad_accum_steps 2 --amp bf16torchrun --standalone --nproc_per_node=8 train.py \
--data_dir data/fineweb10B \
--dp 2 --ep 2 --tp 2 --pp 1 \
--num_experts 8 --top_k 2 --moe_freq 2 --aux_loss_coef 0.01 \
--batch_size 8 --grad_accum_steps 4 --amp bf16bash run_train_quick.sh--dp,--ep,--tp,--pp: parallelism degrees.--grad_accum_steps: microbatch count (especially important for PP).--num_experts,--top_k,--moe_freq: enable/configure MoE.--wandb: optional experiment logging.--compile: usetorch.compile.--dist_backend gloo: useful for CPU-only debugging.
Run the full parallel test suite:
bash tests/run_full_suite.shRun a single TP parity check:
torchrun --standalone --nproc_per_node=2 tests/tests_equiv.py --tp 2- Composable 4D process topology (
DP/EP/TP/PP). - GPT training loop with accumulation + AMP.
- MoE expert routing, load-balancing aux loss, and EP comm path.
- Manual distributed control path for learning/debugging.
- Parity and smoke tests for major parallel configurations.
- Activation checkpointing for deeper models.
- ZeRO-style optimizer/state sharding.
- Better checkpoint format for resuming across topology changes.
PP > 1schedules beyond simple GPipe fill/drain (e.g., 1F1B).- CUDA graph capture and fused kernels for throughput.
- Config files (YAML/TOML) + launch presets.
- Multi-node launcher support and networking docs.
- Richer monitoring dashboards (per-rank throughput, comm overlap).
- For best performance, use CUDA + NCCL.
- Keep
nproc_per_node == dp*ep*tp*pp. wandbis optional; training runs without it.