Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.node import Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -53,9 +52,7 @@ def batch_norm(
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
# TODO: lanl: to remove this once we have solved the batchnorm constant folding issue in RTX
# https://github.com/pytorch/TensorRT/issues/3699
if ENABLED_FEATURES.tensorrt_rtx or any(
if any(
[
isinstance(weight, trt.ITensor),
isinstance(bias, trt.ITensor),
Expand Down Expand Up @@ -175,8 +172,6 @@ def batch_norm(
if running_var is None:
running_var = torch.ones((feature_num,))

power = torch.ones_like(weight)

adjusted_scale, adjusted_bias = batch_norm_constant_folding(
weight, bias, running_mean, running_var, eps
)
Expand All @@ -200,16 +195,6 @@ def batch_norm(
source_ir=source_ir,
)

power = to_trt_weights(
ctx,
power,
name,
layer_type_name="SCALE",
weight_type_name="POWER",
target=target,
source_ir=source_ir,
)

if len(input.shape) < 4:
new_shape = (
(input.shape[0], input.shape[1], 1, 1)
Expand All @@ -221,7 +206,7 @@ def batch_norm(
)

layer = ctx.net.add_scale_nd(
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, None, 1
)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
Expand Down
48 changes: 48 additions & 0 deletions tests/py/dynamo/conversion/test_batch_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,53 @@ def forward(self, x):
)


class TestBatchNormConvFusion(DispatchTestCase):
"""Regression test for https://github.com/pytorch/TensorRT/issues/3699.

When Conv output is negative and feeds into BatchNorm, the fused
Conv+Scale kernel must not produce NaN. This test uses the full
torch_tensorrt.dynamo.compile() pipeline to exercise the TRT optimizer's
layer fusion.
"""

def test_conv_batchnorm_conv_no_nan(self):
import torch_tensorrt

class ConvBNConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(1)
self.conv2 = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False)

def forward(self, x):
return self.conv2(self.bn1(self.conv1(x)))

model = ConvBNConv().eval().to("cuda")
with torch.no_grad():
model.conv1.weight.fill_(1.0)
model.conv2.weight.fill_(1.0)

inp = torch.full((1, 1, 1, 1), -1.0, device="cuda")
exp_program = torch.export.export(model, (inp,), strict=False)
trt_mod = torch_tensorrt.dynamo.compile(
exp_program,
inputs=[torch_tensorrt.Input(shape=(1, 1, 1, 1), dtype=torch.float32)],
device=torch_tensorrt.Device("cuda:0"),
enabled_precisions={torch.float32},
min_block_size=1,
pass_through_build_failures=True,
cache_built_engines=False,
reuse_cached_engines=False,
)

pyt_out = model(inp)
trt_out = trt_mod(inp)

nan_count = torch.isnan(trt_out).sum().item()
self.assertEqual(nan_count, 0, "TRT output contains NaN")
torch.testing.assert_close(trt_out, pyt_out, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
run_tests()
Loading