Skip to content

MD-TRT Support, Compile/Export, C++ and Python #4183

Open
narendasan wants to merge 15 commits intomainfrom
push-vqqzkszwrvyx
Open

MD-TRT Support, Compile/Export, C++ and Python #4183
narendasan wants to merge 15 commits intomainfrom
push-vqqzkszwrvyx

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

Description

Opening this to test the CI

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

apbose and others added 11 commits April 12, 2026 11:41
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8
- Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection
- Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback
- Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops
- Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain
- Build: conditional NCCL compilation via torch_nccl toolchain
- Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
…hapes

Five interconnected fixes:

1. fold_get_attr_item_calls: fold scalar param .item() calls into Python
   scalars before AOT tracing. Inside FakeTensorMode, even real-tensor
   .item() calls raise DataDependentOutputException.

2. backends.py: three changes:
   - call fold_get_attr_item_calls before entering FakeTensorMode
   - detect vmap/higher-order ops and route them through aot_autograd
     instead of aot_export_joint_simple (which doesn't handle HOPs)
   - on TRT build failure, strip TRT-only kwargs (use_fp32_acc) from
     the fallback graph before returning it to PyTorch

3. _decompositions.py: prevent SDPA from leaking back into the decomp
   table via Core ATen Interchange ops even after being removed from
   TORCH_TRT_DECOMPOSITIONS.

4. partitioning/common.py: lower the default max dynamic shape from
   min*2^16 to min*2^12 — 65536 is too large for TRT to find kernel
   implementations for attention ops.

5. _TorchTensorRTModule.py: move CPU scalar inputs to CUDA before
   execution — aot_autograd lifts scalar attributes (e.g. head_dim^-0.5)
   as explicit graph inputs; TRT requires all inputs on CUDA.

Also fixes remove_sym_nodes to match tensor sources by equality rather
than local_name so that GetItemSource bases (from torch.compile
dynamic=True) are matched correctly, and updates register_sdpa.py to
handle aten.scaled_dot_product_attention.default (the form produced after
aot_autograd) in addition to the flash/efficient variants.
@meta-cla meta-cla bot added the cla signed label Apr 12, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Apr 12, 2026
@github-actions github-actions bot requested a review from zewenli98 April 12, 2026 19:09
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the push-vqqzkszwrvyx branch 5 times, most recently from 473cff9 to 9022e03 Compare April 13, 2026 01:14

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// All inputs are expected to be on CUDA. Warn and move any that are not.
for (auto& inp : inputs) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to remove this but didnt have time to check if the device operations in python suppress this correctly

// the constructor-time bind was deferred (e.g. no collective had been issued
// at construction time, or for serialized programs loaded inline where there
// is no Python _TorchTensorRTModule.forward wrapper).
if (compiled_engine->is_md && !compiled_engine->nccl_initialized) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure this is necessary

// process group from the c10d registry. PyTorch assigns sequential
// numeric names ("0", "1", ...) to process groups; probe until we
// find one with an NCCL backend.
if (this->group_name.empty() && this->is_md) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select


def forward(self, x):
out = self.linear(x)
out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets dig into this more after the PR lands

logger = logging.getLogger("torchtrtrun")


def _get_nccl_lib_dir() -> Optional[str]:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move into its own file


self._nccl_comm: Optional[Any] = None
self._has_nccl_ops: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be set before the self.setup_engine()

inspector = self.engine.create_engine_inspector()
engine_json = inspector.get_engine_information(trt.LayerInformationFormat.JSON)
self._has_nccl_ops = "NCCL" in engine_json or "AllReduce" in engine_json

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this works
engine_json_lower = engine_json.lower()
self._has_nccl_ops = "dist_collective" in engine_json_lower or "nccl" in engine_json_lower or "allreduce" in engine_json_lower

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py	2026-04-14 20:34:53.887235+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py	2026-04-14 20:35:10.238612+00:00
@@ -97,13 +97,13 @@
        x = self.out_proj2(self.relu(self.in_proj2(x)))
        return x


def get_model(device_mesh):
-    assert world_size % 2 == 0, (
-        f"TP examples require an even number of GPUs, got {world_size}"
-    )
+    assert (
+        world_size % 2 == 0
+    ), f"TP examples require an even number of GPUs, got {world_size}"
    model = ToyModel().to(DEVICE)
    parallelize_module(
        module=model,
        device_mesh=device_mesh,
        parallelize_plan={
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py	2026-04-14 20:34:53.901309+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py	2026-04-14 20:35:11.188933+00:00
@@ -22,11 +22,13 @@

if ENABLED_FEATURES.native_trt_collectives:
    # Use native TensorRT DistCollective API (no TensorRT-LLM dependency)
    _LOGGER.info("Using native TensorRT DistCollective API for distributed operations")

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_all_gather_op, requires_multidevice=True
+    )
    def fused_nccl_gather(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
@@ -39,11 +41,13 @@
            SourceIR.ATEN,
            name,
            [args[0]],
        )

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True
+    )
    def fused_nccl_reduce_scatter(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
@@ -56,11 +60,13 @@
            SourceIR.ATEN,
            name,
            [args[0]],
        )

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True
+    )
    def fused_nccl_all_reduce(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2026-04-14 20:34:53.900894+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2026-04-14 20:35:11.722086+00:00
@@ -669,11 +669,11 @@
            cuda_engine,
            self._input_names,
            self._output_names,
            self.weight_name_map,
            self.ctx.requires_output_allocator,
-            self.ctx.requires_multidevice
+            self.ctx.requires_multidevice,
        )

    def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
        self._cur_node_name = get_node_name(n)
        self._cur_node = n
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2026-04-14 20:34:53.907339+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2026-04-14 20:35:13.284044+00:00
@@ -368,12 +368,18 @@
            # For engines with native NCCL collective layers, all ranks must
            # have a live IExecutionContext before any rank executes a
            # collective. Barrier here so a fast-compiling rank does not race
            # ahead and issue an NCCL op while another rank is still inside
            # deserialize_cuda_engine / create_execution_context.
