From 28f54f954839a042d38e1721afe8dd9524739e4b Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 11 Feb 2026 19:18:10 +0000 Subject: [PATCH] fix: we now store the traced symbolic functions from compile time in the metadata to use in the case of reexport. Also removes the need to access the real tensorrt engine during reexport --- core/runtime/TRTEngine.cpp | 4 + core/runtime/TRTEngine.h | 1 + core/runtime/register_jit_hooks.cpp | 1 + docsrc/user_guide/dynamic_shapes.rst | 164 ++++- docsrc/user_guide/saving_models.rst | 92 +++ .../save_dynamic_shapes_both_methods.py | 172 +++++ .../dynamo/save_dynamic_shapes_example.py | 183 +++++ py/torch_tensorrt/_compile.py | 269 ++++++-- py/torch_tensorrt/dynamo/_compiler.py | 6 +- py/torch_tensorrt/dynamo/_exporter.py | 57 +- .../dynamo/conversion/_conversion.py | 32 +- .../conversion/_symbolic_shape_capture.py | 121 ++++ .../runtime/_PythonTorchTensorRTModule.py | 7 +- .../dynamo/runtime/_TorchTensorRTModule.py | 14 +- .../runtime/meta_ops/register_meta_ops.py | 252 +++++-- py/torch_tensorrt/dynamo/utils.py | 15 +- pyproject.toml | 1 + .../test_meta_kernel_shape_inference.py | 298 +++++++++ tests/py/dynamo/models/test_reexport.py | 628 +++++++++++++++++- 19 files changed, 2173 insertions(+), 144 deletions(-) create mode 100644 examples/dynamo/save_dynamic_shapes_both_methods.py create mode 100644 examples/dynamo/save_dynamic_shapes_example.py create mode 100644 py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py create mode 100644 tests/py/dynamo/models/test_meta_kernel_shape_inference.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d122e00c9e..37148812f5 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -325,6 +325,10 @@ std::string TRTEngine::get_engine_layer_info() { return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON); } +std::string TRTEngine::get_serialized_metadata() { + return this->serialized_metadata; +} + std::vector TRTEngine::infer_outputs(std::vector> input_shapes) { std::vector outputs; TORCHTRT_CHECK( diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index bf95740bae..363631863f 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -158,6 +158,7 @@ struct TRTEngine : torch::CustomClassHolder { void set_profile_format(std::string profile_format); void disable_profiling(); std::string get_engine_layer_info(); + std::string get_serialized_metadata(); void dump_engine_layer_info_to_file(const std::string& path); void dump_engine_layer_info(); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 9baa0df32c..e8f6217a21 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file) .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) + .def("get_serialized_metadata", &TRTEngine::get_serialized_metadata) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) diff --git a/docsrc/user_guide/dynamic_shapes.rst b/docsrc/user_guide/dynamic_shapes.rst index 8bea6f0fb6..23e2e9e7ca 100644 --- a/docsrc/user_guide/dynamic_shapes.rst +++ b/docsrc/user_guide/dynamic_shapes.rst @@ -49,7 +49,7 @@ Custom Dynamic Shape Constraints --------------------------------- Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, -Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing +Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing `torch.export.Dim` objects with the provided dynamic dimensions accordingly. Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them. If you have to set any custom constraints to your model (by using `torch.export.Dim`), we recommend exporting your program first before compiling with Torch-TensorRT. Please refer to this `documentation `_ to export the Pytorch module with dynamic shapes. @@ -78,7 +78,6 @@ Here's a simple example that exports a matmul layer with some restrictions on dy # Run inference trt_gm(*inputs) - Dynamic shapes using torch.compile (JIT) ------------------------------------ @@ -102,3 +101,164 @@ to avoid recompilation of TensorRT engines. # No recompilation of TRT engines with modified batch size inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32) trt_gm(inputs_bs2) + + +Saving and Loading Models with Dynamic Shapes +---------------------------------------------- + +When you compile a model with dynamic shapes and want to save it for later use, you need to preserve the dynamic shape +specifications. Torch-TensorRT provides two methods to accomplish this: + +Method 1: Automatic Inference from torch_tensorrt.Input +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The simplest approach is to pass the same ``torch_tensorrt.Input`` objects (with min/opt/max shapes) to both ``compile()`` and ``save()``. +The dynamic shape specifications will be inferred automatically: + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + + # Define Input with dynamic shapes once + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + dtype=torch.float32, + name="x" # Optional: provides better dimension naming + ) + ] + + # Compile with dynamic shapes + trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) + + # Save - dynamic shapes inferred automatically! + torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs) + + # Load and use with different batch sizes + loaded_model = torch_tensorrt.load("model.ep").module() + output1 = loaded_model(torch.randn(4, 3, 224, 224).cuda()) # Works! + output2 = loaded_model(torch.randn(16, 3, 224, 224).cuda()) # Works! + + +Method 2: Explicit torch.export.Dim Specification +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For advanced use cases or when you need fine-grained control over dimension naming, you can explicitly provide ``dynamic_shapes`` +using ``torch.export.Dim``: + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + example_input = torch.randn((2, 3, 224, 224)).cuda() + + # Define dynamic dimensions explicitly + dyn_batch = torch.export.Dim("batch", min=1, max=32) + dynamic_shapes = {"x": {0: dyn_batch}} + + # Export with dynamic shapes + exp_program = torch.export.export( + model, (example_input,), + dynamic_shapes=dynamic_shapes, + strict=False + ) + + # Compile + trt_model = torch_tensorrt.dynamo.compile( + exp_program, + inputs=[torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + )] + ) + + # Save with explicit dynamic_shapes + torch_tensorrt.save( + trt_model, + "model.ep", + arg_inputs=[example_input], + dynamic_shapes=dynamic_shapes # Same as used during export + ) + + # Load and use + loaded_model = torch_tensorrt.load("model.ep").module() + +**When to use this method:** + - You need specific dimension names for torch.export compatibility + - You're working with existing torch.export workflows + - You require fine-grained control over dynamic dimension specifications + +Multiple Dynamic Dimensions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Both methods support multiple dynamic dimensions (e.g., dynamic batch, height, and width): + +.. code-block:: python + + # Method 1 (Automatic): Multiple dynamic dimensions + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 64, 64), + opt_shape=(8, 3, 256, 256), + max_shape=(16, 3, 512, 512), + name="image" + ) + ] + + trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) + torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs) # All 3 dims inferred! + + # Load and test with various sizes + loaded = torch_tensorrt.load("model.ep").module() + loaded(torch.randn(4, 3, 128, 128).cuda()) + loaded(torch.randn(12, 3, 384, 384).cuda()) + +Saving with Keyword Arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your model uses keyword arguments with dynamic shapes, both methods support them: + +.. code-block:: python + + # Define dynamic inputs for both args and kwargs + arg_inputs = [ + torch_tensorrt.Input( + min_shape=(1, 10), + opt_shape=(4, 10), + max_shape=(8, 10), + name="x" + ) + ] + + kwarg_inputs = { + "mask": torch_tensorrt.Input( + min_shape=(1, 5), + opt_shape=(4, 5), + max_shape=(8, 5), + name="mask" + ) + } + + # Compile + trt_model = torch_tensorrt.compile( + model, + ir="dynamo", + arg_inputs=arg_inputs, + kwarg_inputs=kwarg_inputs + ) + + # Save - both arg and kwarg dynamic shapes inferred automatically + torch_tensorrt.save( + trt_model, + "model.ep", + arg_inputs=arg_inputs, + kwarg_inputs=kwarg_inputs + ) diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index bef9b4dec3..230b240560 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -42,6 +42,98 @@ Here's an example usage model = torch.export.load("trt.ep").module() model(*inputs) + +Saving Models with Dynamic Shapes +"""""""""""""""""""""""""""""""""" + +When saving models compiled with dynamic shapes, you have two methods to preserve +the dynamic shape specifications: + +**Method 1: Using torch.export.Dim (explicit)** + +Provide explicit ``dynamic_shapes`` parameter following torch.export's pattern: + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + example_input = torch.randn((2, 3, 224, 224)).cuda() + + # Define dynamic batch dimension + dyn_batch = torch.export.Dim("batch", min=1, max=32) + dynamic_shapes = {"x": {0: dyn_batch}} + + # Export with dynamic shapes + exp_program = torch.export.export( + model, (example_input,), + dynamic_shapes=dynamic_shapes, + strict=False + ) + + # Compile with dynamic input specifications + trt_gm = torch_tensorrt.dynamo.compile( + exp_program, + inputs=[torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + )] + ) + + # Save with dynamic_shapes to preserve dynamic behavior + torch_tensorrt.save( + trt_gm, + "trt_dynamic.ep", + arg_inputs=[example_input], + dynamic_shapes=dynamic_shapes # Same as used during export + ) + + # Load and use with different batch sizes + loaded_model = torch_tensorrt.load("trt_dynamic.ep").module() + output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda()) + output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda()) + +**Method 2: Using torch_tensorrt.Input** + +Pass ``torch_tensorrt.Input`` objects with min/opt/max shapes directly, and the +dynamic shapes will be inferred automatically: + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + + # Define Input with dynamic shapes + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + dtype=torch.float32, + name="x" # Optional: provides better dimension naming + ) + ] + + # Compile with Torch-TensorRT + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) + + # Save with Input objects - dynamic_shapes inferred automatically! + torch_tensorrt.save( + trt_gm, + "trt_dynamic.ep", + arg_inputs=inputs # Dynamic shapes inferred from Input objects + ) + + # Load and use with different batch sizes + loaded_model = torch_tensorrt.load("trt_dynamic.ep").module() + output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda()) + output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda()) + + b) Torchscript ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/dynamo/save_dynamic_shapes_both_methods.py b/examples/dynamo/save_dynamic_shapes_both_methods.py new file mode 100644 index 0000000000..7f39b83c48 --- /dev/null +++ b/examples/dynamo/save_dynamic_shapes_both_methods.py @@ -0,0 +1,172 @@ +""" +.. _save_dynamic_shapes_both_methods: + +Saving Models with Dynamic Shapes - Both Methods +================================================= + +This example demonstrates BOTH methods for saving Torch-TensorRT compiled models +with dynamic input shapes: + +1. **Method 1**: Using torch.export.Dim (explicit dynamic_shapes parameter) +2. **Method 2**: Using torch_tensorrt.Input with min/opt/max (automatic inference) + +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import tempfile + +import torch +import torch.nn as nn +import torch_tensorrt + + +# %% +# Define a simple model +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + +model = SimpleModel().eval().cuda() + +# %% +# Method 1: Explicit dynamic_shapes with torch.export.Dim +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# This follows torch.export's API pattern + +example_input = torch.randn(4, 10).cuda() + +# Define dynamic dimension explicitly +dyn_batch = torch.export.Dim("batch", min=1, max=32) +dynamic_shapes = {"x": {0: dyn_batch}} + +# Export with dynamic shapes +exp_program = torch.export.export( + model, (example_input,), dynamic_shapes=dynamic_shapes, strict=False +) + +# Compile with TensorRT +trt_module_method1 = torch_tensorrt.dynamo.compile( + exp_program, + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 10), + opt_shape=(8, 10), + max_shape=(32, 10), + dtype=torch.float32, + ) + ], + enabled_precisions={torch.float32}, + min_block_size=1, +) + +with tempfile.TemporaryDirectory() as tmpdir: + save_path = f"{tmpdir}/model_method1.ep" + + # Save with explicit dynamic_shapes parameter + torch_tensorrt.save( + trt_module_method1, + save_path, + output_format="exported_program", + arg_inputs=[example_input], + dynamic_shapes=dynamic_shapes, # Explicit! + retrace=True, + ) + + # Load and test + loaded_model = torch_tensorrt.load(save_path).module() + output_bs4 = loaded_model(torch.randn(4, 10).cuda()) + output_bs16 = loaded_model(torch.randn(16, 10).cuda()) + + +# %% +# Method 2: Automatic inference from torch_tensorrt.Input +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Redefine model for fresh compile +model2 = SimpleModel().eval().cuda() + +inputs = [ + torch_tensorrt.Input( + min_shape=(1, 10), + opt_shape=(8, 10), + max_shape=(32, 10), + dtype=torch.float32, + name="x", + ) +] + +# Compile directly with torch_tensorrt.compile +trt_module_method2 = torch_tensorrt.compile(model2, ir="dynamo", inputs=inputs) + + +with tempfile.TemporaryDirectory() as tmpdir: + save_path = f"{tmpdir}/model_method2.ep" + + # Save with Input objects - dynamic_shapes inferred automatically! + # No need to specify dynamic_shapes explicitly + torch_tensorrt.save( + trt_module_method2, + save_path, + output_format="exported_program", + arg_inputs=inputs, # Pass the same Input objects used for compile + retrace=True, + ) + + # Load and test + loaded_model = torch_tensorrt.load(save_path).module() + output_bs4 = loaded_model(torch.randn(4, 10).cuda()) + output_bs16 = loaded_model(torch.randn(16, 10).cuda()) + + +# %% +# Multiple Dynamic Dimensions Example +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +class ConvModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return self.conv(x) + + +model3 = ConvModel().eval().cuda() + +# Multiple dynamic dimensions: batch, height, width +inputs_multi = [ + torch_tensorrt.Input( + min_shape=(1, 3, 64, 64), + opt_shape=(8, 3, 256, 256), + max_shape=(16, 3, 512, 512), + dtype=torch.float32, + name="image", + ) +] + +trt_module_multi = torch_tensorrt.compile(model3, ir="dynamo", inputs=inputs_multi) + +with tempfile.TemporaryDirectory() as tmpdir: + save_path = f"{tmpdir}/model_multi_dim.ep" + + torch_tensorrt.save( + trt_module_multi, + save_path, + arg_inputs=inputs_multi, # Automatically infers all 3 dynamic dims! + retrace=True, + ) + + loaded_model = torch_tensorrt.load(save_path).module() + + # Test with different shapes + out1 = loaded_model(torch.randn(4, 3, 128, 128).cuda()) + out2 = loaded_model(torch.randn(12, 3, 384, 384).cuda()) diff --git a/examples/dynamo/save_dynamic_shapes_example.py b/examples/dynamo/save_dynamic_shapes_example.py new file mode 100644 index 0000000000..b1585aeaa3 --- /dev/null +++ b/examples/dynamo/save_dynamic_shapes_example.py @@ -0,0 +1,183 @@ +""" +.. _save_dynamic_shapes: + +Saving and Loading Models with Dynamic Shapes +============================================== + +This example demonstrates how to save and load Torch-TensorRT compiled models +with dynamic input shapes. When you compile a model with dynamic shapes, +you need to preserve the dynamic shape specifications when saving the model +to ensure it can handle variable input sizes after deserialization. + +The API is designed to feel similar to torch.export's handling of dynamic shapes +for consistency and ease of use. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import tempfile + +import torch +import torch.nn as nn +import torch_tensorrt + + +# %% +# Define a simple model that we'll compile with dynamic batch size +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1) + self.relu = nn.ReLU() + self.linear = nn.Linear(16 * 224 * 224, 10) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = x.flatten(1) + x = self.linear(x) + return x + + +# %% +# Compile with Dynamic Shapes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# First, we compile the model with dynamic batch dimension + +model = MyModel().eval().cuda() + +# Define example input with batch size 2 +example_input = torch.randn(2, 3, 224, 224).cuda() + +# Define dynamic batch dimension using torch.export.Dim +# This allows batch sizes from 1 to 32 +dyn_batch = torch.export.Dim("batch", min=1, max=32) + +# Specify which dimensions are dynamic +dynamic_shapes = {"x": {0: dyn_batch}} + +# Export the model with dynamic shapes +exp_program = torch.export.export( + model, (example_input,), dynamic_shapes=dynamic_shapes, strict=False +) + +# Compile with Torch-TensorRT +compile_spec = { + "inputs": [ + torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + dtype=torch.float32, + ) + ], + "enabled_precisions": {torch.float32}, + "min_block_size": 1, +} + +trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec) + +# %% +# Test Compiled Model with Different Batch Sizes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Test with batch size 4 +input_bs4 = torch.randn(4, 3, 224, 224).cuda() +output_bs4 = trt_gm(input_bs4) + +# Test with batch size 16 +input_bs16 = torch.randn(16, 3, 224, 224).cuda() +output_bs16 = trt_gm(input_bs16) + +# %% +# Save the Model with Dynamic Shapes +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The key is to pass the same dynamic_shapes specification to save() + +with tempfile.TemporaryDirectory() as tmpdir: + save_path = f"{tmpdir}/dynamic_model.ep" + + # Save with dynamic_shapes parameter - this is crucial for preserving dynamic behavior + torch_tensorrt.save( + trt_gm, + save_path, + output_format="exported_program", + arg_inputs=[example_input], + dynamic_shapes=dynamic_shapes, # Same as used during export + ) + + # %% + # Load and Test the Saved Model + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + # Load the saved model + loaded_model = torch_tensorrt.load(save_path).module() + + # Test with the same batch sizes to verify dynamic shapes are preserved + output_loaded_bs4 = loaded_model(input_bs4) + + output_loaded_bs16 = loaded_model(input_bs16) + + assert torch.allclose(output_bs4, output_loaded_bs4, rtol=1e-3, atol=1e-3) + assert torch.allclose(output_bs16, output_loaded_bs16, rtol=1e-3, atol=1e-3) + +# %% +# Example with Multiple Dynamic Dimensions +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +class MultiDimModel(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1) + + def forward(self, x): + return self.conv(x) + + +model2 = MultiDimModel().eval().cuda() +example_input2 = torch.randn(2, 3, 128, 128).cuda() + +# Define dynamic dimensions for batch and spatial dimensions +dyn_batch2 = torch.export.Dim("batch", min=1, max=16) +dyn_height = torch.export.Dim("height", min=64, max=512) +dyn_width = torch.export.Dim("width", min=64, max=512) + +dynamic_shapes2 = {"x": {0: dyn_batch2, 2: dyn_height, 3: dyn_width}} + +exp_program2 = torch.export.export( + model2, (example_input2,), dynamic_shapes=dynamic_shapes2, strict=False +) + +compile_spec2 = { + "inputs": [ + torch_tensorrt.Input( + min_shape=(1, 3, 64, 64), + opt_shape=(8, 3, 256, 256), + max_shape=(16, 3, 512, 512), + dtype=torch.float32, + ) + ], + "enabled_precisions": {torch.float32}, +} + +trt_gm2 = torch_tensorrt.dynamo.compile(exp_program2, **compile_spec2) + +with tempfile.TemporaryDirectory() as tmpdir: + save_path2 = f"{tmpdir}/multi_dim_model.ep" + + torch_tensorrt.save( + trt_gm2, + save_path2, + output_format="exported_program", + arg_inputs=[example_input2], + dynamic_shapes=dynamic_shapes2, + ) + + loaded_model2 = torch_tensorrt.load(save_path2).module() + + # Test with different input shapes + test_input = torch.randn(4, 3, 256, 256).cuda() + output = loaded_model2(test_input) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 623bad7b9f..f739c6074d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -5,7 +5,7 @@ import platform import warnings from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Set, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import torch from torch_tensorrt._enums import dtype @@ -23,9 +23,9 @@ from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision - InputType = Union[Input, torch.Tensor, InputTensorSpec] -else: InputType = Union[Input, torch.Tensor] +else: + InputType = Union[Input, torch.Tensor] # type: ignore if ENABLED_FEATURES.torchscript_frontend: import torch_tensorrt.ts @@ -49,7 +49,13 @@ from torch_tensorrt.dynamo._compiler import ( save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program, ) + from torch_tensorrt.dynamo._defaults import default_device + from torch_tensorrt.dynamo._tracer import ( + get_dynamic_shapes_args, + get_dynamic_shapes_kwargs, + ) from torch_tensorrt.dynamo._tracer import trace as dynamo_trace + from torch_tensorrt.dynamo.utils import get_torch_inputs logger = logging.getLogger(__name__) @@ -175,7 +181,7 @@ def compile( ir: str = "default", inputs: Optional[Sequence[InputType]] = None, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, - kwarg_inputs: Optional[dict[Any, Any]] = None, + kwarg_inputs: Optional[Dict[str, Any]] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, **kwargs: Any, ) -> ( @@ -301,7 +307,6 @@ def _fx_input_interface( if not isinstance(arg_inputs, collections.abc.Sequence): arg_inputs = [arg_inputs] # type: ignore - # Export the module torchtrt_arg_inputs = prepare_inputs(arg_inputs) torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) @@ -323,7 +328,7 @@ def _fx_input_interface( raise RuntimeError("Module is an unknown format or the ir requested is unknown") -@needs_cross_compile +@needs_cross_compile # type: ignore[misc] def cross_compile_for_windows( module: torch.nn.Module, file_path: str, @@ -573,24 +578,25 @@ def load(file_path: str = "") -> Any: Raises: ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file """ + try: - logger.debug(f"Loading the provided file {file_path} using torch.jit.load()") - ts_module = torch.jit.load(file_path) - return ts_module + logger.debug(f"Loading the provided file {file_path} using torch.export.load()") + exp_program = torch.export.load(file_path) + return exp_program except Exception: logger.info( - f"Loading the provided file {file_path} via torch.jit.load() failed with the following error", + f"Loading the provided file {file_path} via torch.export.load() failed with the following error", exc_info=True, ) pass try: - logger.debug(f"Loading the provided file {file_path} using torch.export.load()") - exp_program = torch.export.load(file_path) - return exp_program + logger.debug(f"Loading the provided file {file_path} using torch.jit.load()") + ts_module = torch.jit.load(file_path) + return ts_module except Exception: logger.info( - f"Loading the provided file {file_path} via torch.export.load() failed with the following error", + f"Loading the provided file {file_path} via torch.jit.load() (after failing to load with torch.export.load()) failed with the following error", exc_info=True, ) raise ValueError( @@ -603,11 +609,12 @@ def save( file_path: str = "", *, output_format: str = "exported_program", - inputs: Optional[Sequence[torch.Tensor]] = None, - arg_inputs: Optional[Sequence[torch.Tensor]] = None, - kwarg_inputs: Optional[dict[str, Any]] = None, + inputs: Optional[Sequence[torch.Tensor | Input]] = None, + arg_inputs: Optional[Sequence[torch.Tensor | Input]] = None, + kwarg_inputs: Optional[Dict[str, Any]] = None, retrace: bool = True, pickle_protocol: int = 2, + dynamic_shapes: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """ @@ -615,23 +622,77 @@ def save( Arguments: module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module - inputs (torch.Tensor): Torch input tensors - arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. - kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. + inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications + arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function. output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor. retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. - This flag is experimental for now. + + For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported: + + - **retrace=True**: Automatically detects symbolic shape metadata in the TRT module and preserves it + without retracing. This is the recommended approach as it maintains the exact symbolic shapes + from the original compilation. + + - **retrace=False**: Directly serializes the existing graph metadata without any re-export. + This is faster but may not be compatible with all torch.export consumers. + + For static shape models, retrace=True performs a standard torch.export.export() call. + pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models + dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any, ...]]]): Dynamic shape specifications for re-exporting the model. + + **Method 1: Explicit dynamic_shapes (torch.export style)** + + Provide explicit torch.export.Dim specifications:: + + # For a single input with dynamic batch dimension + dyn_batch = torch.export.Dim("batch", min=1, max=32) + dynamic_shapes = {"x": {0: dyn_batch}} + torch_tensorrt.save(model, "model.ep", arg_inputs=[example_tensor], dynamic_shapes=dynamic_shapes) + + # For multiple inputs + dynamic_shapes = ({"x": {0: dyn_batch}}, {"y": {0: dyn_batch}}) + + **Method 2: Inferred from torch_tensorrt.Input** + + Pass torch_tensorrt.Input objects with min/opt/max shapes in arg_inputs/kwarg_inputs, + and dynamic_shapes will be inferred automatically:: + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(32, 3, 224, 224), + name="x" # Optional: name for better dim naming + ) + ] + torch_tensorrt.save(model, "model.ep", arg_inputs=inputs) # dynamic_shapes inferred! + + **Important Limitations:** + + - Automatic inference creates **separate Dim objects for each input**. If your model requires + multiple inputs to share the same dimension (e.g., matching batch sizes), you MUST use + Method 1 with explicit shared Dim objects:: + + batch = torch.export.Dim("batch", min=1, max=8) + dynamic_shapes = {"x": {0: batch}, "mask": {0: batch}} # Shared batch dimension + + - Automatic inference is **disabled for mixed Input/Tensor inputs** to avoid spurious + equality constraints. Use explicit dynamic_shapes for these cases. + + - If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes + parameter takes precedence. """ if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module module_type = _parse_module_type(module) accepted_formats = {"exported_program", "torchscript", "aot_inductor"} if arg_inputs is not None and not all( - isinstance(input, torch.Tensor) for input in arg_inputs + isinstance(input, (torch.Tensor, Input)) for input in arg_inputs ): raise ValueError( - "Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs" + "Not all inputs provided are torch.Tensor or torch_tensorrt.Input objects. Please provide inputs of a valid type" ) if arg_inputs and inputs: raise AssertionError( @@ -645,6 +706,104 @@ def save( if kwarg_inputs and any(value is None for value in kwarg_inputs.values()): raise ValueError("kwargs should not include None.") + + def _all_are_input_objects(obj: Any) -> bool: + """Recursively check if all elements in nested collections are Input objects.""" + if isinstance(obj, Input): + return True + elif isinstance(obj, (list, tuple)): + return all(_all_are_input_objects(item) for item in obj) + elif isinstance(obj, dict): + return all(_all_are_input_objects(value) for value in obj.values()) + else: + # Not an Input object or collection + return False + + all_inputs_are_input_objects = _all_are_input_objects(arg_inputs) + if kwarg_inputs: + all_inputs_are_input_objects = ( + all_inputs_are_input_objects and _all_are_input_objects(kwarg_inputs) + ) + + # Infer dynamic_shapes from Input objects if not explicitly provided + # Only infer if ALL inputs are Input objects (not mixed with Tensors) + # + # Why? When we have mixed Input/Tensor inputs, torch.export may detect that + # a dynamic Input's dimension always equals a static Tensor's dimension during + # tracing, and enforce an equality constraint. Since we create separate Dim + # objects for each input, this causes a constraint violation. Users must use + # explicit dynamic_shapes for these cases. + + # Warn if user provides both dynamic_shapes and Input objects with dynamic shapes + + arg_tensors: Tuple[torch.Tensor | int, ...] = () + kwarg_tensors: Dict[str, Any] = {} + + if all_inputs_are_input_objects: + if dynamic_shapes is not None: + has_dynamic_input_objects = any( + isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC + for inp in arg_inputs # type: ignore[union-attr] + ) + if kwarg_inputs: + has_dynamic_input_objects = has_dynamic_input_objects or any( + isinstance(inp, Input) + and inp.shape_mode == Input._ShapeMode.DYNAMIC + for inp in kwarg_inputs.values() + ) + if has_dynamic_input_objects: + logger.warning( + "Both explicit dynamic_shapes and torch_tensorrt.Input objects with min/opt/max shapes were provided. " + "The explicit dynamic_shapes parameter takes precedence and Input shape specifications will be ignored." + ) + else: + inferred_dynamic_shapes = get_dynamic_shapes_args(module, arg_inputs) + inferred_dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + + if inferred_dynamic_shapes is not None: + dynamic_shapes = inferred_dynamic_shapes + logger.info( + f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" + ) + + arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore + kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore + + else: + # Mixed case: some inputs are Tensors, some are Input objects + # Extract tensors from Input objects and use provided tensors as-is + def _extract_tensor(obj: Any) -> Any: + """Recursively extract tensors from Input objects or pass through tensors.""" + if isinstance(obj, Input): + if ( + obj.shape_mode == Input._ShapeMode.DYNAMIC + and dynamic_shapes is None + ): + logger.warning( + "Mixed torch.Tensor and torch_tensorrt.Input objects provided in the example arguments without explicit dynamic_shapes. " + "We cannot infer the dynamic shape specs from these mixed cases " + "Consider providing explicit dynamic_shapes parameter or using Input objects for all inputs." + ) + return obj.example_tensor() + elif isinstance(obj, torch.Tensor): + return obj + elif isinstance(obj, (list, tuple)): + extracted = [_extract_tensor(item) for item in obj] + return type(obj)(extracted) + elif isinstance(obj, dict): + return {key: _extract_tensor(value) for key, value in obj.items()} + else: + raise TypeError( + f"Unsupported input type: {type(obj)}. Expected torch.Tensor or torch_tensorrt.Input" + ) + + arg_tensors = _extract_tensor(arg_inputs) + kwarg_tensors = _extract_tensor(kwarg_inputs) if kwarg_inputs else {} + + # Extract tensors from Input objects for actual execution + # When inferring dynamic shapes, use different sizes for args vs kwargs to avoid + # torch.export detecting spurious equality constraints + if output_format not in accepted_formats: raise ValueError( f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript" @@ -666,10 +825,6 @@ def save( "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" ) else: - if arg_inputs is not None: - logger.warning( - "Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save." - ) torch.jit.save(module, file_path) elif module_type == _ModuleType.ep: if output_format == "torchscript": @@ -712,7 +867,13 @@ def save( logger.warning( "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save." ) - exp_program = export(module) + + exp_program = export( + module, + arg_inputs=arg_tensors, + kwarg_inputs=kwarg_tensors, + dynamic_shapes=dynamic_shapes, + ) if output_format == "exported_program": torch.export.save( exp_program, file_path, pickle_protocol=pickle_protocol @@ -732,16 +893,50 @@ def save( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" ) else: - if arg_inputs is None: - raise ValueError( - "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model" - ) - exp_program = torch.export.export( - module, - tuple(arg_inputs), - kwargs=kwarg_inputs, - strict=False, + # When retrace=True with a TRT-compiled GraphModule that has dynamic shapes, + # use the exporter to preserve symbolic metadata instead of retracing + has_symbolic_metadata = any( + isinstance(dim, torch.SymInt) + for node in module.graph.nodes + if node.op == "placeholder" and "val" in node.meta + for dim in getattr(node.meta["val"], "shape", []) ) + if has_symbolic_metadata and dynamic_shapes is not None: + # TRT module with dynamic shapes - use the exporter to preserve symbolic metadata + from torch_tensorrt.dynamo._exporter import export + + if arg_inputs is not None: + logger.info( + "Provided model is a torch.fx.GraphModule with dynamic shapes and retrace is True. " + "Using existing symbolic metadata instead of retracing. Input specs are not necessary." + ) + exp_program = export( + module, + arg_inputs=arg_tensors, + kwarg_inputs=kwarg_tensors, + dynamic_shapes=dynamic_shapes, + ) + else: + # Regular GraphModule or no dynamic shapes - retrace normally + if has_symbolic_metadata: + logger.warning( + "The provided module has symbolic metadata and retrace is True, however there is no dynamic shapes information available either explicitly or derived from arg/kwarg inputs (torch_tensorrt.Input) " + "This may lead to incorrect tracing and overly restrictive shape guards when the exported program is loaded. Please specify the dynamic shapes either explicitly or derived from arg/kwarg inputs" + ) + + if arg_inputs is None: + raise ValueError( + "Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. " + "Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model" + ) + + exp_program = torch.export.export( + module, + args=arg_tensors, + kwargs=kwarg_tensors, + dynamic_shapes=dynamic_shapes, + strict=False, + ) if output_format == "exported_program": torch.export.save( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 24c009c189..d1b0a63e2c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -564,7 +564,7 @@ def compile( if not kwargs.get("use_explicit_typing", False): warnings.warn( - "`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.", + "`use_explicit_typing` is deprecated. use_explicit_types is now on by default, this setting will be removed and you should enable autocast to recover weak typing behavior.", DeprecationWarning, stacklevel=2, ) @@ -1042,7 +1042,6 @@ def preserve_module_specs( trt_modules[name] = trt_module if _debugger_config: - if _debugger_config.save_engine_profile: if settings.use_python_runtime: if _debugger_config.profile_format != "cudagraph": @@ -1430,6 +1429,9 @@ def convert_exported_program_to_serialized_trt_engine( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) + if trt_kwarg_inputs is None: + trt_kwarg_inputs = {} + flattened_input_list = get_flat_args_with_check( exported_program, list(trt_arg_inputs), trt_kwarg_inputs )[0] diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index 17e0ad4561..1a444da458 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -22,7 +22,12 @@ def export( gm: torch.fx.GraphModule, + *, + arg_inputs: Optional[Sequence[torch.Tensor]] = None, + kwarg_inputs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Dict[str, Any]] = None, cross_compile_module: Optional[bool] = False, + use_legacy_exporter: Optional[bool] = False, ) -> ExportedProgram: """Export the result of TensorRT compilation into the desired output format. @@ -32,8 +37,26 @@ def export( cross_compile_module (bool): Flag to indicated whether it is cross_compilation enabled or not """ patched_module = transform(gm, cross_compile_module) - exp_program = create_trt_exp_program(patched_module) - return exp_program + if not use_legacy_exporter: + # NB: PROBABLY THE MOST CONTROVERSIAL CHANGE, ARE WE AT THE POINT WHERE WE CAN JUST USE TORCH.EXPORT.EXPORT? + args = () + if arg_inputs is not None: + args = arg_inputs if isinstance(arg_inputs, tuple) else tuple(arg_inputs) + + return torch.export.export( + gm, + args=args, + kwargs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + ) + else: + exp_program = create_trt_exp_program( + patched_module, + arg_inputs=arg_inputs, + kwarg_inputs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + ) + return exp_program def transform( @@ -144,7 +167,6 @@ def lift( const_placeholder_node = gm.graph.placeholder(const_placeholder_name) # Copy the node meta into this new placeholder node const_placeholder_node.meta = node.meta - if isinstance(lift_val, torch.Tensor): const_placeholder_node.meta["val"] = cast( FakeTensor, @@ -270,7 +292,26 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm.graph.erase_node(gm_added_placeholder_inputs[idx]) # Replace the pytorch submodule node (call_module) with the inlined subgraph output - gm_node.replace_all_uses_with(submodule_output) + # Special handling when submodule returns multiple outputs (tuple) + if isinstance(submodule_output, tuple): + # The fallback module has multiple outputs + # Find getitem nodes that extract from this module call and replace them directly + getitem_users = [ + user + for user in list(gm_node.users.keys()) + if user.op == "call_function" + and user.target is operator.getitem + ] + for user in getitem_users: + # getitem extracts element idx from the tuple + _, idx = user.args + # Replace this getitem with the actual node from the tuple + user.replace_all_uses_with(submodule_output[idx]) + # Erase the getitem node since it's no longer needed + gm.graph.erase_node(user) + else: + # Single output - normal replacement + gm_node.replace_all_uses_with(submodule_output) # copy the attributes of the submodule into gm (graph_copy doesn't do this) copy_submodule_attributes(gm, submodule, gm_node.name) @@ -303,6 +344,10 @@ def copy_submodule_attributes( def create_trt_exp_program( gm: torch.fx.GraphModule, + *, + arg_inputs: Optional[Sequence[torch.Tensor]] = None, + kwarg_inputs: Optional[Dict[str, torch.Tensor]] = None, + dynamic_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, ) -> ExportedProgram: """Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines and constructs an Exported Program object with the new IO node names and state_dict @@ -347,7 +392,7 @@ def create_trt_exp_program( graph=gm.graph, graph_signature=trt_graph_signature, state_dict=state_dict, - range_constraints={}, + range_constraints={}, # I feel like we need to fill this in to get dynamic shapes to work properly with this exporter module_call_graph=module_call_graph, constants=constants, ) @@ -392,6 +437,8 @@ def inline_trt_modules( else: # for the normal workflow: use the execute_engine node engine_name = f"{name}_engine" + # TODO: THROWS SOME WARNING ABOUT A LACK OF UNDERLYING REFERENCE TO THE OWNING GRAPH MODULE + # SAYS THERES 3 OPTIONS, SUBMODULE, PARAMETER, OR BUFFER, BUFFER SEEMS THE BEST BUT I THINK ITS KEYED TO TENSORS setattr(gm, engine_name, trt_module.engine) engine_node = gm.graph.get_attr(engine_name) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 58decff529..e47d3f404f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,15 +2,17 @@ import io import logging -from typing import Any, List, NamedTuple, Optional, Sequence +from typing import Any, Dict, List, NamedTuple, Optional, Sequence -import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible +from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import ( + extract_symbolic_shape_expressions, +) from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( TRTInterpreter, TRTInterpreterResult, @@ -23,6 +25,8 @@ ) from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt + logger = logging.getLogger(__name__) @@ -32,6 +36,7 @@ class SerializedInterpreterResult(NamedTuple): output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] requires_output_allocator: bool + symbolic_shape_expressions: Dict[str, List[Dict[str, Any]]] def infer_module_output_dtypes( @@ -44,7 +49,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def insert_engine_to_cache( @@ -101,6 +106,7 @@ def pull_cached_engine( engine_cache: BaseEngineCache, settings: CompilationSettings, inputs: Sequence[Input], + symbolic_shape_expressions: Dict[str, List[Dict[str, Any]]], ) -> Optional[SerializedInterpreterResult]: if hash_val is None: logger.warning( @@ -183,6 +189,7 @@ def pull_cached_engine( output_names=output_names, weight_name_map=weight_name_map, requires_output_allocator=requires_output_allocator, + symbolic_shape_expressions=symbolic_shape_expressions, ) return None @@ -203,6 +210,12 @@ def interpret_module_to_result( SerializedInterpreterResult """ + symbolic_shape_expressions = extract_symbolic_shape_expressions(module) + if symbolic_shape_expressions is None: + raise RuntimeError( + "Failed to extract symbolic shape expressions from source FX graph partition" + ) + # engine_cache could be None if: # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or # 2) both cache_built_engines and reuse_cached_engines are False @@ -235,7 +248,12 @@ def interpret_module_to_result( ) else: serialized_interpreter_result = pull_cached_engine( - hash_val, module, engine_cache, settings, inputs + hash_val, + module, + engine_cache, + settings, + inputs, + symbolic_shape_expressions, ) if serialized_interpreter_result is not None: # hit the cache return serialized_interpreter_result @@ -244,6 +262,10 @@ def interpret_module_to_result( module, truncate_double=settings.truncate_double ) + logger.debug( + f"Extracted symbolic shape expressions: {len(symbolic_shape_expressions) if symbolic_shape_expressions else 0} outputs" + ) + interpreter = TRTInterpreter( module, inputs, @@ -295,6 +317,7 @@ def interpret_module_to_result( output_names=interpreter_result.output_names, weight_name_map=interpreter_result.weight_name_map, requires_output_allocator=interpreter_result.requires_output_allocator, + symbolic_shape_expressions=symbolic_shape_expressions, ) return serialized_interpreter_result @@ -343,4 +366,5 @@ def convert_module( settings=settings, weight_name_map=serialized_interpreter_result.weight_name_map, requires_output_allocator=serialized_interpreter_result.requires_output_allocator, + symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, ) diff --git a/py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py b/py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py new file mode 100644 index 0000000000..c8bbc41b06 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py @@ -0,0 +1,121 @@ +""" +Capture symbolic shape expressions from FX graphs for TRT meta kernel. + +This module extracts the symbolic relationship between input and output shapes +at compile time, which can then be used by the meta kernel to correctly infer +output shapes without pattern matching. +""" + +import logging +from typing import Any, Dict, List, Optional + +import torch + +logger = logging.getLogger(__name__) + + +def extract_symbolic_shape_expressions( + module: torch.fx.GraphModule, +) -> Optional[Dict[str, List[Dict[str, Any]]]]: + """ + Extract symbolic shape expressions from an FX graph. + + This captures the symbolic expressions (as sympy expressions) for input and output shapes + that can be applied to input fake tensors at runtime. + + Args: + module: FX GraphModule with symbolic shapes in node metadata + + Returns: + Dict with 'inputs' and 'outputs' keys, each containing a list of dicts with shape_exprs and dtype, + or None if extraction fails + """ + # Find input nodes (placeholders) + input_nodes = [node for node in module.graph.nodes if node.op == "placeholder"] + + # Find output node + output_nodes = [node for node in module.graph.nodes if node.op == "output"] + if not output_nodes: + return None + + output_node = output_nodes[0] + + # Collect shape expressions and dtypes for each input + input_info = [] + for input_node in input_nodes: + if not hasattr(input_node, "meta") or "val" not in input_node.meta: + logger.warning( + "When processing symbolic shapes for TensorRT engine, found no metadata in input node" + ) + return None + + input_val = input_node.meta["val"] + if not isinstance(input_val, torch.Tensor): + logger.warning( + "When processing symbolic shapes for TensorRT engine, input is not a tensor" + ) + return None + + # Extract shape as sympy expressions (can be pickled) + shape_exprs = [] + for dim_size in input_val.shape: + if isinstance(dim_size, torch.SymInt): + # Store the sympy expression, which can be pickled + shape_exprs.append(dim_size.node.expr) + else: + # Store concrete integer + shape_exprs.append(int(dim_size)) + + input_info.append( + { + "shape_exprs": shape_exprs, + "dtype": input_val.dtype, + "name": input_node.name, + } + ) + + # Extract output values from output node + output_args = output_node.args[0] + if not isinstance(output_args, (tuple, list)): + output_args = (output_args,) + + # Collect shape expressions and dtypes for each output + output_info = [] + for out_arg in output_args: + if not hasattr(out_arg, "meta") or "val" not in out_arg.meta: + logger.warning( + "When processing symbolic shapes for TensorRT engine, found no metadata in FX Graph" + ) + return None + + out_val = out_arg.meta["val"] + if not isinstance(out_val, torch.Tensor): + logger.warning( + "When processing symbolic shapes for TensorRT engine, output is not a tensor" + ) + return None + + # Extract shape as sympy expressions (can be pickled) + shape_exprs = [] + for dim_size in out_val.shape: + if isinstance(dim_size, torch.SymInt): + # Store the sympy expression, which can be pickled + shape_exprs.append(dim_size.node.expr) + else: + # Store concrete integer + shape_exprs.append(int(dim_size)) + + output_info.append( + { + "shape_exprs": shape_exprs, + "dtype": out_val.dtype, + } + ) + + if not output_info: + return None + + return { + "inputs": input_info, + "outputs": output_info, + } diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 12f1ce28c7..31182bbe21 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,7 +4,6 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple -import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -22,6 +21,8 @@ multi_gpu_device_check, ) +import tensorrt as trt + logger = logging.getLogger(__name__) @@ -131,6 +132,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, + symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, _debugger_config: Optional[DebuggerConfig] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs @@ -146,6 +148,7 @@ def __init__( settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + symbolic_shape_expressions (List[str]): List of symbolic shape expressions for each output binding Example: @@ -222,6 +225,7 @@ def __init__( self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() # If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass self.output_tensors_are_unowned = False + self.symbolic_shape_expressions = symbolic_shape_expressions if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -462,7 +466,6 @@ def create_output_allocator(self) -> None: self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: shape_changed = self.validate_input_shapes(contiguous_inputs) ( diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 7ed48fdb7f..669c82739a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -4,9 +4,10 @@ import copy import logging import pickle -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch + from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform from torch_tensorrt._features import ( @@ -87,6 +88,7 @@ def __init__( settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, + symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -106,6 +108,7 @@ def __init__( settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + symbolic_shape_expressions (List[Any]): List of symbolic shape expressions for each input binding Example: @@ -143,6 +146,7 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources + self.symbolic_shape_expressions = symbolic_shape_expressions if ( serialized_engine @@ -160,6 +164,7 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, + "inout_symexprs": self.symbolic_shape_expressions, "output_tensors_are_unowned": ( False if self.engine is None @@ -308,6 +313,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] self.output_tensors_are_unowned = metadata["output_tensors_are_unowned"] + self.symbolic_shape_expression = metadata["inout_symexprs"] self.engine.set_output_tensors_as_unowned(self.output_tensors_are_unowned) else: @@ -335,9 +341,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """ if self.engine is None: raise RuntimeError("Engine has not been setup yet.") - - assert len(inputs) == len( - self.input_binding_names + assert len(inputs) == len(self.input_binding_names), ( + f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." + ) ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." # If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index 434f434ad5..e03c88153c 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -1,9 +1,163 @@ import base64 -from collections import defaultdict -from typing import Any, List +import logging +from typing import Any, Dict, List import torch -from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + +logger = logging.getLogger(__name__) + + +def _apply_symbolic_shape_expressions( + inputs: List[torch.Tensor], shape_info: Dict[str, List[Dict[str, Any]]] +) -> List[torch.Tensor]: + """ + Apply symbolic shape expressions to create output fake tensors. + + This applies the shape expressions captured at compile time to the current + input fake tensors' symbolic context, using the input alignment to map + symbolic dimensions. + + Args: + inputs: Input fake tensors with current symbolic shapes + shape_info: Dict with 'inputs' and 'outputs' keys containing shape_exprs and dtype info + + Returns: + List of output fake tensors with symbolic shapes + """ + from torch._guards import detect_fake_mode + + logger.debug( + f"[torch.ops.tensorrt.execute_engine]: Meta kernel found the following input FakeTensors: {inputs}" + ) + + input_info = shape_info.get("inputs", []) + output_info = shape_info.get("outputs", []) + + fake_mode = detect_fake_mode(inputs) + if fake_mode is None: + # No fake mode - shouldn't happen, but fall back to concrete shapes + outputs = [] + for info in output_info: + shape = [ + int(s) if not hasattr(s, "is_Symbol") else 1 + for s in info["shape_exprs"] + ] + outputs.append( + torch.empty(shape, dtype=info["dtype"], device=inputs[0].device) + ) + return outputs + + # Build a mapping from compile-time symbolic expressions to runtime SymInts + # by aligning captured input info with actual runtime input tensors + symbol_to_symint = {} + symbol_to_concrete = {} + shape_env = None + + # Align inputs: for each captured input, match it with the corresponding runtime input + for idx, (inp_tensor, inp_info) in enumerate(zip(inputs, input_info)): + for d, s in zip(inp_tensor.shape, inp_info["shape_exprs"]): + if isinstance(d, torch.SymInt): + symbol_to_symint[s] = d + if shape_env is None: + shape_env = d.node.shape_env + + elif isinstance(d, int): + symbol_to_concrete[s] = d + + logger.debug( + f"[torch.ops.tensorrt.execute_engine]: Meta kernel captured and mapped symbol from input {inp_tensor} (compile time symbol: {s}, new symbol: {d})" + ) + + # Create output fake tensors with symbolic shapes + logger.debug(f"Deserialized output shape expressions: {output_info}") + outputs = [] + with fake_mode: + for output_num, info in enumerate(output_info): + output_shape = [] + for expr in info["shape_exprs"]: + if isinstance(expr, int): + # Concrete dimension + output_shape.append(expr) + else: + logger.debug(f"Symbolic expression: {expr}") + # Symbolic expression (sympy expr) + + # Check if this expression uses any symbols that are now concrete + has_concrete_symbols = any( + sym in symbol_to_concrete for sym in expr.free_symbols + ) + + if has_concrete_symbols: + # Case 2: Some compile-time symbols are now concrete ints + # Evaluate the expression to a concrete value + try: + # Build substitution dict with concrete values + subs_dict = {} + for sym in expr.free_symbols: + if sym in symbol_to_concrete: + subs_dict[sym] = symbol_to_concrete[sym] + elif sym in symbol_to_symint: + subs_dict[sym] = symbol_to_symint[sym].node.hint + else: + subs_dict[sym] = sym + + val = expr.subs(subs_dict) + concrete_dim = int(val) + output_shape.append(concrete_dim) + logger.debug( + f"Evaluated {expr} to concrete value {concrete_dim} using concrete mappings" + ) + except Exception as e: + raise RuntimeError( + f"[torch.ops.tensorrt.execute_engine]: Failed to evaluate symbolic expression {expr} " + f"with concrete values. Free symbols: {expr.free_symbols}, " + f"Concrete mappings: {symbol_to_concrete}, " + f"SymInt mappings: {list(symbol_to_symint.keys())}. Error: {e}" + ) + elif expr in symbol_to_symint: + # Case 1a: Direct mapping - compile-time symbol is represented by runtime SymInt + output_shape.append(symbol_to_symint[expr]) + logger.debug( + f"Reused SymInt from input: {expr} -> {symbol_to_symint[expr]}" + ) + elif shape_env is not None: + # Case 1b: Create new SymInt from expression using existing SymInts + try: + # Calculate hint by substituting known values + hint_val = expr.subs( + { + sym: symbol_to_symint[sym].node.hint + for sym in expr.free_symbols + if sym in symbol_to_symint + } + ) + hint = int(hint_val) if hint_val.is_number else None + + # Create new SymInt from the expression + output_symint = shape_env.create_symintnode(expr, hint=hint) + output_shape.append(output_symint) + logger.debug( + f"Created new SymInt for {expr} with hint {hint}" + ) + except Exception as e: + raise RuntimeError( + f"[torch.ops.tensorrt.execute_engine]: Failed to create SymInt for expression {expr}. " + f"Error: {e}" + ) + else: + raise RuntimeError( + "[torch.ops.tensorrt.execute_engine]: No shape_env available during meta kernel execution" + ) + + outputs.append( + torch.empty(output_shape, dtype=info["dtype"], device=inputs[0].device) + ) + logger.debug( + f"[torch.ops.tensorrt.execute_engine]: Meta kernel found the following output FakeTensors: {outputs}" + ) + + return outputs @torch.library.register_fake("aten::cudnn_grid_sampler") # type: ignore @@ -40,71 +194,36 @@ def fake_tensorrt_execute_engine( inputs: List[torch.Tensor], fake_trt_engine: Any ) -> Any: """ - We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel. + Meta kernel for TensorRT engine execution. + + Uses symbolic shape expressions captured at compile time to correctly infer + output shapes while preserving symbolic SymInt relationships. """ - # Here's what we are doing - # 1) Check if inputs are dynamic (they have sym ints in their shapes) - # 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs - # 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism - # 4) Create a new symbolic fake tensor using min and max output shape for each output and return them - # 5) For static inputs, the output shape will be static and we won't need to create sym ints - is_dynamic_execution = input_is_dynamic(inputs) - if is_dynamic_execution: - modes = ["min", "max", "opt"] + + metadata = None + if hasattr(fake_trt_engine, "real_obj"): + # Wrapped C++ engine with real_obj + trt_engine = fake_trt_engine.real_obj + metadata = TorchTensorRTModule.decode_metadata( + trt_engine.get_serialized_metadata() + ) else: - modes = ["opt"] - - # Get the TRTEngine class and infer output shapes based on input shapes - trt_engine = fake_trt_engine.real_obj - outputs_mode_dict = defaultdict(list) - for mode in modes: - input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs] - proxy_outputs = trt_engine.infer_outputs(input_shapes) - outputs_mode_dict[mode].extend(proxy_outputs) - - # Store the number of outputs - if {"min", "max"}.issubset(outputs_mode_dict): - assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"]) - num_outputs = len(outputs_mode_dict["min"]) - elif "opt" in outputs_mode_dict: - num_outputs = len(outputs_mode_dict["opt"]) - - fake_outputs = [] - for out_idx in range(num_outputs): - output_shape = [] - if is_dynamic_execution: - # Create output symbolic shape using unbacked symint. - # Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0) - # and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem - # to affect compilation results / serialization during our testing. - output_min_shape = outputs_mode_dict["min"][out_idx].size() - output_opt_shape = outputs_mode_dict["opt"][out_idx].size() - output_max_shape = outputs_mode_dict["max"][out_idx].size() - - ctx = torch._custom_ops.get_ctx() - for min_val, opt_val, max_val in zip( - output_min_shape, output_opt_shape, output_max_shape - ): - if min_val != max_val: - output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val) - # Update var to val (hint) - output_sym_int_shape_env = output_sym_int.node.shape_env - output_sym_int_shape_env.set_unbacked_var_to_val( - output_sym_int.node.expr, opt_val - ) - output_shape.append(output_sym_int) - else: - output_shape.append(min_val) - else: - output_shape.extend(outputs_mode_dict["opt"][out_idx].size()) - fake_outputs.append( - torch.empty( - output_shape, - dtype=outputs_mode_dict["opt"][out_idx].dtype, - device=inputs[0].device, - ) + metadata = TorchTensorRTModule.decode_metadata( + fake_trt_engine.get_serialized_metadata() + ) + + shape_info = metadata.get("inout_symexprs") if metadata else None + + if shape_info: + # Apply the symbolic shape expressions to create output fake tensors + # shape_info now contains both 'inputs' and 'outputs' keys + return _apply_symbolic_shape_expressions(inputs, shape_info) + else: + raise RuntimeError( + "No symbolic shape expressions found in TensorRT engine metadata. " + "This engine may have been compiled with an older version of Torch-TensorRT. " + "Please recompile your model." ) - return fake_outputs @torch._library.register_fake_class("tensorrt::Engine") @@ -176,6 +295,9 @@ def infer_outputs(self, input_shapes: List[Any]) -> Any: def reset_captured_graph(self) -> Any: pass + def get_serialized_metadata(self) -> Any: + return self.serialized_metadata + def __setstate__(self, serialized_state: List[str]) -> Any: pass diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index abc697a086..10d2dfbbee 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -22,7 +22,6 @@ import numpy as np import psutil import sympy -import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily @@ -36,6 +35,7 @@ from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings +import tensorrt as trt from packaging import version from .types import TRTDataType @@ -418,9 +418,15 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]: # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy # expr.xreplace replaces the symbolic variables with their current values and computes the expression. var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(expr) + # Handle both old and new PyTorch API for unbacked variables + unbacked_var_dict = getattr(shape_env, "unbacked_var_to_val", None) or getattr( + shape_env, "backed_var_to_val", {} + ) + + # TODO: VAR TO VAL IS DEPRECATED WE SHOULD BE USING BACKED_VAR_TO_VAL var_val = ( shape_env.var_to_val.get(expr, None) - or shape_env.unbacked_var_to_val.get(expr, None) + or unbacked_var_dict.get(expr, None) or expr.xreplace(shape_env.var_to_val) ) assert var_range, var_val @@ -699,8 +705,9 @@ def check_module_output( arg_inputs: Any, kwarg_inputs: Any = None, ) -> bool: - old_outputs, new_outputs = refitted_module(*arg_inputs), new_module( - *arg_inputs, **kwarg_inputs + old_outputs, new_outputs = ( + refitted_module(*arg_inputs), + new_module(*arg_inputs, **kwarg_inputs), ) if type(old_outputs) != type(new_outputs): logger.warning("The output types are different. Output check is skipped.") diff --git a/pyproject.toml b/pyproject.toml index aab906b389..2e1fbb41b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ test = [ "expecttest==0.1.6", "timm>=1.0.3", "transformers>=4.49.0", + "torchvision>=0.26.0.dev" ] quantization = ["nvidia-modelopt[all]>=0.27.1"] diff --git a/tests/py/dynamo/models/test_meta_kernel_shape_inference.py b/tests/py/dynamo/models/test_meta_kernel_shape_inference.py new file mode 100644 index 0000000000..173a2e366f --- /dev/null +++ b/tests/py/dynamo/models/test_meta_kernel_shape_inference.py @@ -0,0 +1,298 @@ +""" +Test meta kernel shape inference by running TRT modules in fake mode. + +Each test exports a model, compiles with TRT, then runs the TRT module in fake +mode to verify the meta kernel correctly infers symbolic output shapes. + +The test approach: +1. Export a model with dynamic shapes to get an exported program with symbolic SymInts +2. Compile the exported program with Torch-TensorRT +3. Extract the symbolic fake input from the exported program +4. Run both the exported program and TRT module with the same symbolic fake input +5. Verify that both produce the same symbolic output shapes + +Currently, tests with dynamic shapes are marked as xfail because the meta kernel +does not preserve symbolic SymInt dimensions - it creates new unbacked symints instead +of reusing the input SymInts. This is a known limitation. +""" + +import pytest +import torch +import torch_tensorrt +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Dim + + +class TestMetaKernelShapeInference: + """Test meta kernel by running TRT modules in fake mode""" + + def _test_in_fake_mode(self, model, test_input, dynamic_shapes=None): + """ + Helper that exports model, compiles with TRT, runs in fake mode. + Returns (exported_output, trt_output, fake_input) for shape comparison. + """ + # Export with dynamic shapes + if dynamic_shapes: + exported = torch.export.export( + model, args=(test_input,), dynamic_shapes=dynamic_shapes + ) + else: + exported = torch.export.export(model, args=(test_input,)) + + # Compile with TRT + compiled = torch_tensorrt.compile( + exported, inputs=[test_input], min_block_size=1 + ) + + # Get the fake input from the exported program - it has symbolic shapes + from torch._guards import detect_fake_mode + + fake_input = None + for node in exported.graph.nodes: + if node.op == "placeholder" and node.name == "x" and "val" in node.meta: + fake_input = node.meta["val"] + break + + assert ( + fake_input is not None + ), "Could not find input placeholder 'x' in exported program" + + # Get the fake mode + fake_mode = detect_fake_mode((fake_input,)) + assert fake_mode is not None, "Could not detect fake mode from exported program" + + # Run both exported and compiled in the same fake mode + with fake_mode: + exported_output = exported.module()(fake_input) + trt_output = compiled(fake_input) + + return exported_output, trt_output, fake_input + + def test_identity_static(self): + """Test meta kernel with static shapes (identity operation)""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 64, 64).cuda() + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input + ) + + print(f"Input shape: {fake_input.shape}") + print(f"Exported output shape: {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Both should produce same shape + assert exported_output.shape == trt_output.shape + assert trt_output.shape == (4, 64, 64, 64) + + def test_downsample_static(self): + """Test meta kernel with static shapes (stride=2 downsampling)""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 64, 64).cuda() + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input + ) + + print(f"Input shape: {fake_input.shape}") + print(f"Exported output shape: {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Both should produce same downsampled shape + assert exported_output.shape == trt_output.shape + assert trt_output.shape == (4, 64, 32, 32) + + def test_dynamic_batch(self): + """Test meta kernel preserves symbolic batch dimension""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 64, 64).cuda() + + batch = Dim("batch", min=1, max=8) + dynamic_shapes = {"x": {0: batch}} + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input, dynamic_shapes + ) + + print(f"Input shape: {fake_input.shape}") + print(f"Exported output shape: {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Both should have symbolic batch + assert isinstance( + fake_input.shape[0], torch.SymInt + ), "Input batch should be symbolic" + assert isinstance( + exported_output.shape[0], torch.SymInt + ), "Exported output batch should be symbolic" + assert isinstance( + trt_output.shape[0], torch.SymInt + ), "TRT output batch should be symbolic" + + # Shapes should match + assert exported_output.shape == trt_output.shape + + def test_arithmetic_h_div_2(self): + """Test meta kernel infers h//2 symbolic relationship""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.conv(x)) + h = x.shape[2] + return x[:, :, : h // 2, :] + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 64, 64).cuda() + + batch = Dim("batch", min=1, max=8) + h_base = Dim("h_base", min=16, max=64) + w_base = Dim("w_base", min=16, max=64) + dynamic_shapes = {"x": {0: batch, 2: 2 * h_base, 3: 2 * w_base}} + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input, dynamic_shapes + ) + + print(f"Input shape (height=2*h_base): {fake_input.shape}") + print(f"Exported output shape (height=h_base): {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Height should be symbolic and correctly inferred + assert isinstance( + fake_input.shape[2], torch.SymInt + ), "Input height should be symbolic" + assert isinstance( + exported_output.shape[2], torch.SymInt + ), "Exported output height should be symbolic" + assert isinstance( + trt_output.shape[2], torch.SymInt + ), "TRT output height should be symbolic" + + # Shapes should match + assert exported_output.shape == trt_output.shape + + def test_stride_2_dynamic(self): + """Test meta kernel infers h//2 from stride=2 convolution""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 64, 64).cuda() + + batch = Dim("batch", min=1, max=8) + h_base = Dim("h_base", min=16, max=64) + w_base = Dim("w_base", min=16, max=64) + # Input must be even for stride=2 + dynamic_shapes = {"x": {0: batch, 2: 2 * h_base, 3: 2 * w_base}} + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input, dynamic_shapes + ) + + print(f"Input shape (2*h_base): {fake_input.shape}") + print(f"Exported output shape (h_base): {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Height should be symbolic + assert isinstance( + fake_input.shape[2], torch.SymInt + ), "Input height should be symbolic" + assert isinstance( + exported_output.shape[2], torch.SymInt + ), "Exported output height should be symbolic" + assert isinstance( + trt_output.shape[2], torch.SymInt + ), "TRT output height should be symbolic" + + # Shapes should match + assert exported_output.shape == trt_output.shape + + def test_concat(self): + """Test meta kernel with concat operation (concatenates on height dimension)""" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(self.conv(x)) + # Concatenate x with itself on height dimension to double the height + return torch.cat([x, x], dim=2) + + model = Model().eval().cuda() + test_input = torch.randn(4, 3, 32, 32).cuda() + + batch = Dim("batch", min=1, max=8) + h = Dim("h", min=16, max=64) + w = Dim("w", min=16, max=64) + dynamic_shapes = {"x": {0: batch, 2: h, 3: w}} + + exported_output, trt_output, fake_input = self._test_in_fake_mode( + model, test_input, dynamic_shapes + ) + + print(f"Input shape: {fake_input.shape}") + print(f"Exported output shape (2*h): {exported_output.shape}") + print(f"TRT output shape: {trt_output.shape}") + + # Height should be symbolic (2x input from concat) + assert isinstance( + fake_input.shape[2], torch.SymInt + ), "Input height should be symbolic" + assert isinstance( + exported_output.shape[2], torch.SymInt + ), "Exported output height should be symbolic" + assert isinstance( + trt_output.shape[2], torch.SymInt + ), "TRT output height should be symbolic" + + # Shapes should match + assert exported_output.shape == trt_output.shape + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/py/dynamo/models/test_reexport.py b/tests/py/dynamo/models/test_reexport.py index 9636c9d91a..7ad544811c 100644 --- a/tests/py/dynamo/models/test_reexport.py +++ b/tests/py/dynamo/models/test_reexport.py @@ -478,11 +478,15 @@ def test_resnet18_dynamic(ir, tmpdir): ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - # Reexport with dynamic dimensions - trt_exp_program = torch.export.export( - trt_module, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False + # Save with torch_tensorrt.save() which handles dynamic shapes properly + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=[input_bs2], + dynamic_shapes=({0: dyn_batch},), + retrace=True, ) - torch.export.save(trt_exp_program, trt_ep_path) # TODO: Enable this serialization issues are fixed deser_trt_module = torchtrt.load(trt_ep_path).module() @@ -556,14 +560,15 @@ def test_resnet18_dynamic_fallback(ir, tmpdir): ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - # Reexport with dynamic dimensions - trt_exp_program = torch.export.export( + # Save with torch_tensorrt.save() which handles dynamic shapes properly + torchtrt.save( trt_module, - (input_bs2,), - strict=False, + trt_ep_path, + output_format="exported_program", + arg_inputs=[input_bs2], dynamic_shapes=({0: dyn_batch},), + retrace=True, ) - torch.export.save(trt_exp_program, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input_bs2) @@ -635,14 +640,15 @@ def forward(self, lhs_val, rhs_val): ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - # Reexport with dynamic dimensions - trt_exp_program = torch.export.export( + # Save with torch_tensorrt.save() which handles dynamic shapes properly + torchtrt.save( trt_module, - inputs_4, - strict=False, + trt_ep_path, + output_format="exported_program", + arg_inputs=list(inputs_4), dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}}, + retrace=True, ) - torch.export.save(trt_exp_program, trt_ep_path) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(*inputs_4) @@ -727,14 +733,15 @@ def forward(self, x): ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - # Reexport with dynamic dimensions - trt_exp_program = torch.export.export( + # Save with torch_tensorrt.save() which handles dynamic shapes properly + torchtrt.save( trt_module, - torch_inputs_bs50, - strict=False, + trt_ep_path, + output_format="exported_program", + arg_inputs=list(torch_inputs_bs50), dynamic_shapes=({0: dyn_dim},), + retrace=True, ) - torch.export.save(trt_exp_program, trt_ep_path) # Test with BS=50 deser_trt_module = torchtrt.load(trt_ep_path).module() @@ -769,3 +776,586 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_random_dynamic_fallback TRT outputs don't match with the original model with inputs bs=62. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_save_with_dynamic_shapes_api(ir, tmpdir): + """ + This tests the torch_tensorrt.save() API with dynamic_shapes parameter + to preserve dynamic shape specifications during serialization + """ + + trt_ep_path = os.path.join(tmpdir, "trt_dynamic.ep") + model = models.resnet18().eval().cuda() + input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ], + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + # Define dynamic shapes + dyn_batch = torch.export.Dim("batch", min=1, max=8) + dynamic_shapes = {"x": {0: dyn_batch}} + + # Export with dynamic shapes + exp_program = torch.export.export( + model, (input_bs2,), dynamic_shapes=dynamic_shapes, strict=False + ) + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + + # Use the new torch_tensorrt.save() API with dynamic_shapes parameter + # Using retrace=True (should work now with fixed meta kernel and symbolic input handling) + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=[input_bs2], + dynamic_shapes=dynamic_shapes, # Preserve dynamic shapes + retrace=True, + ) + + # Load and test with different batch sizes + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test with batch size 2 + outputs_pyt = model(input_bs2) + outputs_trt = trt_module(input_bs2) + outputs_trt_deser = deser_trt_module(input_bs2) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_dynamic_shapes_api TRT outputs don't match with the original model for batch size=2. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_dynamic_shapes_api deserialized TRT outputs don't match with the original model for batch size=2. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Test with batch size 6 + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + outputs_pyt = model(input_bs6) + outputs_trt = trt_module(input_bs6) + outputs_trt_deser = deser_trt_module(input_bs6) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_dynamic_shapes_api TRT outputs don't match with the original model for batch size=6. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_dynamic_shapes_api deserialized TRT outputs don't match with the original model for batch size=6. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_save_with_input_objects_inferred_dynamic_shapes(ir, tmpdir): + """ + This tests the torch_tensorrt.save() API with torch_tensorrt.Input objects + that have min/opt/max shapes. The dynamic_shapes should be inferred automatically. + This is Method 2 - the recommended approach. + """ + + trt_ep_path = os.path.join(tmpdir, "trt_input_inferred.ep") + model = models.resnet18().eval().cuda() + + # Define Input objects with dynamic shapes + compile_inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ] + + compile_spec = { + "arg_inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + # Note: We're NOT using torch.export.export here, going directly through compile + trt_module = torchtrt.compile(model, **compile_spec) + + # Use the new torch_tensorrt.save() API with Input objects + # Dynamic shapes should be inferred automatically - no explicit dynamic_shapes needed! + # Note: retrace=False because retracing a TRT-compiled module with dynamic shapes + # causes issues with unbacked symints from TensorRT runtime + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=compile_inputs, # Pass Input objects, not tensors + retrace=True, + ) + + # Load and test with different batch sizes + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test with batch size 2 + input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda") + outputs_pyt = model(input_bs2) + outputs_trt = trt_module(input_bs2) + outputs_trt_deser = deser_trt_module(input_bs2) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_input_objects_inferred_dynamic_shapes TRT outputs don't match for batch size=2. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_input_objects_inferred_dynamic_shapes deserialized TRT outputs don't match for batch size=2. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Test with batch size 7 + input_bs7 = torch.randn((7, 3, 224, 224)).to("cuda") + outputs_pyt = model(input_bs7) + outputs_trt = trt_module(input_bs7) + outputs_trt_deser = deser_trt_module(input_bs7) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_input_objects_inferred_dynamic_shapes TRT outputs don't match for batch size=7. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_input_objects_inferred_dynamic_shapes deserialized TRT outputs don't match for batch size=7. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_save_inferred_dynamic_shapes_multiple_dimensions(ir, tmpdir): + """ + Test automatic dynamic shape inference with multiple dynamic dimensions + (batch, height, width) + + NOTE: This test is skipped because torch.export cannot properly serialize + ExportedPrograms with multiple dynamic dimensions when retrace=False (causes + missing value range info), and retrace=True causes unbacked symint issues with + TRT-compiled modules. Use test_save_with_dynamic_shapes_api with explicit + dynamic_shapes parameter instead. + """ + + trt_ep_path = os.path.join(tmpdir, "trt_multi_dim.ep") + + class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return self.conv(x) + + model = ConvModel().eval().cuda() + + # Define Input with 3 dynamic dimensions + compile_inputs = [ + torchtrt.Input( + min_shape=(1, 3, 64, 64), + opt_shape=(4, 3, 256, 256), + max_shape=(8, 3, 512, 512), + dtype=torch.float32, + name="x", + ) + ] + + compile_spec = { + "inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + # Generate dynamic shape specs + dyn_batch = torch.export.Dim("batch", min=1, max=8) + dyn_height = torch.export.Dim("height", min=64, max=512) + dyn_width = torch.export.Dim("width", min=64, max=512) + dynamic_shapes = {"x": {0: dyn_batch, 2: dyn_height, 3: dyn_width}} + + trt_module = torchtrt.compile(model, **compile_spec) + + # Save with automatic inference of all 3 dynamic dimensions + # retrace=True now works correctly with dynamic shapes + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=compile_inputs, + dynamic_shapes=dynamic_shapes, + retrace=True, + ) + + # Load and test with various sizes + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test different combinations of batch, height, width + test_shapes = [ + (2, 3, 128, 128), + (6, 3, 384, 384), + (1, 3, 64, 64), # Min + (8, 3, 512, 512), # Max + ] + + for shape in test_shapes: + input_tensor = torch.randn(shape).to("cuda") + outputs_pyt = model(input_tensor) + outputs_trt = trt_module(input_tensor) + outputs_trt_deser = deser_trt_module(input_tensor) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_inferred_dynamic_shapes_multiple_dimensions TRT outputs don't match for shape {shape}. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_inferred_dynamic_shapes_multiple_dimensions deserialized TRT outputs don't match for shape {shape}. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_save_mixed_static_dynamic_inputs(ir, tmpdir): + """ + Test saving with mixed static (tensor) and dynamic (Input) inputs + + NOTE: This scenario requires explicit dynamic_shapes because automatic inference + cannot distinguish between dimensions that should be independent vs. equal. + """ + + trt_ep_path = os.path.join(tmpdir, "trt_mixed.ep") + + class MixedInputModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x, bias): + # x has dynamic batch, bias is a fixed-size tensor that broadcasts + out = self.linear(x) + # bias shape [1, 5] broadcasts to [batch, 5] + return out + bias + + model = MixedInputModel().eval().cuda() + + compile_inputs = [ + torchtrt.Input( + min_shape=(1, 10), + opt_shape=(4, 10), + max_shape=(8, 10), + dtype=torch.float32, + name="x", + ), + torchtrt.Input( + shape=(1, 5), # Fixed size bias + dtype=torch.float32, + name="bias", + ), + ] + + compile_spec = { + "inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_module = torchtrt.compile(model, **compile_spec) + + # Save with explicit dynamic_shapes + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=compile_inputs, + retrace=True, + ) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test with different batch sizes for dynamic input + for batch_size in [2, 6]: + input_x = torch.randn(batch_size, 10).cuda() + input_bias = torch.randn(1, 5).cuda() # Same fixed bias + + outputs_pyt = model(input_x, input_bias) + outputs_trt_deser = deser_trt_module(input_x, input_bias) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_mixed_static_dynamic_inputs outputs don't match for batch size {batch_size}. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_save_with_kwarg_inputs_dynamic(ir, tmpdir): + """ + Test saving with dynamic shapes in kwarg_inputs + + NOTE: When multiple inputs share the same dynamic dimension (e.g., batch size), + you must explicitly declare this by sharing a Dim object: + + batch = Dim("batch", min=1, max=8) + dynamic_shapes = {"x": {0: batch}, "mask": {0: batch}} + + Automatic inference creates separate Dim objects which causes torch.export + to detect an equality constraint violation. + """ + + trt_ep_path = os.path.join(tmpdir, "trt_kwargs.ep") + + class KwargModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x, *, mask): + # Apply linear transformation and multiply by mask + out = self.linear(x) + out = out * mask + return out + + model = KwargModel().eval().cuda() + + # Create example tensors for export + example_x = torch.randn(4, 10).cuda() + example_mask = torch.randn(4, 5).cuda() + + # Define dynamic shapes with shared batch dimension + # Both inputs share the same batch Dim object to express equality constraint + batch = torch.export.Dim("batch", min=1, max=8) + dynamic_shapes = {"x": {0: batch}, "mask": {0: batch}} + + # Step 1: Export with torch.export + exp_program = torch.export.export( + model, + (example_x,), + {"mask": example_mask}, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # Step 2: Compile with TensorRT using torch_tensorrt.dynamo.compile + compile_inputs = [ + torchtrt.Input( + min_shape=(1, 10), + opt_shape=(4, 10), + max_shape=(8, 10), + dtype=torch.float32, + name="x", + ), + torchtrt.Input( + min_shape=(1, 5), + opt_shape=(4, 5), + max_shape=(8, 5), + dtype=torch.float32, + name="mask", + ), + ] + + compile_spec = { + "inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + + # Save with explicit dynamic_shapes + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=(example_x,), + kwarg_inputs={"mask": example_mask}, + dynamic_shapes=dynamic_shapes, + retrace=True, + ) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test with different batch sizes + for batch_size in [2, 6]: + input_x = torch.randn(batch_size, 10).cuda() + input_mask = torch.randn(batch_size, 5).cuda() + + outputs_pyt = model(input_x, mask=input_mask) + outputs_trt_deser = deser_trt_module(input_x, mask=input_mask) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_with_kwarg_inputs_dynamic outputs don't match for batch size {batch_size}. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_explicit_dynamic_shapes_takes_precedence(ir, tmpdir): + """ + Test that explicit dynamic_shapes parameter takes precedence over + inferred dynamic shapes from Input objects + """ + + trt_ep_path = os.path.join(tmpdir, "trt_precedence.ep") + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel().eval().cuda() + example_input = torch.randn(4, 10).cuda() + + # Define both Input objects AND explicit dynamic_shapes + compile_inputs = [ + torchtrt.Input( + min_shape=(1, 10), + opt_shape=(4, 10), + max_shape=(8, 10), + dtype=torch.float32, + name="x", + ) + ] + + # Explicit dynamic_shapes with custom naming + dyn_batch = torch.export.Dim("custom_batch_name", min=1, max=8) + dynamic_shapes = {"x": {0: dyn_batch}} + + exp_program = torch.export.export( + model, (example_input,), dynamic_shapes=dynamic_shapes, strict=False + ) + + compile_spec = { + "inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + + # Save with BOTH compile_inputs and explicit dynamic_shapes + # Explicit should take precedence + # retrace=True now works correctly with dynamic shapes + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=[example_input], + dynamic_shapes=dynamic_shapes, # Explicit takes precedence + retrace=True, + ) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Test with different batch sizes + for batch_size in [2, 7]: + input_tensor = torch.randn(batch_size, 10).cuda() + outputs_pyt = model(input_tensor) + outputs_trt_deser = deser_trt_module(input_tensor) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_explicit_dynamic_shapes_takes_precedence outputs don't match for batch size {batch_size}. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_save_static_inputs_no_dynamic_inference(ir, tmpdir): + """ + Test that static Input objects (without min/opt/max) don't trigger + dynamic shape inference + """ + + trt_ep_path = os.path.join(tmpdir, "trt_static.ep") + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel().eval().cuda() + + # Static Input (single shape, not min/opt/max) + compile_inputs = [torchtrt.Input(shape=(4, 10), dtype=torch.float32, name="x")] + + compile_spec = { + "inputs": compile_inputs, + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_module = torchtrt.compile(model, **compile_spec) + + # Save - should NOT infer dynamic shapes (all inputs are static) + torchtrt.save( + trt_module, + trt_ep_path, + output_format="exported_program", + arg_inputs=compile_inputs, + retrace=True, + ) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + + # Should only work with the exact shape + input_tensor = torch.randn(4, 10).cuda() + outputs_pyt = model(input_tensor) + outputs_trt_deser = deser_trt_module(input_tensor) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_static_inputs_no_dynamic_inference outputs don't match. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + )