diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 6e2f85fab0f..c88b1fad56f 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2266,7 +2266,8 @@ def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...] class ReplacePowWithMulPass(RemoveOrReplacePassInterface): """ Replace the pow op with successive mul ops when the exponent is an - integer between 2 and 4 (inclusive). + integer between 2 and 4 (inclusive). Float exponents that are whole + numbers (e.g., 2.0, 3.0, 4.0) are also accepted. """ @property @@ -2274,11 +2275,16 @@ def targets(self) -> list[EdgeOpOverload]: return [exir_ops.edge.aten.pow.Tensor_Scalar] def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Check if we have at least 2 args and the exponent is an int - if len(node.args) < 2 or not isinstance(node.args[1], int): + # Check if we have at least 2 args and the exponent is an int or float + if len(node.args) < 2 or not isinstance(node.args[1], (int, float)): return False - exponent = cast(int, node.args[1]) + exponent_val = node.args[1] + if isinstance(exponent_val, float): + if not exponent_val.is_integer(): + return False + exponent_val = int(exponent_val) + exponent = cast(int, exponent_val) # Only replace if exponent is between 2 and 4 (inclusive) if exponent < 2 or exponent > 4: diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ee13726a94b..56a17b73f88 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1926,8 +1926,8 @@ def test_replace_split_with_sizes_with_slice(self) -> None: 2, ) - @expand([[2], [3], [4]]) - def test_replace_pow_with_mul(self, exponent: int) -> None: + @expand([[2], [3], [4], [2.0], [3.0], [4.0]]) + def test_replace_pow_with_mul(self, exponent: int | float) -> None: x_input = torch.randn(2, 1, 64) x = x_input original_gm = single_op_builder( @@ -1956,13 +1956,15 @@ def test_replace_pow_with_mul(self, exponent: int) -> None: graph_after_passes, exir_ops.edge.aten.mul.Tensor, ), - exponent - 1, + int(exponent) - 1, ) @expand( [ [1], [1.5], + [5.0], + [0.5], ] ) def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None: