Skip to content

Commit f52e542

Browse files
committed
Rather than passing a power value of 1 to the scale op, send None.
1 parent 0168301 commit f52e542

2 files changed

Lines changed: 50 additions & 16 deletions

File tree

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def batch_norm(
5353
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5454
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5555
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
56-
# TODO: lanl: to remove this once we have solved the batchnorm constant folding issue in RTX
57-
# https://github.com/pytorch/TensorRT/issues/3699
58-
if ENABLED_FEATURES.tensorrt_rtx or any(
56+
if any(
5957
[
6058
isinstance(weight, trt.ITensor),
6159
isinstance(bias, trt.ITensor),
@@ -175,8 +173,6 @@ def batch_norm(
175173
if running_var is None:
176174
running_var = torch.ones((feature_num,))
177175

178-
power = torch.ones_like(weight)
179-
180176
adjusted_scale, adjusted_bias = batch_norm_constant_folding(
181177
weight, bias, running_mean, running_var, eps
182178
)
@@ -200,16 +196,6 @@ def batch_norm(
200196
source_ir=source_ir,
201197
)
202198

203-
power = to_trt_weights(
204-
ctx,
205-
power,
206-
name,
207-
layer_type_name="SCALE",
208-
weight_type_name="POWER",
209-
target=target,
210-
source_ir=source_ir,
211-
)
212-
213199
if len(input.shape) < 4:
214200
new_shape = (
215201
(input.shape[0], input.shape[1], 1, 1)
@@ -221,7 +207,7 @@ def batch_norm(
221207
)
222208

223209
layer = ctx.net.add_scale_nd(
224-
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
210+
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, None, 1
225211
)
226212
set_layer_name(layer, target, name, source_ir)
227213
output = layer.get_output(0)

tests/py/dynamo/conversion/test_batch_norm_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,53 @@ def forward(self, x):
274274
)
275275

276276

277+
class TestBatchNormConvFusion(DispatchTestCase):
278+
"""Regression test for https://github.com/pytorch/TensorRT/issues/3699.
279+
280+
When Conv output is negative and feeds into BatchNorm, the fused
281+
Conv+Scale kernel must not produce NaN. This test uses the full
282+
torch_tensorrt.dynamo.compile() pipeline to exercise the TRT optimizer's
283+
layer fusion.
284+
"""
285+
286+
def test_conv_batchnorm_conv_no_nan(self):
287+
import torch_tensorrt
288+
289+
class ConvBNConv(torch.nn.Module):
290+
def __init__(self):
291+
super().__init__()
292+
self.conv1 = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False)
293+
self.bn1 = torch.nn.BatchNorm2d(1)
294+
self.conv2 = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False)
295+
296+
def forward(self, x):
297+
return self.conv2(self.bn1(self.conv1(x)))
298+
299+
model = ConvBNConv().eval().to("cuda")
300+
with torch.no_grad():
301+
model.conv1.weight.fill_(1.0)
302+
model.conv2.weight.fill_(1.0)
303+
304+
inp = torch.full((1, 1, 1, 1), -1.0, device="cuda")
305+
exp_program = torch.export.export(model, (inp,), strict=False)
306+
trt_mod = torch_tensorrt.dynamo.compile(
307+
exp_program,
308+
inputs=[torch_tensorrt.Input(shape=(1, 1, 1, 1), dtype=torch.float32)],
309+
device=torch_tensorrt.Device("cuda:0"),
310+
enabled_precisions={torch.float32},
311+
min_block_size=1,
312+
pass_through_build_failures=True,
313+
cache_built_engines=False,
314+
reuse_cached_engines=False,
315+
)
316+
317+
pyt_out = model(inp)
318+
trt_out = trt_mod(inp)
319+
320+
nan_count = torch.isnan(trt_out).sum().item()
321+
self.assertEqual(nan_count, 0, "TRT output contains NaN")
322+
torch.testing.assert_close(trt_out, pyt_out, atol=1e-3, rtol=1e-3)
323+
324+
277325
if __name__ == "__main__":
278326
run_tests()

0 commit comments

Comments
 (0)