Skip to content

cache shape expressions for reexport #4079

Open
narendasan wants to merge 1 commit intomainfrom
narendasan/push-knqwnzwpomoz
Open

cache shape expressions for reexport #4079
narendasan wants to merge 1 commit intomainfrom
narendasan/push-knqwnzwpomoz

Conversation

@narendasan
Copy link
Collaborator

Description

Adds functionality to store shape expressions for compiled subgraph in the metadata pickle. At re-searlization time, these objects will apply these shape expressions on the input FakeTensor to describe the output shape in terms of symbolic shape.

Fixes N/A

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla bot added the cla signed label Feb 12, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Feb 12, 2026
@github-actions github-actions bot requested a review from zewenli98 February 12, 2026 07:44
@narendasan narendasan changed the title Narendasan/push knqwnzwpomoz cache shape expressions for reexport Feb 12, 2026
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:44:49.746067+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:45:24.467335+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:44:49.746067+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:45:24.468461+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:44:49.765067+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:45:27.776410+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:44:49.789067+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:45:33.790813+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:45:23.932804+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:45:58.466850+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:45:23.933804+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:45:58.480629+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:45:23.951804+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:46:02.059481+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:45:23.975805+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:46:08.702192+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from c4a5c3d to be928f9 Compare February 12, 2026 07:46
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:46:33.479227+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:47:07.112656+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:46:33.480227+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:47:07.125919+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:46:33.499227+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:47:10.629536+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:46:33.523226+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:47:16.514872+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from be928f9 to ae76517 Compare February 12, 2026 07:47
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:47:38.561177+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:48:07.873315+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:47:38.561177+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:48:07.882942+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:47:38.575177+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:48:11.244302+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:47:38.591177+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:48:17.398762+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from ae76517 to 4bcca50 Compare February 12, 2026 07:48
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:48:52.526534+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:49:27.437557+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:48:52.526534+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:49:27.441973+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:48:52.545534+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:49:30.866475+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:48:52.570534+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:49:37.109646+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,


# Replace the pytorch submodule node (call_module) with the inlined subgraph output
gm_node.replace_all_uses_with(submodule_output)
# Special handling when submodule returns multiple outputs (tuple)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure about this but it addresses some test cases

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from 4bcca50 to fb44e84 Compare February 12, 2026 07:53
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:54:07.597822+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 07:54:40.192796+00:00
@@ -21,10 +21,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -60,11 +61,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
@@ -154,11 +158,12 @@

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)

-print("""
+print(
+    """
Method 1 (Explicit torch.export.Dim):
  ✓ More control over dimension naming
  ✓ Familiar to torch.export users
  ✗ Requires specifying dynamic_shapes twice (export and save)
  ✗ More verbose
@@ -170,11 +175,12 @@
  ✓ RECOMMENDED for most use cases
  ✗ Less control over Dim naming (auto-generated)

**Recommendation**: Use Method 2 (torch_tensorrt.Input) unless you need
fine-grained control over dimension names for specific torch.export use cases.
-""")
+"""
+)

# %%
# Multiple Dynamic Dimensions Example
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:54:07.597822+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 07:54:40.197219+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:54:07.611823+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 07:54:43.443494+00:00
@@ -806,13 +806,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:54:07.627823+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 07:54:49.550244+00:00
@@ -1312,13 +1312,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch 3 times, most recently from b2cb6d0 to c751ce7 Compare February 12, 2026 22:22
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 22:22:49.449978+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_both_methods.py	2026-02-12 22:23:26.587252+00:00
@@ -19,10 +19,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
@@ -54,11 +55,14 @@
# Compile with TensorRT
trt_module_method1 = torch_tensorrt.dynamo.compile(
    exp_program,
    inputs=[
        torch_tensorrt.Input(
-            min_shape=(1, 10), opt_shape=(8, 10), max_shape=(32, 10), dtype=torch.float32
+            min_shape=(1, 10),
+            opt_shape=(8, 10),
+            max_shape=(32, 10),
+            dtype=torch.float32,
        )
    ],
    enabled_precisions={torch.float32},
    min_block_size=1,
)
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 22:22:49.449978+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/save_dynamic_shapes_example.py	2026-02-12 22:23:26.611905+00:00
@@ -20,10 +20,11 @@
import tempfile

import torch
import torch.nn as nn
import torch_tensorrt
+

