Skip to content

Commit bf432ad

Browse files
committed
fix: thread the MD-TRT requirement through the conversion system
1 parent 754b62b commit bf432ad

16 files changed

Lines changed: 543 additions & 80 deletions

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ std::string TRTEngine::to_str() const {
459459
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
460460
ss << " Target Platform: " << target_platform << std::endl;
461461
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
462+
ss << " Multi-Device Engine: " << (is_md) << std::endl;
462463
// clang-format on
463464
return ss.str();
464465
}

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def forward(self, x):
121121
logger.info(f"Loading from {args.save_path}")
122122
loaded_program = torch_tensorrt.load(args.save_path)
123123
output = loaded_program.module()(inp)
124+
dist.barrier()
124125
assert (python_result - output).std() < 0.01, "Result mismatch"
125126
logger.info("Load successful!")
126127

@@ -137,6 +138,8 @@ def forward(self, x):
137138
},
138139
)
139140
output = trt_model(inp)
141+
dist.barrier()
142+
140143
assert (python_result - output).std() < 0.01, "Result mismatch"
141144
logger.info("JIT compile successful!")
142145

@@ -153,6 +156,7 @@ def forward(self, x):
153156
},
154157
)
155158
output = trt_model(inp)
159+
dist.barrier()
156160
assert (python_result - output).std() < 0.01, "Result mismatch"
157161
logger.info("JIT compile successful!")
158162

@@ -169,6 +173,7 @@ def forward(self, x):
169173
use_distributed_mode_trace=True,
170174
)
171175
output = trt_model(inp)
176+
dist.barrier()
172177
assert (python_result - output).std() < 0.01, "Result mismatch"
173178

174179
# Save per-rank: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
Tensor Parallel Distributed Inference with Torch-TensorRT (torchrun)
3+
=====================================================================
4+
5+
Same model as tensor_parallel_simple_example.py but launched with
6+
torchrun / ``python -m torch_tensorrt.distributed.run`` instead of mpirun.
7+
8+
Usage
9+
-----
10+
.. code-block:: bash
11+
12+
# Single-node, 2 GPUs
13+
torchrun --nproc_per_node=2 tensor_parallel_simple_example_torchrun.py
14+
15+
# Two nodes, 1 GPU each — run on BOTH nodes simultaneously:
16+
# Node 0 (spirit):
17+
RANK=0 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \\
18+
uv run python tensor_parallel_simple_example_torchrun.py
19+
20+
# Node 1 (opportunity):
21+
RANK=1 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \\
22+
uv run python tensor_parallel_simple_example_torchrun.py
23+
24+
# Or via torchtrtrun (sets up NCCL library paths automatically):
25+
python -m torch_tensorrt.distributed.run --nproc_per_node=2 \\
26+
tensor_parallel_simple_example_torchrun.py
27+
28+
Optional args:
29+
--mode jit_python | jit_cpp | export | load (default: jit_python)
30+
--save-path /tmp/tp_model.ep
31+
--precision FP16 | BF16 | FP32 (default: FP16)
32+
--debug
33+
"""
34+
35+
import argparse
36+
import datetime
37+
import logging
38+
import os
39+
from contextlib import nullcontext
40+
41+
import torch
42+
import torch.distributed as dist
43+
import torch.nn as nn
44+
import torch.utils._pytree
45+
from torch.distributed.device_mesh import init_device_mesh
46+
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
47+
48+
torch.utils._pytree.register_constant(
49+
torch.distributed.tensor._dtensor_spec.DTensorSpec
50+
)
51+
52+
# One GPU per node; LOCAL_RANK defaults to 0 for plain env-var launch.
53+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
54+
torch.cuda.set_device(local_rank)
55+
DEVICE = torch.device(f"cuda:{local_rank}")
56+
57+
# 2-hour timeout so TRT engine building doesn't trigger the NCCL watchdog.
58+
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=2))
59+
rank = dist.get_rank()
60+
world_size = dist.get_world_size()
61+
62+
import torch_tensorrt
63+
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
64+
65+
setup_nccl_for_torch_tensorrt()
66+
67+
from torch.distributed._tensor import Shard
68+
from torch.distributed.tensor.parallel import (
69+
ColwiseParallel,
70+
RowwiseParallel,
71+
parallelize_module,
72+
)
73+
74+
logging.basicConfig(
75+
level=logging.INFO,
76+
format=f"[Rank {rank}] %(levelname)s: %(message)s",
77+
)
78+
logger = logging.getLogger(__name__)
79+
logger.info(f"dist init OK rank={rank}/{world_size} device={DEVICE}")
80+
81+
82+
class ToyModel(nn.Module):
83+
"""MLP based model"""
84+
85+
def __init__(self):
86+
super().__init__()
87+
self.in_proj = nn.Linear(10, 3200)
88+
self.relu = nn.ReLU()
89+
self.out_proj = nn.Linear(3200, 1600)
90+
self.in_proj2 = nn.Linear(1600, 500)
91+
self.out_proj2 = nn.Linear(500, 100)
92+
93+
def forward(self, x):
94+
x = self.out_proj(self.relu(self.in_proj(x)))
95+
x = self.relu(x)
96+
x = self.out_proj2(self.relu(self.in_proj2(x)))
97+
return x
98+
99+
100+
def get_model(device_mesh):
101+
assert (
102+
world_size % 2 == 0
103+
), f"TP examples require an even number of GPUs, got {world_size}"
104+
model = ToyModel().to(DEVICE)
105+
parallelize_module(
106+
module=model,
107+
device_mesh=device_mesh,
108+
parallelize_plan={
109+
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
110+
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
111+
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
112+
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
113+
},
114+
)
115+
logger.info("Model built and sharded across ranks.")
116+
return model
117+
118+
119+
def compile_torchtrt(model, args):
120+
use_fp32_acc = False
121+
use_explicit_typing = False
122+
if args.precision == "FP16":
123+
enabled_precisions = {torch.float16}
124+
use_fp32_acc = True
125+
use_explicit_typing = True
126+
elif args.precision == "BF16":
127+
enabled_precisions = {torch.bfloat16}
128+
use_explicit_typing = True
129+
else:
130+
enabled_precisions = {torch.float32}
131+
use_explicit_typing = True
132+
133+
use_python_runtime = args.mode == "jit_python"
134+
135+
with torch_tensorrt.logging.debug() if args.debug else nullcontext():
136+
trt_model = torch.compile(
137+
model,
138+
backend="torch_tensorrt",
139+
dynamic=False,
140+
options={
141+
"enabled_precisions": enabled_precisions,
142+
"use_explicit_typing": use_explicit_typing,
143+
"use_fp32_acc": use_fp32_acc,
144+
"device": DEVICE,
145+
"disable_tf32": True,
146+
"use_python_runtime": use_python_runtime,
147+
"debug": args.debug,
148+
"min_block_size": 1,
149+
"use_distributed_mode_trace": True,
150+
},
151+
)
152+
return trt_model
153+
154+
155+
if __name__ == "__main__":
156+
parser = argparse.ArgumentParser(
157+
description="Tensor Parallel Simple Example (torchrun)"
158+
)
159+
parser.add_argument(
160+
"--mode",
161+
type=str,
162+
choices=["jit_python", "jit_cpp", "export", "load"],
163+
default="jit_python",
164+
)
165+
parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep")
166+
parser.add_argument(
167+
"--precision",
168+
default="FP16",
169+
choices=["FP16", "BF16", "FP32"],
170+
)
171+
parser.add_argument("--debug", action="store_true")
172+
args = parser.parse_args()
173+
174+
device_mesh = init_device_mesh("cuda", (world_size,))
175+
176+
with torch.inference_mode():
177+
model = get_model(device_mesh)
178+
179+
torch.manual_seed(0)
180+
inp = torch.rand(20, 10, device=DEVICE)
181+
python_result = model(inp)
182+
183+
if args.mode == "load":
184+
logger.info(f"Loading from {args.save_path}")
185+
loaded_program = torch_tensorrt.load(args.save_path)
186+
output = loaded_program.module()(inp)
187+
assert (python_result - output).std() < 0.01, "Result mismatch"
188+
logger.info("Load successful!")
189+
190+
elif args.mode in ("jit_python", "jit_cpp"):
191+
trt_model = compile_torchtrt(model, args)
192+
193+
# Warmup: trigger engine build on all ranks, then barrier so no
194+
# rank races ahead to the next NCCL collective before others finish.
195+
logger.info("Warming up (triggering TRT engine build)...")
196+
_ = trt_model(inp)
197+
dist.barrier()
198+
logger.info("All ranks compiled. Running inference...")
199+
200+
output = trt_model(inp)
201+
assert (python_result - output).std() < 0.01, "Result mismatch"
202+
logger.info("JIT compile successful!")
203+
204+
elif args.mode == "export":
205+
exported_program = torch.export.export(model, (inp,), strict=False)
206+
trt_model = torch_tensorrt.dynamo.compile(
207+
exported_program,
208+
inputs=[inp],
209+
use_explicit_typing=True,
210+
use_fp32_acc=True,
211+
device=DEVICE,
212+
disable_tf32=True,
213+
use_python_runtime=False,
214+
min_block_size=1,
215+
use_distributed_mode_trace=True,
216+
assume_dynamic_shape_support=True,
217+
)
218+
output = trt_model(inp)
219+
assert (python_result - output).std() < 0.01, "Result mismatch"
220+
save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp])
221+
logger.info(f"Saved to {save_path}")
222+
dist.barrier()
223+
224+
dist.destroy_process_group()
225+
logger.info("Done!")

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ConversionContext:
2222
default_factory=CompilationSettings
2323
)
2424
requires_output_allocator: bool = False
25+
requires_multidevice: bool = False
2526
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
2627
cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list)
2728

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
cast,
1919
)
2020

21-
import tensorrt as trt
2221
import torch
2322
from torch import SymBool, SymFloat, SymInt
2423
from torch._ops import OpOverloadPacket
2524
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2625
from torch_tensorrt.dynamo._settings import CompilationSettings
2726
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2827

28+
import tensorrt as trt
29+
2930
logger = logging.getLogger(__name__)
3031

3132
LegacyConverterImplSignature = Callable[
@@ -88,6 +89,7 @@ class ConverterSupport:
8889
)
8990
supports_dynamic_shapes: bool = False
9091
requires_output_allocator: bool = False
92+
requires_multidevice: bool = False
9193

9294

9395
# Dictionary representing Dynamo aten-only converters
@@ -205,6 +207,7 @@ def dynamo_tensorrt_converter(
205207
priority: ConverterPriority = ConverterPriority.STANDARD,
206208
supports_dynamic_shapes: bool = False,
207209
requires_output_allocator: bool = False,
210+
requires_multidevice: bool = False,
208211
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
209212
"""Decorator for Dynamo TensorRT Converter
210213
@@ -222,6 +225,7 @@ def dynamo_tensorrt_converter(
222225
same target
223226
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes.
224227
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators).
228+
requires_multidevice: Boolean flag indicating if the converter creates operators which require native TensorRT multi device collectives.
225229
Returns:
226230
The converter being decorated
227231
"""
@@ -236,6 +240,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
236240
converter_implementation=converter,
237241
supports_dynamic_shapes=supports_dynamic_shapes,
238242
requires_output_allocator=requires_output_allocator,
243+
requires_multidevice=requires_multidevice,
239244
)
240245
else:
241246
assert callable(
@@ -246,6 +251,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
246251
capability_validator=capability_validator,
247252
supports_dynamic_shapes=supports_dynamic_shapes,
248253
requires_output_allocator=requires_output_allocator,
254+
requires_multidevice=requires_multidevice,
249255
)
250256

251257
# OpOverloadPackets are only valid if they have a single overload, or
@@ -477,6 +483,7 @@ def __getitem__(
477483
{
478484
"supports_dynamic_shapes": candidate.supports_dynamic_shapes,
479485
"requires_output_allocator": candidate.requires_output_allocator,
486+
"requires_multidevice": candidate.requires_multidevice,
480487
},
481488
)
482489
else:
@@ -493,6 +500,7 @@ def __getitem__(
493500
{
494501
"supports_dynamic_shapes": False,
495502
"requires_output_allocator": False,
503+
"requires_multidevice": False,
496504
},
497505
)
498506

0 commit comments

Comments
 (0)