Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/kirin/dialects/scf/constprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def _prop_const_iterable_forloop(

loop_vars = frame.get_values(stmt.initializers)

# Only safe to break early when the body doesn't use the iteration
# variable — otherwise later iterations may take different code paths.
iter_var = stmt.body.blocks[0].args[0]
can_early_terminate = not iter_var.uses

prev_loop_vars = None
for value in iterable.data:
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
loop_vars = interp_.frame_call_region(
Expand All @@ -121,6 +127,14 @@ def _prop_const_iterable_forloop(
elif isinstance(loop_vars, interp.ReturnValue):
return loop_vars

if (
can_early_terminate
and prev_loop_vars is not None
and loop_vars == prev_loop_vars
):
break
prev_loop_vars = loop_vars

if not frame_is_not_pure:
frame.should_be_pure.add(stmt)
return loop_vars
57 changes: 57 additions & 0 deletions test/dialects/scf/test_constprop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from pytest import mark

from kirin import ir, lowering
from kirin.decl import statement
from kirin.prelude import structural_no_opt
from kirin.analysis import const
from kirin.dialects import scf, func

prop = const.Propagate(structural_no_opt)

# A statement with no Pure/MaybePure trait — acts as a side effect.
_impure_dialect = ir.Dialect("test_impure")


@statement(dialect=_impure_dialect)
class ImpureOp(ir.Statement):
name = "impure_op"
traits = frozenset({lowering.FromPythonCall()})


def test_simple_loop():
@structural_no_opt
Expand Down Expand Up @@ -103,3 +114,49 @@ def simple_ifelse(x: int):
assert isinstance(terminator, func.Return)
assert isinstance(value := frame.entries[terminator.value], const.Value)
assert value.data == 0


def test_early_termination_when_body_ignores_iter_var():
"""When the body doesn't reference the iteration variable and loop_vars
converge (x is Unknown, so Unknown + 1 = Unknown), early termination
should fire and produce the same result as running all iterations."""

@structural_no_opt
def converging_loop(x: int) -> int:
for _i in range(100):
x = x + 1
return x

constprop = const.Propagate(structural_no_opt)
frame, ret = constprop.run(converging_loop)

assert isinstance(ret, const.Unknown)
[for_stmt] = [s for s in converging_loop.code.walk() if isinstance(s, scf.For)]
assert for_stmt in frame.should_be_pure


def test_no_early_termination_when_body_uses_iter_var():
"""Early termination must not fire when the body references the iteration
variable, because later iterations may follow different code paths that
affect purity. Here the impure ``ImpureOp`` is guarded by ``i == 50``,
so the loop body is impure only on iteration 50. If early termination
incorrectly broke after iteration 1 (where loop_vars converge), the
for-loop would be marked as pure when it is not."""

_group = structural_no_opt.add(_impure_dialect)

@_group
def impure_on_later_iter(x: int) -> int:
for i in range(100):
if i == 50:
ImpureOp()
x = x + 1
return x

constprop = const.Propagate(_group)
frame, ret = constprop.run(impure_on_later_iter)

[for_stmt] = [s for s in impure_on_later_iter.code.walk() if isinstance(s, scf.For)]
# The for-loop must NOT be in should_be_pure — it contains a
# conditionally-impure operation on a later iteration.
assert for_stmt not in frame.should_be_pure
Loading