diff --git a/src/kirin/dialects/scf/constprop.py b/src/kirin/dialects/scf/constprop.py index 25181769c..11303718c 100644 --- a/src/kirin/dialects/scf/constprop.py +++ b/src/kirin/dialects/scf/constprop.py @@ -108,6 +108,12 @@ def _prop_const_iterable_forloop( loop_vars = frame.get_values(stmt.initializers) + # Only safe to break early when the body doesn't use the iteration + # variable — otherwise later iterations may take different code paths. + iter_var = stmt.body.blocks[0].args[0] + can_early_terminate = not iter_var.uses + + prev_loop_vars = None for value in iterable.data: with interp_.new_frame(stmt, has_parent_access=True) as body_frame: loop_vars = interp_.frame_call_region( @@ -121,6 +127,14 @@ def _prop_const_iterable_forloop( elif isinstance(loop_vars, interp.ReturnValue): return loop_vars + if ( + can_early_terminate + and prev_loop_vars is not None + and loop_vars == prev_loop_vars + ): + break + prev_loop_vars = loop_vars + if not frame_is_not_pure: frame.should_be_pure.add(stmt) return loop_vars diff --git a/test/dialects/scf/test_constprop.py b/test/dialects/scf/test_constprop.py index 30d3535b3..875ee5556 100644 --- a/test/dialects/scf/test_constprop.py +++ b/test/dialects/scf/test_constprop.py @@ -1,11 +1,22 @@ from pytest import mark +from kirin import ir, lowering +from kirin.decl import statement from kirin.prelude import structural_no_opt from kirin.analysis import const from kirin.dialects import scf, func prop = const.Propagate(structural_no_opt) +# A statement with no Pure/MaybePure trait — acts as a side effect. +_impure_dialect = ir.Dialect("test_impure") + + +@statement(dialect=_impure_dialect) +class ImpureOp(ir.Statement): + name = "impure_op" + traits = frozenset({lowering.FromPythonCall()}) + def test_simple_loop(): @structural_no_opt @@ -103,3 +114,49 @@ def simple_ifelse(x: int): assert isinstance(terminator, func.Return) assert isinstance(value := frame.entries[terminator.value], const.Value) assert value.data == 0 + + +def test_early_termination_when_body_ignores_iter_var(): + """When the body doesn't reference the iteration variable and loop_vars + converge (x is Unknown, so Unknown + 1 = Unknown), early termination + should fire and produce the same result as running all iterations.""" + + @structural_no_opt + def converging_loop(x: int) -> int: + for _i in range(100): + x = x + 1 + return x + + constprop = const.Propagate(structural_no_opt) + frame, ret = constprop.run(converging_loop) + + assert isinstance(ret, const.Unknown) + [for_stmt] = [s for s in converging_loop.code.walk() if isinstance(s, scf.For)] + assert for_stmt in frame.should_be_pure + + +def test_no_early_termination_when_body_uses_iter_var(): + """Early termination must not fire when the body references the iteration + variable, because later iterations may follow different code paths that + affect purity. Here the impure ``ImpureOp`` is guarded by ``i == 50``, + so the loop body is impure only on iteration 50. If early termination + incorrectly broke after iteration 1 (where loop_vars converge), the + for-loop would be marked as pure when it is not.""" + + _group = structural_no_opt.add(_impure_dialect) + + @_group + def impure_on_later_iter(x: int) -> int: + for i in range(100): + if i == 50: + ImpureOp() + x = x + 1 + return x + + constprop = const.Propagate(_group) + frame, ret = constprop.run(impure_on_later_iter) + + [for_stmt] = [s for s in impure_on_later_iter.code.walk() if isinstance(s, scf.For)] + # The for-loop must NOT be in should_be_pure — it contains a + # conditionally-impure operation on a later iteration. + assert for_stmt not in frame.should_be_pure