-            if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
-                logger.debug("Barrier after execution context creation (distributed NCCL engine)")
+            if (
+                dist.is_available()
+                and dist.is_initialized()
+                and dist.get_world_size() > 1
+            ):
+                logger.debug(
+                    "Barrier after execution context creation (distributed NCCL engine)"
+                )
                dist.barrier()

        assert self.context is not None, "Failed to create execution context"
        assert self.engine.num_io_tensors == (
            len(self.input_names) + len(self.output_names)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py	2026-04-14 20:34:53.931747+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py	2026-04-14 20:35:15.878751+00:00
@@ -1525,13 +1525,11 @@

        expected = torch.full((1, 4), float(expected_sum), device=device)
        _check_close(out, expected, f"context_switch sg{i+1} rank={rank}")


-def _multirank_pg_migration(
-    rank: int, world_size: int, device: torch.device
-) -> None:
+def _multirank_pg_migration(rank: int, world_size: int, device: torch.device) -> None:
    """Compile with the default world group, run inference, then migrate to a new
    subgroup via distributed_group(new_group, model) and verify that inference
    still produces correct results — i.e. the NCCL communicator is re-bound.

    Tests both the C++ runtime (set_group_name resets nccl_initialized) and the
@@ -1562,13 +1560,11 @@
        def __init__(self, pg_name: str) -> None:
            super().__init__()
            self.pg_name = pg_name

        def forward(self, x: torch.Tensor) -> torch.Tensor:
-            out = torch.ops._c10d_functional.all_reduce.default(
-                x, "sum", self.pg_name
-            )
+            out = torch.ops._c10d_functional.all_reduce.default(x, "sum", self.pg_name)
            return torch.ops._c10d_functional.wait_tensor.default(out)

    inp = torch.full((1, 4), float(rank + 1), device=device)
    expected_sum = world_size * (world_size + 1) // 2
    expected = torch.full((1, 4), float(expected_sum), device=device)
@@ -1600,13 +1596,11 @@
        # lazy setup_nccl_comm() call.
        with distributed_group(subgroup, trt_model) as migrated_model:
            with torch.no_grad():
                out_sub = migrated_model(inp)

-        _check_close(
-            out_sub, expected, f"[{label}] migrated to subgroup rank={rank}"
-        )
+        _check_close(out_sub, expected, f"[{label}] migrated to subgroup rank={rank}")

        # ---- Step 3: set_distributed_group (persistent, outside context) ----
        subgroup2 = dist.new_group(ranks=list(range(world_size)))
        torch_tensorrt.distributed.set_distributed_group(trt_model, subgroup2)
        # _state.pg is NOT set here — Python runtime falls back to world group
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-14 20:35:16.561193+00:00
@@ -194,11 +194,13 @@

            tokenizer = AutoTokenizer.from_pretrained(args.model)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

-            input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
+            input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(
+                DEVICE
+            )
            max_len = input_ids.shape[1] + args.num_tokens

            logger.info("Running uncompiled PyTorch baseline ...")
            torch_tokens = generate(
                model, input_ids.clone(), max_len, tokenizer.eos_token_id
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 20:35:16.626752+00:00
@@ -135,16 +135,16 @@
            )
        cfg = model.config
        parallelize_module(model, device_mesh, build_tp_plan(cfg))

    cfg = model.config
-    assert cfg.num_key_value_heads % world_size == 0, (
-        f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
-    )
-    assert cfg.num_attention_heads % world_size == 0, (
-        f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
-    )
+    assert (
+        cfg.num_key_value_heads % world_size == 0
+    ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+    assert (
+        cfg.num_attention_heads % world_size == 0
+    ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"

    # After column-sharding Q/K/V, each rank holds num_heads // world_size
    # heads. Patch these so HuggingFace attention reshapes correctly.
    for layer in model.model.layers:
        layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
    parser.add_argument(
        "--sharded_checkpoint",
        type=str,
        default="",
        help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
-             "If set, skips HF weight download and loads only this rank's shard.",
+        "If set, skips HF weight download and loads only this rank's shard.",
    )
    args = parser.parse_args()

    device_mesh = init_device_mesh("cuda", (world_size,))

--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2026-04-14 20:34:53.941720+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2026-04-14 20:35:17.028688+00:00
@@ -163,11 +163,13 @@
            # block-size padding produces s1*(8+s1-s1%8)>1 guards that the
            # symbolic solver can't verify without concrete values).  Without
            # bounds, dynamo traces symbolically and TRT infers the profile
            # from the first concrete shape it sees.
            torch._dynamo.mark_dynamic(input_seq, 1)
-        position_ids = torch.arange(input_seq.shape[1], device=input_seq.device).unsqueeze(0)
+        position_ids = torch.arange(
+            input_seq.shape[1], device=input_seq.device
+        ).unsqueeze(0)
        if dynamic_seqlen_range is not None:
            torch._dynamo.mark_dynamic(position_ids, 1)
        outputs = model(input_seq, position_ids=position_ids)
        logits = outputs.logits
        next_token_logits = logits[:, -1, :]

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 21:04:44.996426+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 21:05:11.648796+00:00
@@ -135,16 +135,16 @@
            )
        cfg = model.config
        parallelize_module(model, device_mesh, build_tp_plan(cfg))

    cfg = model.config
-    assert cfg.num_key_value_heads % world_size == 0, (
-        f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
-    )
-    assert cfg.num_attention_heads % world_size == 0, (
-        f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
-    )
+    assert (
+        cfg.num_key_value_heads % world_size == 0
+    ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+    assert (
+        cfg.num_attention_heads % world_size == 0
+    ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"

    # After column-sharding Q/K/V, each rank holds num_heads // world_size
    # heads. Patch these so HuggingFace attention reshapes correctly.
    for layer in model.model.layers:
        layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
    parser.add_argument(
        "--sharded_checkpoint",
        type=str,
        default="",
        help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
-             "If set, skips HF weight download and loads only this rank's shard.",
+        "If set, skips HF weight download and loads only this rank's shard.",
    )
    args = parser.parse_args()

    device_mesh = init_device_mesh("cuda", (world_size,))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests component: torch_compile documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants