Skip to content

Commit 257fae2

Browse files
committed
feat: Complex operations which are not supported will now fallback to pytorch rather than fail to build
1 parent 9656bec commit 257fae2

9 files changed

Lines changed: 370 additions & 7 deletions

File tree

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ jobs:
107107
set -euo pipefail
108108
pushd .
109109
cd tests/py/dynamo
110-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/
110+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/
111111
popd
112112
113113
L0-dynamo-core-tests:
@@ -141,6 +141,7 @@ jobs:
141141
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
142142
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
143143
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
144+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
144145
popd
145146
146147
L0-py-core-tests:

.github/workflows/build-test-linux-x86_64_rtx.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ jobs:
107107
set -euo pipefail
108108
pushd .
109109
cd tests/py/dynamo
110-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/
110+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/
111111
popd
112112
113113
L0-dynamo-core-tests:
@@ -142,6 +142,7 @@ jobs:
142142
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
143143
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
144144
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
145+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
145146
popd
146147
147148
L0-py-core-tests:

.github/workflows/build-test-windows.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ jobs:
140140
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
141141
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
142142
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
143+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
143144
popd
144145
145146
L0-py-core-tests:

.github/workflows/build-test-windows_rtx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ jobs:
144144
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
145145
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
146146
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
147+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/
147148
popd
148149
149150
L0-py-core-tests:

py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import math
3+
import operator
34
from typing import Callable, List, Optional, Tuple
45

56
import torch
@@ -57,6 +58,16 @@
5758
torch.ops.aten.floor.default,
5859
torch.ops.aten.round.default,
5960
torch.ops.aten.trunc.default,
61+
# Structural list indexing — extracts one element from a split/chunk output.
62+
# The element is still in real [..., 2] complex layout; the flag is already
63+
# set by the pre-rewrite annotation loop. No view_as_complex wrapping needed.
64+
operator.getitem,
65+
# Shape queries — sym_size.int reads a tensor's dimension value, which is not
66+
# affected by the complex [..., 2] layout. Without this entry the fallback
67+
# wrapper inserts view_as_complex before the sym_size node, causing the shape
68+
# to be computed from a complex tensor in the PyTorch fallback and returning
69+
# a raw SymInt backing value (garbage) to TRT for reshape dims.
70+
torch.ops.aten.sym_size.int,
6071
}
6172
)
6273

@@ -467,7 +478,14 @@ def _inline_complex_sqrt(
467478

468479
@_complex_unpacker(torch.ops.aten.view_as_complex.default)
469480
def _rewrite_view_as_complex(self, node: Node) -> bool:
470-
node.replace_all_uses_with(node.args[0])
481+
inp = node.args[0]
482+
# The input to view_as_complex is a (..., 2) real-layout tensor that
483+
# represents a complex tensor. After erasing view_as_complex, downstream
484+
# consumers (e.g. mul.Tensor) need to know that this node is in complex
485+
# layout so the correct rewrite branch is chosen.
486+
if isinstance(inp, torch.fx.Node):
487+
inp.meta["is_complex_layout"] = True
488+
node.replace_all_uses_with(inp)
471489
self.gm.graph.erase_node(node)
472490
# Return True so the caller triggers propagate_metadata + gm.recompile().
473491
# Without recompile the compiled forward still calls the erased node.
@@ -1727,11 +1745,58 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None:
17271745
else:
17281746
logger.warning(
17291747
"Complex op '%s' has no explicit rewrite rule. "
1730-
"It will be passed through as-is on the real [..., 2] layout, "
1731-
"which may produce incorrect results or fail TRT compilation. "
1732-
"Consider adding a rewrite in complex_graph_rewrite.py.",
1748+
"Wrapping with view_as_complex/view_as_real so the op "
1749+
"receives genuine complex tensors and TRT graph-breaks "
1750+
"around it into a PyTorch fallback block.",
17331751
node.target,
17341752
)
1753+
# Generic fallback: for each arg that is a real-layout
1754+
# complex node, insert view_as_complex before the node so
1755+
# the op sees genuine complex-dtype tensors (correct
1756+
# semantics); then, if the node itself originally produced
1757+
# a complex-layout output, wrap it with view_as_real and
1758+
# redirect all users back onto the real [..., 2] path.
1759+
# TRT has no complex-dtype support so it will refuse to
1760+
# compile the view_as_complex/op/view_as_real cluster,
1761+
# causing the partitioner to create a PyTorch fallback
1762+
# block around it — exactly the graph break we want.
1763+
new_args = list(node.args)
1764+
any_complexified = False
1765+
for i, arg in enumerate(node.args):
1766+
if not isinstance(arg, torch.fx.Node):
1767+
continue
1768+
if not arg.meta.get("is_complex_layout", False):
1769+
continue
1770+
# Skip when val is a list/tuple (e.g. a residual split
1771+
# output that wasn't caught by the getitem pass-through).
1772+
# Allow None (newly created node without metadata yet).
1773+
arg_val = arg.meta.get("val")
1774+
if isinstance(arg_val, (list, tuple)):
1775+
continue
1776+
with self.gm.graph.inserting_before(node):
1777+
vc = self.gm.graph.call_function(
1778+
torch.ops.aten.view_as_complex.default,
1779+
(arg,),
1780+
)
1781+
# view_as_complex produces a genuine complex node —
1782+
# do NOT set is_complex_layout; it is not a
1783+
# real-layout stand-in.
1784+
new_args[i] = vc
1785+
any_complexified = True
1786+
if any_complexified:
1787+
node.args = tuple(new_args)
1788+
if any_complexified and node.meta.get("is_complex_layout", False):
1789+
with self.gm.graph.inserting_after(node):
1790+
vr = self.gm.graph.call_function(
1791+
torch.ops.aten.view_as_real.default,
1792+
(node,),
1793+
)
1794+
vr.meta["is_complex_layout"] = True
1795+
node.replace_all_uses_with(
1796+
vr,
1797+
delete_user_cb=lambda user: user is not vr,
1798+
)
1799+
modified = True
17351800
if modified:
17361801
# After rewriting complex ops, any view_as_real node that now receives a
17371802
# real tensor must be erased. The subgraph_rewriter replaces the original

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2323
ConverterRegistry,
2424
)
25+
from torch_tensorrt.dynamo.partitioning._global_partitioner import (
26+
TorchTensorRTOperatorSupport,
27+
)
2528

