|
1 | 1 | import logging |
2 | 2 | import math |
| 3 | +import operator |
3 | 4 | from typing import Callable, List, Optional, Tuple |
4 | 5 |
|
5 | 6 | import torch |
|
57 | 58 | torch.ops.aten.floor.default, |
58 | 59 | torch.ops.aten.round.default, |
59 | 60 | torch.ops.aten.trunc.default, |
| 61 | + # Structural list indexing — extracts one element from a split/chunk output. |
| 62 | + # The element is still in real [..., 2] complex layout; the flag is already |
| 63 | + # set by the pre-rewrite annotation loop. No view_as_complex wrapping needed. |
| 64 | + operator.getitem, |
| 65 | + # Shape queries — sym_size.int reads a tensor's dimension value, which is not |
| 66 | + # affected by the complex [..., 2] layout. Without this entry the fallback |
| 67 | + # wrapper inserts view_as_complex before the sym_size node, causing the shape |
| 68 | + # to be computed from a complex tensor in the PyTorch fallback and returning |
| 69 | + # a raw SymInt backing value (garbage) to TRT for reshape dims. |
| 70 | + torch.ops.aten.sym_size.int, |
60 | 71 | } |
61 | 72 | ) |
62 | 73 |
|
@@ -467,7 +478,14 @@ def _inline_complex_sqrt( |
467 | 478 |
|
468 | 479 | @_complex_unpacker(torch.ops.aten.view_as_complex.default) |
469 | 480 | def _rewrite_view_as_complex(self, node: Node) -> bool: |
470 | | - node.replace_all_uses_with(node.args[0]) |
| 481 | + inp = node.args[0] |
| 482 | + # The input to view_as_complex is a (..., 2) real-layout tensor that |
| 483 | + # represents a complex tensor. After erasing view_as_complex, downstream |
| 484 | + # consumers (e.g. mul.Tensor) need to know that this node is in complex |
| 485 | + # layout so the correct rewrite branch is chosen. |
| 486 | + if isinstance(inp, torch.fx.Node): |
| 487 | + inp.meta["is_complex_layout"] = True |
| 488 | + node.replace_all_uses_with(inp) |
471 | 489 | self.gm.graph.erase_node(node) |
472 | 490 | # Return True so the caller triggers propagate_metadata + gm.recompile(). |
473 | 491 | # Without recompile the compiled forward still calls the erased node. |
@@ -1727,11 +1745,58 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: |
1727 | 1745 | else: |
1728 | 1746 | logger.warning( |
1729 | 1747 | "Complex op '%s' has no explicit rewrite rule. " |
1730 | | - "It will be passed through as-is on the real [..., 2] layout, " |
1731 | | - "which may produce incorrect results or fail TRT compilation. " |
1732 | | - "Consider adding a rewrite in complex_graph_rewrite.py.", |
| 1748 | + "Wrapping with view_as_complex/view_as_real so the op " |
| 1749 | + "receives genuine complex tensors and TRT graph-breaks " |
| 1750 | + "around it into a PyTorch fallback block.", |
1733 | 1751 | node.target, |
1734 | 1752 | ) |
| 1753 | + # Generic fallback: for each arg that is a real-layout |
| 1754 | + # complex node, insert view_as_complex before the node so |
| 1755 | + # the op sees genuine complex-dtype tensors (correct |
| 1756 | + # semantics); then, if the node itself originally produced |
| 1757 | + # a complex-layout output, wrap it with view_as_real and |
| 1758 | + # redirect all users back onto the real [..., 2] path. |
| 1759 | + # TRT has no complex-dtype support so it will refuse to |
| 1760 | + # compile the view_as_complex/op/view_as_real cluster, |
| 1761 | + # causing the partitioner to create a PyTorch fallback |
| 1762 | + # block around it — exactly the graph break we want. |
| 1763 | + new_args = list(node.args) |
| 1764 | + any_complexified = False |
| 1765 | + for i, arg in enumerate(node.args): |
| 1766 | + if not isinstance(arg, torch.fx.Node): |
| 1767 | + continue |
| 1768 | + if not arg.meta.get("is_complex_layout", False): |
| 1769 | + continue |
| 1770 | + # Skip when val is a list/tuple (e.g. a residual split |
| 1771 | + # output that wasn't caught by the getitem pass-through). |
| 1772 | + # Allow None (newly created node without metadata yet). |
| 1773 | + arg_val = arg.meta.get("val") |
| 1774 | + if isinstance(arg_val, (list, tuple)): |
| 1775 | + continue |
| 1776 | + with self.gm.graph.inserting_before(node): |
| 1777 | + vc = self.gm.graph.call_function( |
| 1778 | + torch.ops.aten.view_as_complex.default, |
| 1779 | + (arg,), |
| 1780 | + ) |
| 1781 | + # view_as_complex produces a genuine complex node — |
| 1782 | + # do NOT set is_complex_layout; it is not a |
| 1783 | + # real-layout stand-in. |
| 1784 | + new_args[i] = vc |
| 1785 | + any_complexified = True |
| 1786 | + if any_complexified: |
| 1787 | + node.args = tuple(new_args) |
| 1788 | + if any_complexified and node.meta.get("is_complex_layout", False): |
| 1789 | + with self.gm.graph.inserting_after(node): |
| 1790 | + vr = self.gm.graph.call_function( |
| 1791 | + torch.ops.aten.view_as_real.default, |
| 1792 | + (node,), |
| 1793 | + ) |
| 1794 | + vr.meta["is_complex_layout"] = True |
| 1795 | + node.replace_all_uses_with( |
| 1796 | + vr, |
| 1797 | + delete_user_cb=lambda user: user is not vr, |
| 1798 | + ) |
| 1799 | + modified = True |
1735 | 1800 | if modified: |
1736 | 1801 | # After rewriting complex ops, any view_as_real node that now receives a |
1737 | 1802 | # real tensor must be erased. The subgraph_rewriter replaces the original |
|
0 commit comments