Skip to content

Commit cdb3bb0

Browse files
Merge pull request #330 from zhangstevenunity/codex/kernel-kind-late-emitc
fix: preserve kernel_kind until EmitC lowering
2 parents ff117ec + 2b38303 commit cdb3bb0

4 files changed

Lines changed: 122 additions & 5 deletions

File tree

lib/PTO/Transforms/PTOToEmitC.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,22 @@ static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) {
24412441
}
24422442
}
24432443

2444+
static std::optional<StringRef> getKernelKindMacro(func::FuncOp funcOp) {
2445+
auto kernelKindAttr =
2446+
funcOp->getAttrOfType<FunctionKernelKindAttr>(FunctionKernelKindAttr::name);
2447+
if (!kernelKindAttr)
2448+
return std::nullopt;
2449+
2450+
switch (kernelKindAttr.getKernelKind()) {
2451+
case FunctionKernelKind::Cube:
2452+
return StringRef("__DAV_CUBE__");
2453+
case FunctionKernelKind::Vector:
2454+
return StringRef("__DAV_VEC__");
2455+
}
2456+
2457+
llvm_unreachable("unexpected kernel kind");
2458+
}
2459+
24442460
struct FuncToEmitC : public OpConversionPattern<func::FuncOp> {
24452461
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
24462462

@@ -2500,11 +2516,31 @@ struct FuncToEmitC : public OpConversionPattern<func::FuncOp> {
25002516
*getTypeConverter(), &entryConv)))
25012517
return failure();
25022518

2503-
// [Compatibility patch] Preserve existing snippets that rely on `T`.
2519+
std::optional<StringRef> kernelKindMacro = getKernelKindMacro(op);
2520+
2521+
// Preserve the existing function prologue shape. `kernel_kind` functions are
2522+
// emitted with the same macro guard/reset sequence that used to come from
2523+
// early pto.section wrapping, but only after SCF pre-lowering has finished.
25042524
{
25052525
Block &entryBlock = emitcFunc.getBody().front();
25062526
rewriter.setInsertionPointToStart(&entryBlock);
25072527
rewriter.create<emitc::VerbatimOp>(op.getLoc(), "using T = float;");
2528+
if (kernelKindMacro) {
2529+
std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")";
2530+
rewriter.create<emitc::VerbatimOp>(op.getLoc(), startMacro);
2531+
if (*kernelKindMacro == "__DAV_VEC__") {
2532+
rewriter.create<emitc::VerbatimOp>(op.getLoc(), "set_mask_norm();");
2533+
rewriter.create<emitc::VerbatimOp>(op.getLoc(),
2534+
"set_vector_mask(-1, -1);");
2535+
}
2536+
}
2537+
}
2538+
2539+
if (kernelKindMacro) {
2540+
Block &lastBlock = emitcFunc.getBody().back();
2541+
rewriter.setInsertionPoint(lastBlock.getTerminator());
2542+
std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n";
2543+
rewriter.create<emitc::VerbatimOp>(op.getLoc(), endMacro);
25082544
}
25092545

25102546
rewriter.eraseOp(op);

lib/PTO/Transforms/PTOVerifyTFreePass.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ static LogicalResult verifyNoTileUsesAfterTFree(TPopOp tpopOp,
9696
return success();
9797
}
9898

