From 55e1ed1b001d8f47738c7369b69def4671f016bd Mon Sep 17 00:00:00 2001 From: Hamidreza Khazaei Date: Tue, 14 Apr 2026 23:20:38 -0700 Subject: [PATCH] Support whole-number float exponents in ReplacePowWithMulPass (#18851) Summary: Extend `ReplacePowWithMulPass` to accept float exponents that are whole numbers (e.g., 2.0, 3.0, 4.0) in addition to integer exponents. Previously the pass only matched `int` typed exponents, causing it to miss valid optimization opportunities when the exponent was a float with no fractional part. Changes: - Broaden the type check from `isinstance(_, int)` to `isinstance(_, (int, float))` - Add a guard to reject non-whole-number floats (e.g., 1.5, 0.5) - Convert validated float exponents to `int` before proceeding - Update docstring to document the new behavior Reviewed By: hsharma35, mcremon-meta Differential Revision: D100695654 --- backends/cadence/aot/replace_ops.py | 14 ++++++++++---- .../cadence/aot/tests/test_replace_ops_passes.py | 8 +++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 6e2f85fab0f..c88b1fad56f 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2266,7 +2266,8 @@ def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...] class ReplacePowWithMulPass(RemoveOrReplacePassInterface): """ Replace the pow op with successive mul ops when the exponent is an - integer between 2 and 4 (inclusive). + integer between 2 and 4 (inclusive). Float exponents that are whole + numbers (e.g., 2.0, 3.0, 4.0) are also accepted. """ @property @@ -2274,11 +2275,16 @@ def targets(self) -> list[EdgeOpOverload]: return [exir_ops.edge.aten.pow.Tensor_Scalar] def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Check if we have at least 2 args and the exponent is an int - if len(node.args) < 2 or not isinstance(node.args[1], int): + # Check if we have at least 2 args and the exponent is an int or float + if len(node.args) < 2 or not isinstance(node.args[1], (int, float)): return False - exponent = cast(int, node.args[1]) + exponent_val = node.args[1] + if isinstance(exponent_val, float): + if not exponent_val.is_integer(): + return False + exponent_val = int(exponent_val) + exponent = cast(int, exponent_val) # Only replace if exponent is between 2 and 4 (inclusive) if exponent < 2 or exponent > 4: diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ee13726a94b..56a17b73f88 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1926,8 +1926,8 @@ def test_replace_split_with_sizes_with_slice(self) -> None: 2, ) - @expand([[2], [3], [4]]) - def test_replace_pow_with_mul(self, exponent: int) -> None: + @expand([[2], [3], [4], [2.0], [3.0], [4.0]]) + def test_replace_pow_with_mul(self, exponent: int | float) -> None: x_input = torch.randn(2, 1, 64) x = x_input original_gm = single_op_builder( @@ -1956,13 +1956,15 @@ def test_replace_pow_with_mul(self, exponent: int) -> None: graph_after_passes, exir_ops.edge.aten.mul.Tensor, ), - exponent - 1, + int(exponent) - 1, ) @expand( [ [1], [1.5], + [5.0], + [0.5], ] ) def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None: