Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,19 +2266,25 @@ def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...]
class ReplacePowWithMulPass(RemoveOrReplacePassInterface):
"""
Replace the pow op with successive mul ops when the exponent is an
integer between 2 and 4 (inclusive).
integer between 2 and 4 (inclusive). Float exponents that are whole
numbers (e.g., 2.0, 3.0, 4.0) are also accepted.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.pow.Tensor_Scalar]

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Check if we have at least 2 args and the exponent is an int
if len(node.args) < 2 or not isinstance(node.args[1], int):
# Check if we have at least 2 args and the exponent is an int or float
if len(node.args) < 2 or not isinstance(node.args[1], (int, float)):
return False

exponent = cast(int, node.args[1])
exponent_val = node.args[1]
if isinstance(exponent_val, float):
if not exponent_val.is_integer():
return False
exponent_val = int(exponent_val)
exponent = cast(int, exponent_val)

# Only replace if exponent is between 2 and 4 (inclusive)
if exponent < 2 or exponent > 4:
Expand Down
8 changes: 5 additions & 3 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,8 +1926,8 @@ def test_replace_split_with_sizes_with_slice(self) -> None:
2,
)

@expand([[2], [3], [4]])
def test_replace_pow_with_mul(self, exponent: int) -> None:
@expand([[2], [3], [4], [2.0], [3.0], [4.0]])
def test_replace_pow_with_mul(self, exponent: int | float) -> None:
x_input = torch.randn(2, 1, 64)
x = x_input
original_gm = single_op_builder(
Expand Down Expand Up @@ -1956,13 +1956,15 @@ def test_replace_pow_with_mul(self, exponent: int) -> None:
graph_after_passes,
exir_ops.edge.aten.mul.Tensor,
),
exponent - 1,
int(exponent) - 1,
)

@expand(
[
[1],
[1.5],
[5.0],
[0.5],
]
)
def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None:
Expand Down
Loading