99+
static bool isInsideSectionOrAttributedKernel(TPopOp tpopOp, func::FuncOp funcOp) {
100+
if (tpopOp->getParentOfType<SectionCubeOp>() ||
101+
tpopOp->getParentOfType<SectionVectorOp>())
102+
return true;
103+
return funcOp &&
104+
funcOp->hasAttr(FunctionKernelKindAttr::name);
105+
}
106+
99107
struct PTOVerifyTFreePass
100108
: public mlir::pto::impl::PTOVerifyTFreeBase<PTOVerifyTFreePass> {
101109
void runOnOperation() override {
@@ -105,8 +113,7 @@ struct PTOVerifyTFreePass
105113
funcOp.walk([&](TPopOp op) { tpops.push_back(op); });
106114

107115
for (TPopOp tpopOp : tpops) {
108-
if (!tpopOp->getParentOfType<SectionCubeOp>() &&
109-
!tpopOp->getParentOfType<SectionVectorOp>())
116+
if (!isInsideSectionOrAttributedKernel(tpopOp, funcOp))
110117
continue;
111118

112119
TFreeOp existingTFree = findMatchingTFree(tpopOp);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: ptoas %s | FileCheck %s
2+
3+
module {
4+
func.func @vector_while_kernel() attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
5+
%c0 = arith.constant 0 : index
6+
%c1 = arith.constant 1 : index
7+
%c2 = arith.constant 2 : index
8+
%c4 = arith.constant 4 : index
9+
%c32 = arith.constant 32 : index
10+
%true = arith.constant true
11+
%false = arith.constant false
12+
%one = arith.constant 1.0 : f32
13+
%ten = arith.constant 10.0 : f32
14+
15+
%tile = pto.alloc_tile
16+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
17+
blayout=row_major, slayout=none_box, fractal=512, pad=0>
18+
19+
%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) {
20+
%lt = arith.cmpi slt, %i, %c4 : index
21+
%go = arith.andi %lt, %alive : i1
22+
scf.condition(%go) %i, %alive : index, i1
23+
} do {
24+
^bb0(%i2: index, %alive2: i1):
25+
pto.tadds ins(%tile, %one
26+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
27+
blayout=row_major, slayout=none_box, fractal=512, pad=0>,
28+
f32)
29+
outs(%tile
30+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
31+
blayout=row_major, slayout=none_box, fractal=512, pad=0>)
32+
%break_now = arith.cmpi eq, %i2, %c2 : index
33+
%next_i = arith.addi %i2, %c1 : index
34+
%yield:2 = scf.if %break_now -> (index, i1) {
35+
scf.yield %next_i, %false : index, i1
36+
} else {
37+
pto.tadds ins(%tile, %ten
38+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
39+
blayout=row_major, slayout=none_box, fractal=512, pad=0>,
40+
f32)
41+
outs(%tile
42+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
43+
blayout=row_major, slayout=none_box, fractal=512, pad=0>)
44+
scf.yield %next_i, %true : index, i1
45+
}
46+
scf.yield %yield#0, %yield#1 : index, i1
47+
}
48+
49+
%stopped = arith.xori %final#1, %true : i1
50+
scf.if %stopped {
51+
pto.tadds ins(%tile, %one
52+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
53+
blayout=row_major, slayout=none_box, fractal=512, pad=0>,
54+
f32)
55+
outs(%tile
56+
: !pto.tile_buf<loc=vec, dtype=f32, rows=32, cols=32, v_row=32, v_col=32,
57+
blayout=row_major, slayout=none_box, fractal=512, pad=0>)
58+
}
59+
return
60+
}
61+
}
62+
63+
// CHECK: AICORE void vector_while_kernel()
64+
// CHECK: using T = float;
65+
// CHECK: #if defined(__DAV_VEC__)
66+
// CHECK: set_mask_norm();
67+
// CHECK: set_vector_mask(-1, -1);
68+
// CHECK: goto [[HEADER:label[0-9]+]];
69+
// CHECK: [[HEADER]]:
70+
// CHECK: if ({{.*}}) {
71+
// CHECK: goto [[BODY:label[0-9]+]];
72+
// CHECK: goto [[EXIT:label[0-9]+]];
73+
// CHECK: [[BODY]]:
74+
// CHECK: goto [[HEADER]];
75+
// CHECK: [[EXIT]]:
76+
// CHECK: #endif // __DAV_VEC__

tools/ptoas/ptoas.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,8 +873,6 @@ int main(int argc, char **argv) {
873873
// pm.addNestedPass<mlir::func::FuncOp>(pto::createPTOInsertLoadStoreForMixCVPass());
874874
pm.addNestedPass<mlir::func::FuncOp>(
875875
pto::createPTOLowerFrontendPipeOpsPass());
876-
pm.addNestedPass<mlir::func::FuncOp>(
877-
pto::createPTOWrapFunctionsInSectionsPass());
878876
pm.addNestedPass<mlir::func::FuncOp>(pto::createPTOVerifyTFreePass());
879877
pm.addNestedPass<mlir::func::FuncOp>(pto::createLoweringSyncToPipePass());
880878

0 commit comments

Comments
 (0)