Skip to content
This repository was archived by the owner on Mar 4, 2026. It is now read-only.

Commit ba2e6fe

Browse files
authored
small fixes (#172)
1 parent d020d8e commit ba2e6fe

3 files changed

Lines changed: 8 additions & 7 deletions

File tree

mlir/extras/dialects/ext/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def index_cast(
132132
) -> Value:
133133
if loc is None:
134134
loc = get_user_code_loc()
135-
assert bool(to) != bool(out), "either `to` or `out` must be set"
135+
assert bool(to) != bool(out), "either `to` or `out` must be set but not both"
136136
res_type = out or to
137137
if res_type is None:
138138
res_type = IndexType.get()

mlir/extras/dialects/ext/scf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@ def __init__(self, operands, num_reductions, *, loc=None, ip=None):
288288
self.regions[i].blocks.append(operands[i].type, operands[i].type)
289289

290290

291-
def reduce_(*operands, num_reductions=1):
292-
loc = get_user_code_loc()
293-
return ReduceOp(operands, num_reductions, loc=loc)
291+
def reduce_(*operands, num_reductions=1, loc=None, ip=None):
292+
if loc is None:
293+
loc = get_user_code_loc()
294+
return ReduceOp(operands, num_reductions, loc=loc, ip=ip)
294295

295296

296297
reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))

tests/test_scf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,11 +2400,11 @@ def res1(lhs: T.index(), rhs: T.index()):
24002400
return lhs + rhs
24012401

24022402
@another_reduce(res1)
2403-
def res1(lhs: T.index(), rhs: T.index()):
2403+
def res2(lhs: T.index(), rhs: T.index()):
24042404
return lhs + rhs
24052405

2406-
@another_reduce(res1)
2407-
def res2(lhs: T.index(), rhs: T.index()):
2406+
@another_reduce(res2)
2407+
def res3(lhs: T.index(), rhs: T.index()):
24082408
return lhs + rhs
24092409

24102410
ctx.module.operation.verify()

0 commit comments

Comments
 (0)