2629
logger = logging.getLogger(__name__)
2730

@@ -42,6 +45,14 @@ def is_node_supported(
4245
) -> bool:
4346
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4447

48+
if TorchTensorRTOperatorSupport._has_complex_dtype(node):
49+
# Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback
50+
if not node.is_impure():
51+
self.unsupported_operators[node_name] = (
52+
self.unsupported_operators.get(node_name, 0) + 1
53+
)
54+
return False
55+
4556
if (
4657
(node in CONVERTERS or node.op == "get_attr")
4758
and node_name not in self.torch_executed_ops

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple
2+
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple, Set
33

44
import torch
55
from torch.fx.graph_module import GraphModule
@@ -144,11 +144,41 @@ def __init__(
144144
self.unsupported_operators: Dict[str, int] = {}
145145
self.torch_executed_ops: Collection[Target] = torch_executed_ops
146146

147+
@staticmethod
148+
def _has_complex_dtype(node: torch.fx.Node) -> bool:
149+
"""Return True if the node output or any of its tensor inputs is complex-dtype.
150+
151+
TensorRT has no native complex-type support. Any node that produces or
152+
consumes a complex tensor must run in the PyTorch fallback so the graph
153+
breaks naturally around it.
154+
"""
155+
_COMPLEX = {torch.complex64, torch.complex128}
156+
157+
def _dtype(n: torch.fx.Node) -> Optional[torch.dtype]:
158+
val = n.meta.get("val")
159+
return getattr(val, "dtype", None) if val is not None else None
160+
161+
if _dtype(node) in _COMPLEX:
162+
return True
163+
for arg in node.all_input_nodes:
164+
if _dtype(arg) in _COMPLEX:
165+
return True
166+
return False
167+
147168
def is_node_supported(
148169
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
149170
) -> bool:
150171
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151172

173+
if self._has_complex_dtype(node):
174+
# Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback
175+
# so the graph breaks around the complex cluster inserted by complex_graph_detection.
176+
if not node.is_impure():
177+
self.unsupported_operators[node_name] = (
178+
self.unsupported_operators.get(node_name, 0) + 1
179+
)
180+
return False
181+
152182
if (
153183
(node in CONVERTERS or node.op == "get_attr")
154184
and node_name not in self.torch_executed_ops

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ lint = [
6060

6161
dev = [
6262
{include-group = "lint"},
63+
{include-group = "test"},
6364
"pre-commit>=2.20.0",
6465
"typos",
6566
"mypy",
@@ -77,6 +78,7 @@ debug = [
7778
test = [
7879
"pytest",
7980
"pytest-xdist",
81+
"pytest-forked>=1.6.0",
8082
"parameterized>=0.2.0",
8183
"expecttest==0.1.6",
8284
]
@@ -114,6 +116,7 @@ include-package-data = false
114116

115117
[tool.pytest.ini_options]
116118
testpaths = ["tests/py"]
119+
addopts = "-n auto --dist=loadfile"
117120
norecursedirs = [
118121
"bazel-*",
119122
".venv",

0 commit comments

Comments
 (0)