# %%
# Define a simple model that we'll compile with dynamic batch size
class MyModel(nn.Module):
    def __init__(self):
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py	2026-02-12 22:22:49.468610+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py	2026-02-12 22:23:29.543799+00:00
@@ -94,13 +94,15 @@
                        try:
                            # Build substitution dict with concrete values
                            subs_dict = {
                                sym: symbol_to_concrete.get(
                                    sym,
-                                    symbol_to_symint.get(sym).node.hint
-                                    if sym in symbol_to_symint
-                                    else sym,
+                                    (
+                                        symbol_to_symint.get(sym).node.hint
+                                        if sym in symbol_to_symint
+                                        else sym
+                                    ),
                                )
                                for sym in expr.free_symbols
                            }
                            val = expr.subs(subs_dict)
                            concrete_dim = int(val)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 22:22:49.469136+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py	2026-02-12 22:23:30.407045+00:00
@@ -808,13 +808,13 @@
    Copy the metadata from anchor node to the replacement node. This should be used
    if the anchor node is replaced with only a single replacement node i.e one-one replacement.
    """
    for match_and_replacement in match_and_replacements:
        anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
-        assert len(match_and_replacement.replacements) == 1, (
-            "Found more than 1 replacements for the anchor node."
-        )
+        assert (
+            len(match_and_replacement.replacements) == 1
+        ), "Found more than 1 replacements for the anchor node."
        replacement_node = match_and_replacement.replacements[0]
        replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 22:22:49.492649+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_reexport.py	2026-02-12 22:23:37.099324+00:00
@@ -1007,12 +1007,10 @@
    dyn_height = torch.export.Dim("height", min=64, max=512)
    dyn_width = torch.export.Dim("width", min=64, max=512)
    dynamic_shapes = {"x": {0: dyn_batch, 2: dyn_height, 3: dyn_width}}

    trt_module = torchtrt.compile(model, **compile_spec)
-
-

    # Save with automatic inference of all 3 dynamic dimensions
    # retrace=True now works correctly with dynamic shapes
    torchtrt.save(
        trt_module,
@@ -1167,11 +1165,15 @@
    batch = torch.export.Dim("batch", min=1, max=8)
    dynamic_shapes = {"x": {0: batch}, "mask": {0: batch}}

    # Step 1: Export with torch.export
    exp_program = torch.export.export(
-        model, (example_x,), {"mask": example_mask}, dynamic_shapes=dynamic_shapes, strict=False
+        model,
+        (example_x,),
+        {"mask": example_mask},
+        dynamic_shapes=dynamic_shapes,
+        strict=False,
    )

    # Step 2: Compile with TensorRT using torch_tensorrt.dynamo.compile
    compile_inputs = [
        torchtrt.Input(
@@ -1322,13 +1324,11 @@
            return self.linear(x)

    model = SimpleModel().eval().cuda()

    # Static Input (single shape, not min/opt/max)
-    compile_inputs = [
-        torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")
-    ]
+    compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")]

    compile_spec = {
        "inputs": compile_inputs,
        "ir": ir,
        "min_block_size": 1,

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from c751ce7 to e65566f Compare February 12, 2026 23:41
@@ -235,7 +248,12 @@ def interpret_module_to_result(
)
else:
serialized_interpreter_result = pull_cached_engine(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 I am just threading the symbolic expressions through the cache to satisfy the type checker. Let me know if you would prefer some other way to handle this.

The way I see it the graph should have the same symbolic relationship if the fx graph matches a previous trt engine. So I think this is fine?

@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from e65566f to 55d7d16 Compare February 12, 2026 23:49
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

…the metadata to use in the case of reexport. Also removes the need to access the real tensorrt engine during reexport
@narendasan narendasan force-pushed the narendasan/push-knqwnzwpomoz branch from 55d7d16 to 28f54f9 Compare February 12, 2026 23:51
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Comment on lines +40 to +51
if not use_legacy_exporter:
# NB: PROBABLY THE MOST CONTROVERSIAL CHANGE, ARE WE AT THE POINT WHERE WE CAN JUST USE TORCH.EXPORT.EXPORT?
args = ()
if arg_inputs is not None:
args = arg_inputs if isinstance(arg_inputs, tuple) else tuple(arg_inputs)

return torch.export.export(
gm,
args=args,
kwargs=kwarg_inputs,
dynamic_shapes=dynamic_shapes,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cehongwang @zewenli98 @lanluo-nvidia need review on this. I cant really remember the reasons we have our own exporter other than the upstream one didnt work at the time with custom objects. Do we still need our own exporter?

I think we still want inlining however

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant