MD-TRT Support, Compile/Export, C++ and Python #4183
MD-TRT Support, Compile/Export, C++ and Python #4183narendasan wants to merge 15 commits intomainfrom
Conversation
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8 - Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection - Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback - Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops - Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain - Build: conditional NCCL compilation via torch_nccl toolchain - Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
…g and enable DTensor decomposition
…hapes
Five interconnected fixes:
1. fold_get_attr_item_calls: fold scalar param .item() calls into Python
scalars before AOT tracing. Inside FakeTensorMode, even real-tensor
.item() calls raise DataDependentOutputException.
2. backends.py: three changes:
- call fold_get_attr_item_calls before entering FakeTensorMode
- detect vmap/higher-order ops and route them through aot_autograd
instead of aot_export_joint_simple (which doesn't handle HOPs)
- on TRT build failure, strip TRT-only kwargs (use_fp32_acc) from
the fallback graph before returning it to PyTorch
3. _decompositions.py: prevent SDPA from leaking back into the decomp
table via Core ATen Interchange ops even after being removed from
TORCH_TRT_DECOMPOSITIONS.
4. partitioning/common.py: lower the default max dynamic shape from
min*2^16 to min*2^12 — 65536 is too large for TRT to find kernel
implementations for attention ops.
5. _TorchTensorRTModule.py: move CPU scalar inputs to CUDA before
execution — aot_autograd lifts scalar attributes (e.g. head_dim^-0.5)
as explicit graph inputs; TRT requires all inputs on CUDA.
Also fixes remove_sym_nodes to match tensor sources by equality rather
than local_name so that GetItemSource bases (from torch.compile
dynamic=True) are matched correctly, and updates register_sdpa.py to
handle aten.scaled_dot_product_attention.default (the form produced after
aot_autograd) in addition to the flash/efficient variants.
67134da to
b5b1f5f
Compare
b5b1f5f to
1957cc4
Compare
473cff9 to
9022e03
Compare
9022e03 to
e08b0c5
Compare
|
|
||
| std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) { | ||
| // All inputs are expected to be on CUDA. Warn and move any that are not. | ||
| for (auto& inp : inputs) { |
There was a problem hiding this comment.
I would like to remove this but didnt have time to check if the device operations in python suppress this correctly
| // the constructor-time bind was deferred (e.g. no collective had been issued | ||
| // at construction time, or for serialized programs loaded inline where there | ||
| // is no Python _TorchTensorRTModule.forward wrapper). | ||
| if (compiled_engine->is_md && !compiled_engine->nccl_initialized) { |
There was a problem hiding this comment.
Not entirely sure this is necessary
| // process group from the c10d registry. PyTorch assigns sequential | ||
| // numeric names ("0", "1", ...) to process groups; probe until we | ||
| // find one with an NCCL backend. | ||
| if (this->group_name.empty() && this->is_md) { |
There was a problem hiding this comment.
We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select
|
|
||
| def forward(self, x): | ||
| out = self.linear(x) | ||
| out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name) |
There was a problem hiding this comment.
Lets dig into this more after the PR lands
| logger = logging.getLogger("torchtrtrun") | ||
|
|
||
|
|
||
| def _get_nccl_lib_dir() -> Optional[str]: |
There was a problem hiding this comment.
Move into its own file
|
|
||
| self._nccl_comm: Optional[Any] = None | ||
| self._has_nccl_ops: bool = False | ||
|
|
There was a problem hiding this comment.
this should be set before the self.setup_engine()
| inspector = self.engine.create_engine_inspector() | ||
| engine_json = inspector.get_engine_information(trt.LayerInformationFormat.JSON) | ||
| self._has_nccl_ops = "NCCL" in engine_json or "AllReduce" in engine_json | ||
|
|
There was a problem hiding this comment.
something like this works
engine_json_lower = engine_json.lower()
self._has_nccl_ops = "dist_collective" in engine_json_lower or "nccl" in engine_json_lower or "allreduce" in engine_json_lower
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py 2026-04-14 20:34:53.887235+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py 2026-04-14 20:35:10.238612+00:00
@@ -97,13 +97,13 @@
x = self.out_proj2(self.relu(self.in_proj2(x)))
return x
def get_model(device_mesh):
- assert world_size % 2 == 0, (
- f"TP examples require an even number of GPUs, got {world_size}"
- )
+ assert (
+ world_size % 2 == 0
+ ), f"TP examples require an even number of GPUs, got {world_size}"
model = ToyModel().to(DEVICE)
parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py 2026-04-14 20:34:53.901309+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py 2026-04-14 20:35:11.188933+00:00
@@ -22,11 +22,13 @@
if ENABLED_FEATURES.native_trt_collectives:
# Use native TensorRT DistCollective API (no TensorRT-LLM dependency)
_LOGGER.info("Using native TensorRT DistCollective API for distributed operations")
- @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op, requires_multidevice=True)
+ @dynamo_tensorrt_converter(
+ tensorrt_fused_nccl_all_gather_op, requires_multidevice=True
+ )
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
@@ -39,11 +41,13 @@
SourceIR.ATEN,
name,
[args[0]],
)
- @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True)
+ @dynamo_tensorrt_converter(
+ tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True
+ )
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
@@ -56,11 +60,13 @@
SourceIR.ATEN,
name,
[args[0]],
)
- @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True)
+ @dynamo_tensorrt_converter(
+ tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True
+ )
def fused_nccl_all_reduce(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2026-04-14 20:34:53.900894+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2026-04-14 20:35:11.722086+00:00
@@ -669,11 +669,11 @@
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
- self.ctx.requires_multidevice
+ self.ctx.requires_multidevice,
)
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
self._cur_node = n
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2026-04-14 20:34:53.907339+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2026-04-14 20:35:13.284044+00:00
@@ -368,12 +368,18 @@
# For engines with native NCCL collective layers, all ranks must
# have a live IExecutionContext before any rank executes a
# collective. Barrier here so a fast-compiling rank does not race
# ahead and issue an NCCL op while another rank is still inside
# deserialize_cuda_engine / create_execution_context.
- if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
- logger.debug("Barrier after execution context creation (distributed NCCL engine)")
+ if (
+ dist.is_available()
+ and dist.is_initialized()
+ and dist.get_world_size() > 1
+ ):
+ logger.debug(
+ "Barrier after execution context creation (distributed NCCL engine)"
+ )
dist.barrier()
assert self.context is not None, "Failed to create execution context"
assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py 2026-04-14 20:34:53.931747+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py 2026-04-14 20:35:15.878751+00:00
@@ -1525,13 +1525,11 @@
expected = torch.full((1, 4), float(expected_sum), device=device)
_check_close(out, expected, f"context_switch sg{i+1} rank={rank}")
-def _multirank_pg_migration(
- rank: int, world_size: int, device: torch.device
-) -> None:
+def _multirank_pg_migration(rank: int, world_size: int, device: torch.device) -> None:
"""Compile with the default world group, run inference, then migrate to a new
subgroup via distributed_group(new_group, model) and verify that inference
still produces correct results — i.e. the NCCL communicator is re-bound.
Tests both the C++ runtime (set_group_name resets nccl_initialized) and the
@@ -1562,13 +1560,11 @@
def __init__(self, pg_name: str) -> None:
super().__init__()
self.pg_name = pg_name
def forward(self, x: torch.Tensor) -> torch.Tensor:
- out = torch.ops._c10d_functional.all_reduce.default(
- x, "sum", self.pg_name
- )
+ out = torch.ops._c10d_functional.all_reduce.default(x, "sum", self.pg_name)
return torch.ops._c10d_functional.wait_tensor.default(out)
inp = torch.full((1, 4), float(rank + 1), device=device)
expected_sum = world_size * (world_size + 1) // 2
expected = torch.full((1, 4), float(expected_sum), device=device)
@@ -1600,13 +1596,11 @@
# lazy setup_nccl_comm() call.
with distributed_group(subgroup, trt_model) as migrated_model:
with torch.no_grad():
out_sub = migrated_model(inp)
- _check_close(
- out_sub, expected, f"[{label}] migrated to subgroup rank={rank}"
- )
+ _check_close(out_sub, expected, f"[{label}] migrated to subgroup rank={rank}")
# ---- Step 3: set_distributed_group (persistent, outside context) ----
subgroup2 = dist.new_group(ranks=list(range(world_size)))
torch_tensorrt.distributed.set_distributed_group(trt_model, subgroup2)
# _state.pg is NOT set here — Python runtime falls back to world group
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py 2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py 2026-04-14 20:35:16.561193+00:00
@@ -194,11 +194,13 @@
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
- input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
+ input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(
+ DEVICE
+ )
max_len = input_ids.shape[1] + args.num_tokens
logger.info("Running uncompiled PyTorch baseline ...")
torch_tokens = generate(
model, input_ids.clone(), max_len, tokenizer.eos_token_id
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py 2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py 2026-04-14 20:35:16.626752+00:00
@@ -135,16 +135,16 @@
)
cfg = model.config
parallelize_module(model, device_mesh, build_tp_plan(cfg))
cfg = model.config
- assert cfg.num_key_value_heads % world_size == 0, (
- f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
- )
- assert cfg.num_attention_heads % world_size == 0, (
- f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
- )
+ assert (
+ cfg.num_key_value_heads % world_size == 0
+ ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+ assert (
+ cfg.num_attention_heads % world_size == 0
+ ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
# After column-sharding Q/K/V, each rank holds num_heads // world_size
# heads. Patch these so HuggingFace attention reshapes correctly.
for layer in model.model.layers:
layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
parser.add_argument(
"--sharded_checkpoint",
type=str,
default="",
help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
- "If set, skips HF weight download and loads only this rank's shard.",
+ "If set, skips HF weight download and loads only this rank's shard.",
)
args = parser.parse_args()
device_mesh = init_device_mesh("cuda", (world_size,))
--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2026-04-14 20:34:53.941720+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py 2026-04-14 20:35:17.028688+00:00
@@ -163,11 +163,13 @@
# block-size padding produces s1*(8+s1-s1%8)>1 guards that the
# symbolic solver can't verify without concrete values). Without
# bounds, dynamo traces symbolically and TRT infers the profile
# from the first concrete shape it sees.
torch._dynamo.mark_dynamic(input_seq, 1)
- position_ids = torch.arange(input_seq.shape[1], device=input_seq.device).unsqueeze(0)
+ position_ids = torch.arange(
+ input_seq.shape[1], device=input_seq.device
+ ).unsqueeze(0)
if dynamic_seqlen_range is not None:
torch._dynamo.mark_dynamic(position_ids, 1)
outputs = model(input_seq, position_ids=position_ids)
logits = outputs.logits
next_token_logits = logits[:, -1, :]7506223 to
bf432ad
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py 2026-04-14 21:04:44.996426+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py 2026-04-14 21:05:11.648796+00:00
@@ -135,16 +135,16 @@
)
cfg = model.config
parallelize_module(model, device_mesh, build_tp_plan(cfg))
cfg = model.config
- assert cfg.num_key_value_heads % world_size == 0, (
- f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
- )
- assert cfg.num_attention_heads % world_size == 0, (
- f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
- )
+ assert (
+ cfg.num_key_value_heads % world_size == 0
+ ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+ assert (
+ cfg.num_attention_heads % world_size == 0
+ ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
# After column-sharding Q/K/V, each rank holds num_heads // world_size
# heads. Patch these so HuggingFace attention reshapes correctly.
for layer in model.model.layers:
layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
parser.add_argument(
"--sharded_checkpoint",
type=str,
default="",
help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
- "If set, skips HF weight download and loads only this rank's shard.",
+ "If set, skips HF weight download and loads only this rank's shard.",
)
args = parser.parse_args()
device_mesh = init_device_mesh("cuda", (world_size,))
Description
Opening this to test the CI
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: