Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ std::string TRTEngine::get_engine_layer_info() {
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
}

std::string TRTEngine::get_serialized_metadata() {
return this->serialized_metadata;
}

std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
std::vector<at::Tensor> outputs;
TORCHTRT_CHECK(
Expand Down
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ struct TRTEngine : torch::CustomClassHolder {
void set_profile_format(std::string profile_format);
void disable_profiling();
std::string get_engine_layer_info();
std::string get_serialized_metadata();

void dump_engine_layer_info_to_file(const std::string& path);
void dump_engine_layer_info();
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("get_serialized_metadata", &TRTEngine::get_serialized_metadata)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)
Expand Down
164 changes: 162 additions & 2 deletions docsrc/user_guide/dynamic_shapes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Custom Dynamic Shape Constraints
---------------------------------

Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
`torch.export.Dim` objects with the provided dynamic dimensions accordingly. Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
If you have to set any custom constraints to your model (by using `torch.export.Dim`), we recommend exporting your program first before compiling with Torch-TensorRT.
Please refer to this `documentation <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_ to export the Pytorch module with dynamic shapes.
Expand Down Expand Up @@ -78,7 +78,6 @@ Here's a simple example that exports a matmul layer with some restrictions on dy
# Run inference
trt_gm(*inputs)


Dynamic shapes using torch.compile (JIT)
------------------------------------

Expand All @@ -102,3 +101,164 @@ to avoid recompilation of TensorRT engines.
# No recompilation of TRT engines with modified batch size
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
trt_gm(inputs_bs2)


Saving and Loading Models with Dynamic Shapes
----------------------------------------------

When you compile a model with dynamic shapes and want to save it for later use, you need to preserve the dynamic shape
specifications. Torch-TensorRT provides two methods to accomplish this:

Method 1: Automatic Inference from torch_tensorrt.Input
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The simplest approach is to pass the same ``torch_tensorrt.Input`` objects (with min/opt/max shapes) to both ``compile()`` and ``save()``.
The dynamic shape specifications will be inferred automatically:

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()

# Define Input with dynamic shapes once
inputs = [
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(32, 3, 224, 224),
dtype=torch.float32,
name="x" # Optional: provides better dimension naming
)
]

# Compile with dynamic shapes
trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)

# Save - dynamic shapes inferred automatically!
torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs)

# Load and use with different batch sizes
loaded_model = torch_tensorrt.load("model.ep").module()
output1 = loaded_model(torch.randn(4, 3, 224, 224).cuda()) # Works!
output2 = loaded_model(torch.randn(16, 3, 224, 224).cuda()) # Works!


Method 2: Explicit torch.export.Dim Specification
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For advanced use cases or when you need fine-grained control over dimension naming, you can explicitly provide ``dynamic_shapes``
using ``torch.export.Dim``:

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
example_input = torch.randn((2, 3, 224, 224)).cuda()

# Define dynamic dimensions explicitly
dyn_batch = torch.export.Dim("batch", min=1, max=32)
dynamic_shapes = {"x": {0: dyn_batch}}

# Export with dynamic shapes
exp_program = torch.export.export(
model, (example_input,),
dynamic_shapes=dynamic_shapes,
strict=False
)

# Compile
trt_model = torch_tensorrt.dynamo.compile(
exp_program,
inputs=[torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(32, 3, 224, 224),
)]
)

# Save with explicit dynamic_shapes
torch_tensorrt.save(
trt_model,
"model.ep",
arg_inputs=[example_input],
dynamic_shapes=dynamic_shapes # Same as used during export
)

# Load and use
loaded_model = torch_tensorrt.load("model.ep").module()

**When to use this method:**
- You need specific dimension names for torch.export compatibility
- You're working with existing torch.export workflows
- You require fine-grained control over dynamic dimension specifications

Multiple Dynamic Dimensions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Both methods support multiple dynamic dimensions (e.g., dynamic batch, height, and width):

.. code-block:: python

# Method 1 (Automatic): Multiple dynamic dimensions
inputs = [
torch_tensorrt.Input(
min_shape=(1, 3, 64, 64),
opt_shape=(8, 3, 256, 256),
max_shape=(16, 3, 512, 512),
name="image"
)
]

trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs) # All 3 dims inferred!

# Load and test with various sizes
loaded = torch_tensorrt.load("model.ep").module()
loaded(torch.randn(4, 3, 128, 128).cuda())
loaded(torch.randn(12, 3, 384, 384).cuda())

Saving with Keyword Arguments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If your model uses keyword arguments with dynamic shapes, both methods support them:

.. code-block:: python

# Define dynamic inputs for both args and kwargs
arg_inputs = [
torch_tensorrt.Input(
min_shape=(1, 10),
opt_shape=(4, 10),
max_shape=(8, 10),
name="x"
)
]

kwarg_inputs = {
"mask": torch_tensorrt.Input(
min_shape=(1, 5),
opt_shape=(4, 5),
max_shape=(8, 5),
name="mask"
)
}

# Compile
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
arg_inputs=arg_inputs,
kwarg_inputs=kwarg_inputs
)

# Save - both arg and kwarg dynamic shapes inferred automatically
torch_tensorrt.save(
trt_model,
"model.ep",
arg_inputs=arg_inputs,
kwarg_inputs=kwarg_inputs
)
92 changes: 92 additions & 0 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,98 @@ Here's an example usage
model = torch.export.load("trt.ep").module()
model(*inputs)


Saving Models with Dynamic Shapes
""""""""""""""""""""""""""""""""""

When saving models compiled with dynamic shapes, you have two methods to preserve
the dynamic shape specifications:

**Method 1: Using torch.export.Dim (explicit)**

Provide explicit ``dynamic_shapes`` parameter following torch.export's pattern:

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
example_input = torch.randn((2, 3, 224, 224)).cuda()

# Define dynamic batch dimension
dyn_batch = torch.export.Dim("batch", min=1, max=32)
dynamic_shapes = {"x": {0: dyn_batch}}

# Export with dynamic shapes
exp_program = torch.export.export(
model, (example_input,),
dynamic_shapes=dynamic_shapes,
strict=False
)

# Compile with dynamic input specifications
trt_gm = torch_tensorrt.dynamo.compile(
exp_program,
inputs=[torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(32, 3, 224, 224),
)]
)

# Save with dynamic_shapes to preserve dynamic behavior
torch_tensorrt.save(
trt_gm,
"trt_dynamic.ep",
arg_inputs=[example_input],
dynamic_shapes=dynamic_shapes # Same as used during export
)

# Load and use with different batch sizes
loaded_model = torch_tensorrt.load("trt_dynamic.ep").module()
output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda())
output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda())

**Method 2: Using torch_tensorrt.Input**

Pass ``torch_tensorrt.Input`` objects with min/opt/max shapes directly, and the
dynamic shapes will be inferred automatically:

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()

# Define Input with dynamic shapes
inputs = [
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(32, 3, 224, 224),
dtype=torch.float32,
name="x" # Optional: provides better dimension naming
)
]

# Compile with Torch-TensorRT
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)

# Save with Input objects - dynamic_shapes inferred automatically!
torch_tensorrt.save(
trt_gm,
"trt_dynamic.ep",
arg_inputs=inputs # Dynamic shapes inferred from Input objects
)

# Load and use with different batch sizes
loaded_model = torch_tensorrt.load("trt_dynamic.ep").module()
output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda())
output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda())


b) Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Loading
Loading