Skip to content

Commit 9c09f21

Browse files
authored
Implementing For Loop Unrolling for PartialTuple. (#596)
In this PR I extend the current `unroll` rule in `scf` to unroll the for loop when the container is not fully constant. The implementation is very simple by just extending the logic to insert a `GetItem` for every iteration of the loop instead of inserting the constant value.
1 parent c4e5454 commit 9c09f21

2 files changed

Lines changed: 51 additions & 6 deletions

File tree

src/kirin/dialects/scf/unroll.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from kirin.dialects import func
44
from kirin.rewrite.abc import RewriteRule, RewriteResult
55
from kirin.dialects.py.constant import Constant
6+
from kirin.dialects.py.indexing import GetItem
67

78
from .stmts import For, Yield, IfElse
89

@@ -54,21 +55,35 @@ def insert_body(self, node: IfElse, body: ir.Region):
5455

5556
class ForLoop(RewriteRule):
5657

58+
def yield_item_results_const(self, node: For, hint: const.Value):
59+
for item in hint.data:
60+
item_stmt = Constant(item)
61+
item_stmt.insert_before(node)
62+
yield item_stmt.result
63+
64+
def yield_item_results_from_len(self, node: For, len_iterable: int):
65+
for i in range(len_iterable):
66+
(index_stmt := Constant(i)).insert_before(node)
67+
(item_stmt := GetItem(node.iterable, index_stmt.result)).insert_before(node)
68+
yield item_stmt.result
69+
5770
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
5871
if not isinstance(node, For):
5972
return RewriteResult()
6073

61-
# TODO: support for PartialTuple and IList with known length
62-
if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
74+
if isinstance(hint := node.iterable.hints.get("const"), const.Value):
75+
item_results = self.yield_item_results_const(node, hint)
76+
elif isinstance(hint, const.PartialTuple):
77+
item_results = self.yield_item_results_from_len(node, len(hint.data))
78+
else:
6379
return RewriteResult()
6480

6581
loop_vars = node.initializers
66-
for item in hint.data:
82+
83+
for item_result in item_results:
6784
body = node.body.clone()
6885
block = body.blocks[0]
69-
item_stmt = Constant(item)
70-
item_stmt.insert_before(node)
71-
block.args[0].replace_by(item_stmt.result)
86+
block.args[0].replace_by(item_result)
7287
for var, input in zip(block.args[1:], loop_vars):
7388
var.replace_by(input)
7489

test/dialects/scf/test_unroll.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,33 @@ def simple_loop(x):
2525
assert isinstance(stmts.at(6), py.Add)
2626
assert isinstance(stmts.at(7), func.Return)
2727
assert simple_loop(1) == 4
28+
29+
30+
def test_partial_tuple_loop_unroll():
31+
@structural_no_opt
32+
def simple_loop(a: int, b: int, c: int):
33+
x = 0
34+
for i in (a, b, c):
35+
x = x + i
36+
return x
37+
38+
fold = Fold(structural_no_opt)
39+
fold(simple_loop)
40+
Walk(scf.unroll.ForLoop()).rewrite(simple_loop.code)
41+
fold(simple_loop)
42+
43+
# after fold the `getitem` should be eliminated as well
44+
# leaving just the block arguments being added directly
45+
# to `x`
46+
assert len(simple_loop.callable_region.blocks) == 1
47+
block = simple_loop.callable_region.blocks[0]
48+
args = block.args
49+
stmts = block.stmts
50+
assert isinstance(stmts.at(0), py.Constant)
51+
assert isinstance(stmt := stmts.at(1), py.Add)
52+
assert stmt.rhs is args[1]
53+
assert isinstance(stmt := stmts.at(2), py.Add)
54+
assert stmt.rhs is args[2]
55+
assert isinstance(stmt := stmts.at(3), py.Add)
56+
assert stmt.rhs is args[3]
57+
assert isinstance(stmts.at(4), func.Return)

0 commit comments

Comments
 (0)