|
| 1 | +""" |
| 2 | +Tensor Parallel Distributed Inference with Torch-TensorRT (torchrun) |
| 3 | +===================================================================== |
| 4 | +
|
| 5 | +Same model as tensor_parallel_simple_example.py but launched with |
| 6 | +torchrun / ``python -m torch_tensorrt.distributed.run`` instead of mpirun. |
| 7 | +
|
| 8 | +Usage |
| 9 | +----- |
| 10 | +.. code-block:: bash |
| 11 | +
|
| 12 | + # Single-node, 2 GPUs |
| 13 | + torchrun --nproc_per_node=2 tensor_parallel_simple_example_torchrun.py |
| 14 | +
|
| 15 | + # Two nodes, 1 GPU each — run on BOTH nodes simultaneously: |
| 16 | + # Node 0 (spirit): |
| 17 | + RANK=0 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \\ |
| 18 | + uv run python tensor_parallel_simple_example_torchrun.py |
| 19 | +
|
| 20 | + # Node 1 (opportunity): |
| 21 | + RANK=1 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \\ |
| 22 | + uv run python tensor_parallel_simple_example_torchrun.py |
| 23 | +
|
| 24 | + # Or via torchtrtrun (sets up NCCL library paths automatically): |
| 25 | + python -m torch_tensorrt.distributed.run --nproc_per_node=2 \\ |
| 26 | + tensor_parallel_simple_example_torchrun.py |
| 27 | +
|
| 28 | +Optional args: |
| 29 | + --mode jit_python | jit_cpp | export | load (default: jit_python) |
| 30 | + --save-path /tmp/tp_model.ep |
| 31 | + --precision FP16 | BF16 | FP32 (default: FP16) |
| 32 | + --debug |
| 33 | +""" |
| 34 | + |
| 35 | +import argparse |
| 36 | +import datetime |
| 37 | +import logging |
| 38 | +import os |
| 39 | +from contextlib import nullcontext |
| 40 | + |
| 41 | +import torch |
| 42 | +import torch.distributed as dist |
| 43 | +import torch.nn as nn |
| 44 | +import torch.utils._pytree |
| 45 | +from torch.distributed.device_mesh import init_device_mesh |
| 46 | +from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt |
| 47 | + |
| 48 | +torch.utils._pytree.register_constant( |
| 49 | + torch.distributed.tensor._dtensor_spec.DTensorSpec |
| 50 | +) |
| 51 | + |
| 52 | +# One GPU per node; LOCAL_RANK defaults to 0 for plain env-var launch. |
| 53 | +local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| 54 | +torch.cuda.set_device(local_rank) |
| 55 | +DEVICE = torch.device(f"cuda:{local_rank}") |
| 56 | + |
| 57 | +# 2-hour timeout so TRT engine building doesn't trigger the NCCL watchdog. |
| 58 | +dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=2)) |
| 59 | +rank = dist.get_rank() |
| 60 | +world_size = dist.get_world_size() |
| 61 | + |
| 62 | +import torch_tensorrt |
| 63 | +from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt |
| 64 | + |
| 65 | +setup_nccl_for_torch_tensorrt() |
| 66 | + |
| 67 | +from torch.distributed._tensor import Shard |
| 68 | +from torch.distributed.tensor.parallel import ( |
| 69 | + ColwiseParallel, |
| 70 | + RowwiseParallel, |
| 71 | + parallelize_module, |
| 72 | +) |
| 73 | + |
| 74 | +logging.basicConfig( |
| 75 | + level=logging.INFO, |
| 76 | + format=f"[Rank {rank}] %(levelname)s: %(message)s", |
| 77 | +) |
| 78 | +logger = logging.getLogger(__name__) |
| 79 | +logger.info(f"dist init OK rank={rank}/{world_size} device={DEVICE}") |
| 80 | + |
| 81 | + |
| 82 | +class ToyModel(nn.Module): |
| 83 | + """MLP based model""" |
| 84 | + |
| 85 | + def __init__(self): |
| 86 | + super().__init__() |
| 87 | + self.in_proj = nn.Linear(10, 3200) |
| 88 | + self.relu = nn.ReLU() |
| 89 | + self.out_proj = nn.Linear(3200, 1600) |
| 90 | + self.in_proj2 = nn.Linear(1600, 500) |
| 91 | + self.out_proj2 = nn.Linear(500, 100) |
| 92 | + |
| 93 | + def forward(self, x): |
| 94 | + x = self.out_proj(self.relu(self.in_proj(x))) |
| 95 | + x = self.relu(x) |
| 96 | + x = self.out_proj2(self.relu(self.in_proj2(x))) |
| 97 | + return x |
| 98 | + |
| 99 | + |
| 100 | +def get_model(device_mesh): |
| 101 | + assert ( |
| 102 | + world_size % 2 == 0 |
| 103 | + ), f"TP examples require an even number of GPUs, got {world_size}" |
| 104 | + model = ToyModel().to(DEVICE) |
| 105 | + parallelize_module( |
| 106 | + module=model, |
| 107 | + device_mesh=device_mesh, |
| 108 | + parallelize_plan={ |
| 109 | + "in_proj": ColwiseParallel(input_layouts=Shard(0)), |
| 110 | + "out_proj": RowwiseParallel(output_layouts=Shard(0)), |
| 111 | + "in_proj2": ColwiseParallel(input_layouts=Shard(0)), |
| 112 | + "out_proj2": RowwiseParallel(output_layouts=Shard(0)), |
| 113 | + }, |
| 114 | + ) |
| 115 | + logger.info("Model built and sharded across ranks.") |
| 116 | + return model |
| 117 | + |
| 118 | + |
| 119 | +def compile_torchtrt(model, args): |
| 120 | + use_fp32_acc = False |
| 121 | + use_explicit_typing = False |
| 122 | + if args.precision == "FP16": |
| 123 | + enabled_precisions = {torch.float16} |
| 124 | + use_fp32_acc = True |
| 125 | + use_explicit_typing = True |
| 126 | + elif args.precision == "BF16": |
| 127 | + enabled_precisions = {torch.bfloat16} |
| 128 | + use_explicit_typing = True |
| 129 | + else: |
| 130 | + enabled_precisions = {torch.float32} |
| 131 | + use_explicit_typing = True |
| 132 | + |
| 133 | + use_python_runtime = args.mode == "jit_python" |
| 134 | + |
| 135 | + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): |
| 136 | + trt_model = torch.compile( |
| 137 | + model, |
| 138 | + backend="torch_tensorrt", |
| 139 | + dynamic=False, |
| 140 | + options={ |
| 141 | + "enabled_precisions": enabled_precisions, |
| 142 | + "use_explicit_typing": use_explicit_typing, |
| 143 | + "use_fp32_acc": use_fp32_acc, |
| 144 | + "device": DEVICE, |
| 145 | + "disable_tf32": True, |
| 146 | + "use_python_runtime": use_python_runtime, |
| 147 | + "debug": args.debug, |
| 148 | + "min_block_size": 1, |
| 149 | + "use_distributed_mode_trace": True, |
| 150 | + }, |
| 151 | + ) |
| 152 | + return trt_model |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + parser = argparse.ArgumentParser( |
| 157 | + description="Tensor Parallel Simple Example (torchrun)" |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--mode", |
| 161 | + type=str, |
| 162 | + choices=["jit_python", "jit_cpp", "export", "load"], |
| 163 | + default="jit_python", |
| 164 | + ) |
| 165 | + parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep") |
| 166 | + parser.add_argument( |
| 167 | + "--precision", |
| 168 | + default="FP16", |
| 169 | + choices=["FP16", "BF16", "FP32"], |
| 170 | + ) |
| 171 | + parser.add_argument("--debug", action="store_true") |
| 172 | + args = parser.parse_args() |
| 173 | + |
| 174 | + device_mesh = init_device_mesh("cuda", (world_size,)) |
| 175 | + |
| 176 | + with torch.inference_mode(): |
| 177 | + model = get_model(device_mesh) |
| 178 | + |
| 179 | + torch.manual_seed(0) |
| 180 | + inp = torch.rand(20, 10, device=DEVICE) |
| 181 | + python_result = model(inp) |
| 182 | + |
| 183 | + if args.mode == "load": |
| 184 | + logger.info(f"Loading from {args.save_path}") |
| 185 | + loaded_program = torch_tensorrt.load(args.save_path) |
| 186 | + output = loaded_program.module()(inp) |
| 187 | + assert (python_result - output).std() < 0.01, "Result mismatch" |
| 188 | + logger.info("Load successful!") |
| 189 | + |
| 190 | + elif args.mode in ("jit_python", "jit_cpp"): |
| 191 | + trt_model = compile_torchtrt(model, args) |
| 192 | + |
| 193 | + # Warmup: trigger engine build on all ranks, then barrier so no |
| 194 | + # rank races ahead to the next NCCL collective before others finish. |
| 195 | + logger.info("Warming up (triggering TRT engine build)...") |
| 196 | + _ = trt_model(inp) |
| 197 | + dist.barrier() |
| 198 | + logger.info("All ranks compiled. Running inference...") |
| 199 | + |
| 200 | + output = trt_model(inp) |
| 201 | + assert (python_result - output).std() < 0.01, "Result mismatch" |
| 202 | + logger.info("JIT compile successful!") |
| 203 | + |
| 204 | + elif args.mode == "export": |
| 205 | + exported_program = torch.export.export(model, (inp,), strict=False) |
| 206 | + trt_model = torch_tensorrt.dynamo.compile( |
| 207 | + exported_program, |
| 208 | + inputs=[inp], |
| 209 | + use_explicit_typing=True, |
| 210 | + use_fp32_acc=True, |
| 211 | + device=DEVICE, |
| 212 | + disable_tf32=True, |
| 213 | + use_python_runtime=False, |
| 214 | + min_block_size=1, |
| 215 | + use_distributed_mode_trace=True, |
| 216 | + assume_dynamic_shape_support=True, |
| 217 | + ) |
| 218 | + output = trt_model(inp) |
| 219 | + assert (python_result - output).std() < 0.01, "Result mismatch" |
| 220 | + save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp]) |
| 221 | + logger.info(f"Saved to {save_path}") |
| 222 | + dist.barrier() |
| 223 | + |
| 224 | + dist.destroy_process_group() |
| 225 | + logger.info("Done!") |
0 commit comments