From f5b847a9a03048304d51bfac471ab8d467aa758e Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Fri, 27 Mar 2026 18:42:32 +0800 Subject: [PATCH] Implement compact tile_buf assembly syntax --- lib/PTO/IR/PTOTypeDefs.cpp | 578 ++++++++++++------ test/basic/compact_left_blayout_parser_a3.pto | 11 + test/basic/compact_left_blayout_parser_a5.pto | 11 + test/basic/left_blayout_parser_a3.pto | 5 +- test/basic/left_blayout_parser_a5.pto | 5 +- test/python/compact_tile_buf_asm.py | 203 ++++++ 6 files changed, 635 insertions(+), 178 deletions(-) create mode 100644 test/basic/compact_left_blayout_parser_a3.pto create mode 100644 test/basic/compact_left_blayout_parser_a5.pto create mode 100644 test/python/compact_tile_buf_asm.py diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index 503774abc..ff4479158 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -102,23 +102,24 @@ int32_t TileBufType::getCompactModeI32() const { return 0; } -// ---- TileBufType custom asm ---- -// !pto.tile_buf<> -Type TileBufType::parse(AsmParser &parser) { - MLIRContext *ctx = parser.getContext(); - - if (failed(parser.parseLess())) - return Type(); +namespace { +struct ParsedTileBufFields { std::string locStr; Type dtype; - int64_t rows = 0, cols = 0; - int64_t vrow = -1, vcol = -1; - std::string blayoutStr, slayoutStr; + int64_t rows = 0; + int64_t cols = 0; + int64_t vrow = -1; + int64_t vcol = -1; + std::string blayoutStr; + std::string slayoutStr; int64_t fractal = 0; - uint32_t padInt; + uint32_t padInt = 0; uint32_t compactInt = 0; +}; +static LogicalResult parseLegacyTileBufFields(AsmParser &parser, + ParsedTileBufFields &fields) { auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult { if (failed(parser.parseKeyword(expectedKey))) return failure(); @@ -127,178 +128,374 @@ Type TileBufType::parse(AsmParser &parser) { return success(); }; - // loc=Vec - { - if (failed(parseKeyEq("loc"))) return Type(); - // Vec/Mat/Acc 不是类型/属性,直接当 keyword/string 读 - if (failed(parser.parseKeywordOrString(&locStr))) return Type(); - if (failed(parser.parseComma())) return Type(); + if (failed(parser.parseEqual())) + return failure(); + if (failed(parser.parseKeywordOrString(&fields.locStr))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("dtype"))) + return failure(); + if (failed(parser.parseType(fields.dtype))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("rows"))) + return failure(); + if (failed(parser.parseInteger(fields.rows))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("cols"))) + return failure(); + if (failed(parser.parseInteger(fields.cols))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("v_row"))) + return failure(); + if (succeeded(parser.parseOptionalQuestion())) { + fields.vrow = -1; + } else { + if (failed(parser.parseInteger(fields.vrow))) + return failure(); + if (fields.vrow < -1) { + parser.emitError(parser.getCurrentLocation(), + "v_row must be '?', -1, or a non-negative integer"); + return failure(); + } } + if (failed(parser.parseComma())) + return failure(); - // dtype=f16 - { - if (failed(parseKeyEq("dtype"))) return Type(); - if (failed(parser.parseType(dtype))) return Type(); - if (failed(parser.parseComma())) return Type(); + if (failed(parseKeyEq("v_col"))) + return failure(); + if (succeeded(parser.parseOptionalQuestion())) { + fields.vcol = -1; + } else { + if (failed(parser.parseInteger(fields.vcol))) + return failure(); + if (fields.vcol < -1) { + parser.emitError(parser.getCurrentLocation(), + "v_col must be '?', -1, or a non-negative integer"); + return failure(); + } } + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("blayout"))) + return failure(); + if (failed(parser.parseKeywordOrString(&fields.blayoutStr))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("slayout"))) + return failure(); + if (failed(parser.parseKeywordOrString(&fields.slayoutStr))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("fractal"))) + return failure(); + if (failed(parser.parseInteger(fields.fractal))) + return failure(); + if (failed(parser.parseComma())) + return failure(); + + if (failed(parseKeyEq("pad"))) + return failure(); + if (failed(parser.parseInteger(fields.padInt))) + return failure(); + + return success(); +} - // rows=16 - { - if (failed(parseKeyEq("rows"))) return Type(); - if (failed(parser.parseInteger(rows))) return Type(); - if (failed(parser.parseComma())) return Type(); +static LogicalResult parseCompactTileBufFields(AsmParser &parser, + StringRef firstToken, + ParsedTileBufFields &fields) { + fields.locStr = firstToken.str(); + + if (failed(parser.parseComma())) + return failure(); + + SmallVector shape; + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/false))) + return failure(); + if (failed(parser.parseType(fields.dtype))) + return failure(); + if (shape.size() != 2) { + parser.emitError(parser.getCurrentLocation(), + "tile_buf compact syntax expects exactly two shape dims"); + return failure(); } - // cols=16 - { - if (failed(parseKeyEq("cols"))) return Type(); - if (failed(parser.parseInteger(cols))) return Type(); - if (failed(parser.parseComma())) return Type(); + fields.rows = shape[0]; + fields.cols = shape[1]; + fields.vrow = fields.rows; + fields.vcol = fields.cols; + + auto defaultConfig = TileBufConfigAttr::getDefault(parser.getContext()); + auto defaultBLayout = llvm::dyn_cast(defaultConfig.getBLayout()); + auto defaultSLayout = llvm::dyn_cast(defaultConfig.getSLayout()); + auto defaultPad = llvm::dyn_cast(defaultConfig.getPad()); + auto defaultCompact = + llvm::dyn_cast(defaultConfig.getCompactMode()); + if (!defaultBLayout || !defaultSLayout || !defaultPad || !defaultCompact) { + parser.emitError(parser.getCurrentLocation(), + "failed to load default tile_buf config"); + return failure(); } - - { - // v_row=?/-1/16 , v_col=?/-1/8 (支持半动态) - if (failed(parseKeyEq("v_row"))) return Type(); - - // 解析 v_row:'?' -> -1,否则整数(允许 -1 兼容) - if (succeeded(parser.parseOptionalQuestion())) { - vrow = -1; - } else { - if (failed(parser.parseInteger(vrow))) return Type(); - if (vrow < -1) { - parser.emitError(parser.getCurrentLocation(), - "v_row must be '?', -1, or a non-negative integer"); - return Type(); - } - } - - if (failed(parser.parseComma())) return Type(); + fields.blayoutStr = stringifyBLayout(defaultBLayout.getValue()).str(); + fields.slayoutStr = stringifySLayout(defaultSLayout.getValue()).str(); + fields.fractal = defaultConfig.getSFractalSize().getInt(); + fields.padInt = static_cast(defaultPad.getValue()); + fields.compactInt = static_cast(defaultCompact.getValue()); + + bool seenValid = false; + bool seenBLayout = false; + bool seenSLayout = false; + bool seenFractal = false; + bool seenPad = false; + bool seenCompact = false; + + while (succeeded(parser.parseOptionalComma())) { + StringRef key; + if (failed(parser.parseKeyword(&key))) + return failure(); + if (failed(parser.parseEqual())) + return failure(); - if (failed(parseKeyEq("v_col"))) return Type(); + if (key == "valid") { + if (seenValid) { + parser.emitError(parser.getCurrentLocation(), + "duplicate valid in tile_buf compact syntax"); + return failure(); + } + seenValid = true; + + SmallVector validShape; + if (failed(parser.parseDimensionList(validShape, /*allowDynamic=*/true, + /*withTrailingX=*/false))) + return failure(); + if (validShape.size() != 2) { + parser.emitError(parser.getCurrentLocation(), + "tile_buf valid must have exactly two dims"); + return failure(); + } + fields.vrow = validShape[0]; + fields.vcol = validShape[1]; + continue; + } - // 解析 v_col:'?' -> -1,否则整数(允许 -1 兼容) - if (succeeded(parser.parseOptionalQuestion())) { - vcol = -1; - } else { - if (failed(parser.parseInteger(vcol))) return Type(); - if (vcol < -1) { - parser.emitError(parser.getCurrentLocation(), - "v_col must be '?', -1, or a non-negative integer"); - return Type(); - } + if (key == "blayout") { + if (seenBLayout) { + parser.emitError(parser.getCurrentLocation(), + "duplicate blayout in tile_buf compact syntax"); + return failure(); + } + seenBLayout = true; + if (failed(parser.parseKeywordOrString(&fields.blayoutStr))) + return failure(); + continue; } - if (failed(parser.parseComma())) return Type(); - } - // blayout=RowMajor - { - if (failed(parseKeyEq("blayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&blayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } + if (key == "slayout") { + if (seenSLayout) { + parser.emitError(parser.getCurrentLocation(), + "duplicate slayout in tile_buf compact syntax"); + return failure(); + } + seenSLayout = true; + if (failed(parser.parseKeywordOrString(&fields.slayoutStr))) + return failure(); + continue; + } + if (key == "fractal") { + if (seenFractal) { + parser.emitError(parser.getCurrentLocation(), + "duplicate fractal in tile_buf compact syntax"); + return failure(); + } + seenFractal = true; + if (failed(parser.parseInteger(fields.fractal))) + return failure(); + continue; + } - // slayout=NoneBox - { - if (failed(parseKeyEq("slayout"))) return Type(); - if (failed(parser.parseKeywordOrString(&slayoutStr))) return Type(); - if (failed(parser.parseComma())) return Type(); - } + if (key == "pad") { + if (seenPad) { + parser.emitError(parser.getCurrentLocation(), + "duplicate pad in tile_buf compact syntax"); + return failure(); + } + seenPad = true; + if (failed(parser.parseInteger(fields.padInt))) + return failure(); + continue; + } - // fractal=512 - { - if (failed(parseKeyEq("fractal"))) return Type(); - if (failed(parser.parseInteger(fractal))) return Type(); - if (failed(parser.parseComma())) return Type(); - } + if (key == "compact") { + if (seenCompact) { + parser.emitError(parser.getCurrentLocation(), + "duplicate compact in tile_buf compact syntax"); + return failure(); + } + seenCompact = true; + if (failed(parser.parseInteger(fields.compactInt))) + return failure(); + continue; + } - // pad=0 - { - if (failed(parseKeyEq("pad"))) return Type(); - if (failed(parser.parseInteger(padInt))) return Type(); + parser.emitError(parser.getCurrentLocation(), + "unknown key in tile_buf compact syntax: ") + << key; + return failure(); } - if (succeeded(parser.parseOptionalComma())) { - if (failed(parseKeyEq("compact"))) return Type(); - if (failed(parser.parseInteger(compactInt))) return Type(); - } + return success(); +} - if (failed(parser.parseGreater())) - return Type(); +static Type buildTileBufType(AsmParser &parser, + const ParsedTileBufFields &fields) { + MLIRContext *ctx = parser.getContext(); - // -------- 语义校验/构造 -------- - if (rows < 0 || cols < 0) { + if (fields.rows < 0 || fields.cols < 0) { parser.emitError(parser.getNameLoc(), "rows/cols must be non-negative"); return Type(); } - auto memorySpace = ::llvm::StringSwitch<::std::optional>(locStr) - .Case("mat", AddressSpace::MAT) - .Case("left", AddressSpace::LEFT) - .Case("right", AddressSpace::RIGHT) - .Case("acc", AddressSpace::ACC) - .Case("vec", AddressSpace::VEC) - .Case("bias", AddressSpace::BIAS) - .Case("scaling", AddressSpace::SCALING) - .Default(::std::nullopt); + auto memorySpace = ::llvm::StringSwitch<::std::optional>( + fields.locStr) + .Case("mat", AddressSpace::MAT) + .Case("left", AddressSpace::LEFT) + .Case("right", AddressSpace::RIGHT) + .Case("acc", AddressSpace::ACC) + .Case("vec", AddressSpace::VEC) + .Case("bias", AddressSpace::BIAS) + .Case("scaling", AddressSpace::SCALING) + .Default(::std::nullopt); if (!memorySpace.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown loc: ") << locStr; + parser.emitError(parser.getNameLoc(), "unknown loc: ") << fields.locStr; return Type(); } - auto bl = symbolizeBLayout(blayoutStr); - auto sl = symbolizeSLayout(slayoutStr); - auto pv = symbolizePadValue(padInt); - auto compact = symbolizeCompactMode(compactInt); + auto bl = symbolizeBLayout(fields.blayoutStr); + auto sl = symbolizeSLayout(fields.slayoutStr); + auto pv = symbolizePadValue(fields.padInt); + auto compact = symbolizeCompactMode(fields.compactInt); if (!bl.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown blayout: ") << blayoutStr; + parser.emitError(parser.getNameLoc(), "unknown blayout: ") + << fields.blayoutStr; return Type(); } if (!sl.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown slayout: ") << slayoutStr; + parser.emitError(parser.getNameLoc(), "unknown slayout: ") + << fields.slayoutStr; return Type(); } if (!pv.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown pad: ") << padInt; + parser.emitError(parser.getNameLoc(), "unknown pad: ") << fields.padInt; return Type(); } if (!compact.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown compact: ") << compactInt; + parser.emitError(parser.getNameLoc(), "unknown compact: ") + << fields.compactInt; return Type(); } - BLayout effectiveBLayout = bl.value(); - if (memorySpace.value() == AddressSpace::LEFT) { + auto normalizeParserBLayout = [&](AddressSpace memorySpace, + BLayout parsedBLayout) -> BLayout { + // LEFT tiles are parser-normalized from the scoped target arch rather than + // from the textual blayout spelling. This preserves the longstanding + // --pto-arch behavior for both legacy and compact tile_buf syntax. + if (memorySpace != AddressSpace::LEFT) + return parsedBLayout; + switch (getPTOParserTargetArch()) { case PTOParserTargetArch::A3: - effectiveBLayout = BLayout::RowMajor; - break; + return BLayout::RowMajor; case PTOParserTargetArch::A5: - effectiveBLayout = BLayout::ColMajor; - break; + return BLayout::ColMajor; case PTOParserTargetArch::Unspecified: - break; + return parsedBLayout; } - } + + return parsedBLayout; + }; + + BLayout effectiveBLayout = + normalizeParserBLayout(memorySpace.value(), bl.value()); auto blAttr = BLayoutAttr::get(ctx, effectiveBLayout); auto slAttr = SLayoutAttr::get(ctx, sl.value()); auto fractalAttr = - IntegerAttr::get(IntegerType::get(ctx, 32), fractal); + IntegerAttr::get(IntegerType::get(ctx, 32), fields.fractal); auto padAttr = PadValueAttr::get(ctx, pv.value()); auto compactAttr = CompactModeAttr::get(ctx, compact.value()); auto memorySpaceAttr = AddressSpaceAttr::get(ctx, memorySpace.value()); - auto cfg = - TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr, compactAttr); + auto cfg = TileBufConfigAttr::get(ctx, blAttr, slAttr, fractalAttr, padAttr, + compactAttr); - SmallVector shape{rows, cols}; - SmallVector validShape{vrow, vcol}; + SmallVector shape{fields.rows, fields.cols}; + SmallVector validShape{fields.vrow, fields.vcol}; auto canonicalValidShape = canonicalizeTileBufValidShape(validShape); - return TileBufType::get(ctx, shape, dtype, memorySpaceAttr, + return TileBufType::get(ctx, shape, fields.dtype, memorySpaceAttr, llvm::ArrayRef(canonicalValidShape), cfg); } +} // namespace + +// ---- TileBufType custom asm ---- +// !pto.tile_buf<> +Type TileBufType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + std::string firstToken; + if (failed(parser.parseKeywordOrString(&firstToken))) + return Type(); + + ParsedTileBufFields fields; + const bool isLegacySyntax = firstToken == "loc"; + if (isLegacySyntax) { + if (failed(parseLegacyTileBufFields(parser, fields))) + return Type(); + } else { + if (failed(parseCompactTileBufFields(parser, firstToken, fields))) + return Type(); + } + + if (isLegacySyntax && succeeded(parser.parseOptionalComma())) { + auto parseKeyEq = [&](StringRef expectedKey) -> LogicalResult { + if (failed(parser.parseKeyword(expectedKey))) + return failure(); + if (failed(parser.parseEqual())) + return failure(); + return success(); + }; + + if (failed(parseKeyEq("compact"))) + return Type(); + if (failed(parser.parseInteger(fields.compactInt))) + return Type(); + } + + if (failed(parser.parseGreater())) + return Type(); + + return buildTileBufType(parser, fields); +} + static llvm::StringRef stringifyLocFromMemorySpace(mlir::Attribute memorySpace) { auto asAttr = llvm::dyn_cast_or_null(memorySpace); switch (asAttr.getAddressSpace()) { @@ -329,60 +526,93 @@ static llvm::StringRef stringifyLocFromPad(mlir::Attribute pad) { static llvm::StringRef stringifyCompactModeInt(mlir::Attribute compactMode) { auto compactAttr = llvm::dyn_cast_or_null(compactMode); - if (!compactAttr) return "9999"; + if (!compactAttr) + return "9999"; switch (compactAttr.getValue()) { - case CompactMode::Null: return "0"; - case CompactMode::Normal: return "1"; - case CompactMode::RowPlusOne: return "2"; - default: - return "9999"; + case CompactMode::Null: + return "0"; + case CompactMode::Normal: + return "1"; + case CompactMode::RowPlusOne: + return "2"; + default: + return "9999"; } } -void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { - auto shape = getShape(); - int64_t rows = shape.size() > 0 ? shape[0] : 0; - int64_t cols = shape.size() > 1 ? shape[1] : 0; - - auto cfg = getConfigAttr(); - if (!cfg) cfg = mlir::pto::TileBufConfigAttr::getDefault(getContext()); - - llvm::StringRef locStr = stringifyLocFromMemorySpace(getMemorySpace()); - - printer << "<" - << "loc=" << locStr - << ", dtype="; - printer.printType(getElementType()); - - auto blayout = llvm::dyn_cast(cfg.getBLayout()); - auto slayout = llvm::dyn_cast(cfg.getSLayout()); +static void printTileBufDim(AsmPrinter &printer, int64_t dim) { + if (dim == ShapedType::kDynamic) + printer << "?"; + else + printer << dim; +} - auto vs = getValidShape(); // ArrayRef - int64_t vrow = rows; - int64_t vcol = cols; +void mlir::pto::TileBufType::print(mlir::AsmPrinter &printer) const { + auto shape = getShape(); + int64_t rows = shape.size() > 0 ? shape[0] : ShapedType::kDynamic; + int64_t cols = shape.size() > 1 ? shape[1] : ShapedType::kDynamic; + + auto cfg = getConfigAttr(); + if (!cfg) + cfg = mlir::pto::TileBufConfigAttr::getDefault(getContext()); + auto defaultCfg = TileBufConfigAttr::getDefault(getContext()); + + llvm::StringRef locStr = stringifyLocFromMemorySpace(getMemorySpace()); + auto blayout = llvm::dyn_cast(cfg.getBLayout()); + auto slayout = llvm::dyn_cast(cfg.getSLayout()); + auto pad = llvm::dyn_cast(cfg.getPad()); + auto compact = llvm::dyn_cast(cfg.getCompactMode()); + auto defaultBLayout = llvm::dyn_cast(defaultCfg.getBLayout()); + auto defaultSLayout = llvm::dyn_cast(defaultCfg.getSLayout()); + auto defaultPad = llvm::dyn_cast(defaultCfg.getPad()); + auto defaultCompact = + llvm::dyn_cast(defaultCfg.getCompactMode()); + + auto vs = getValidShape(); + int64_t vrow = rows; + int64_t vcol = cols; + if (vs.size() >= 2) { + vrow = vs[0]; + vcol = vs[1]; + } - if (vs.size() >= 2) { - vrow = vs[0]; - vcol = vs[1]; - } - printer << ", rows=" << rows - << ", cols=" << cols; - printer << ", v_row="; - if (vrow < 0) printer << "?"; - else printer << vrow; - - printer << ", v_col="; - if (vcol < 0) printer << "?"; - else printer << vcol; - - printer << ", blayout=" << stringifyBLayout(blayout.getValue()) - << ", slayout=" << stringifySLayout(slayout.getValue()) - << ", fractal=" << cfg.getSFractalSize().getInt() - << ", pad=" << stringifyLocFromPad(cfg.getPad()); - if (auto compact = llvm::dyn_cast(cfg.getCompactMode())) { - if (compact.getValue() != CompactMode::Null) - printer << ", compact=" << stringifyCompactModeInt(compact); - } - printer << ">"; + const bool printValid = vrow != rows || vcol != cols; + const bool printBLayout = + blayout && defaultBLayout && blayout.getValue() != defaultBLayout.getValue(); + const bool printSLayout = + slayout && defaultSLayout && slayout.getValue() != defaultSLayout.getValue(); + const bool printFractal = + cfg.getSFractalSize().getInt() != defaultCfg.getSFractalSize().getInt(); + const bool printPad = + pad && defaultPad && pad.getValue() != defaultPad.getValue(); + const bool printCompact = + compact && defaultCompact && + compact.getValue() != defaultCompact.getValue(); + + printer << "<" << locStr << ", "; + printTileBufDim(printer, rows); + printer << "x"; + printTileBufDim(printer, cols); + printer << "x"; + printer.printType(getElementType()); + + if (printValid) { + printer << ", valid="; + printTileBufDim(printer, vrow); + printer << "x"; + printTileBufDim(printer, vcol); + } + if (printBLayout) + printer << ", blayout=" << stringifyBLayout(blayout.getValue()); + if (printSLayout) + printer << ", slayout=" << stringifySLayout(slayout.getValue()); + if (printFractal) + printer << ", fractal=" << cfg.getSFractalSize().getInt(); + if (printPad) + printer << ", pad=" << stringifyLocFromPad(cfg.getPad()); + if (printCompact) + printer << ", compact=" << stringifyCompactModeInt(cfg.getCompactMode()); + + printer << ">"; } diff --git a/test/basic/compact_left_blayout_parser_a3.pto b/test/basic/compact_left_blayout_parser_a3.pto new file mode 100644 index 000000000..dcd53fe1b --- /dev/null +++ b/test/basic/compact_left_blayout_parser_a3.pto @@ -0,0 +1,11 @@ +// RUN: ptoas --pto-arch a3 %s 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a3"} { + func.func @compact_left_blayout_parser_a3() { + %0 = pto.alloc_tile : !pto.tile_buf + return + } +} + +// CHECK-LABEL: func.func @compact_left_blayout_parser_a3() { +// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> diff --git a/test/basic/compact_left_blayout_parser_a5.pto b/test/basic/compact_left_blayout_parser_a5.pto new file mode 100644 index 000000000..2cc498923 --- /dev/null +++ b/test/basic/compact_left_blayout_parser_a5.pto @@ -0,0 +1,11 @@ +// RUN: ptoas --pto-arch a5 %s 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @compact_left_blayout_parser_a5() { + %0 = pto.alloc_tile : !pto.tile_buf + return + } +} + +// CHECK-LABEL: func.func @compact_left_blayout_parser_a5() { +// CHECK: #pto.tile_buf_config, slayout=#pto.slayout, s_fractal_size=512, pad=#pto.pad_value, compact=#pto.compact_mode> diff --git a/test/basic/left_blayout_parser_a3.pto b/test/basic/left_blayout_parser_a3.pto index 52aeeff13..d6617f094 100644 --- a/test/basic/left_blayout_parser_a3.pto +++ b/test/basic/left_blayout_parser_a3.pto @@ -1,6 +1,6 @@ // RUN: ptoas --pto-arch a3 %s | FileCheck %s -module { +module attributes {"pto.target_arch" = "a3"} { func.func @left_blayout_parser_a3() { %c0 = arith.constant 0 : index %src = pto.alloc_tile : !pto.tile_buf @@ -10,4 +10,5 @@ module { } } -// CHECK: left_blayout_parser_a3 +// CHECK-LABEL: __global__ AICORE void left_blayout_parser_a3() { +// CHECK: Tile diff --git a/test/basic/left_blayout_parser_a5.pto b/test/basic/left_blayout_parser_a5.pto index 3311495ab..2d9b950d3 100644 --- a/test/basic/left_blayout_parser_a5.pto +++ b/test/basic/left_blayout_parser_a5.pto @@ -1,6 +1,6 @@ // RUN: ptoas --pto-arch a5 %s | FileCheck %s -module { +module attributes {"pto.target_arch" = "a5"} { func.func @left_blayout_parser_a5() { %c0 = arith.constant 0 : index %src = pto.alloc_tile : !pto.tile_buf @@ -10,4 +10,5 @@ module { } } -// CHECK: left_blayout_parser_a5 +// CHECK-LABEL: __global__ AICORE void left_blayout_parser_a5() { +// CHECK: Tile diff --git a/test/python/compact_tile_buf_asm.py b/test/python/compact_tile_buf_asm.py new file mode 100644 index 000000000..8c07e945c --- /dev/null +++ b/test/python/compact_tile_buf_asm.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 + +from mlir.ir import Context, F32Type, MLIRError, Module +from mlir.dialects import pto + + +def expect_equal(actual: str, expected: str, label: str) -> None: + if actual != expected: + raise AssertionError( + f"{label} mismatch\nexpected: {expected}\nactual: {actual}" + ) + + +def expect_contains(text: str, needle: str, label: str) -> None: + if needle not in text: + raise AssertionError( + f"{label} missing substring\nneedle: {needle}\ntext:\n{text}" + ) + + +def expect_not_contains(text: str, needle: str, label: str) -> None: + if needle in text: + raise AssertionError( + f"{label} unexpectedly contained substring\nneedle: {needle}\ntext:\n{text}" + ) + + +def expect_parse_error(ctx: Context, asm: str, needle: str, label: str) -> None: + try: + Module.parse(asm, ctx) + except MLIRError as err: + if needle not in str(err): + raise AssertionError( + f"{label} error mismatch\nexpected substring: {needle}\nactual: {err}" + ) from err + return + raise AssertionError(f"{label} unexpectedly parsed successfully") + + +def main() -> None: + with Context() as ctx: + pto.register_dialect(ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + col_major = pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx) + row_major = pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx) + zero_pad = pto.PadValueAttr.get(pto.PadValue.Zero, ctx) + cfg = pto.TileBufConfigAttr.get(col_major, row_major, 1024, zero_pad, ctx) + + default_ty = pto.TileBufType.get([1, 16], F32Type.get(ctx), vec, context=ctx) + expect_equal( + str(default_ty), + "!pto.tile_buf", + "default compact print", + ) + + valid_ty = pto.TileBufType.get( + [16, 128], + F32Type.get(ctx), + vec, + valid_shape=[16, 1], + context=ctx, + ) + expect_equal( + str(valid_ty), + "!pto.tile_buf", + "valid suffix print", + ) + + non_default_cfg_ty = pto.TileBufType.get( + [8, 8], + F32Type.get(ctx), + vec, + config=cfg, + context=ctx, + ) + expect_equal( + str(non_default_cfg_ty), + "!pto.tile_buf", + "non-default config suffix print", + ) + + compact_cfg_ty = pto.TileBufType.get( + [8, 8], + F32Type.get(ctx), + vec, + config=pto.TileBufConfigAttr.get( + col_major, + row_major, + 1024, + zero_pad, + ctx, + compact_mode=pto.CompactMode.RowPlusOne, + ), + context=ctx, + ) + expect_equal( + str(compact_cfg_ty), + "!pto.tile_buf", + "non-default compact suffix print", + ) + + legacy_module = Module.parse( + """ +module { + func.func @legacy( + %arg0: !pto.tile_buf) { + return + } +} +""", + ctx, + ) + legacy_text = str(legacy_module) + expect_contains( + legacy_text, + "!pto.tile_buf", + "legacy parse reprint", + ) + expect_not_contains(legacy_text, "loc=", "legacy parse reprint") + expect_not_contains(legacy_text, "v_row=", "legacy parse reprint") + expect_not_contains(legacy_text, "v_col=", "legacy parse reprint") + + compact_module = Module.parse( + """ +module { + func.func @compact( + %arg0: !pto.tile_buf) { + return + } +} +""", + ctx, + ) + expect_contains( + str(compact_module), + "!pto.tile_buf", + "compact parse roundtrip", + ) + + compact_mode_module = Module.parse( + """ +module { + func.func @compact_mode( + %arg0: !pto.tile_buf) { + return + } +} +""", + ctx, + ) + expect_contains( + str(compact_mode_module), + "!pto.tile_buf", + "compact mode parse roundtrip", + ) + + expect_parse_error( + ctx, + """ +module { + func.func @dup_valid( + %arg0: !pto.tile_buf) { + return + } +} +""", + "duplicate valid in tile_buf compact syntax", + "duplicate valid rejection", + ) + expect_parse_error( + ctx, + """ +module { + func.func @dynamic_base(%arg0: !pto.tile_buf) { + return + } +} +""", + "expected static shape", + "dynamic base shape rejection", + ) + expect_parse_error( + ctx, + """ +module { + func.func @bad_valid_rank( + %arg0: !pto.tile_buf) { + return + } +} +""", + "tile_buf valid must have exactly two dims", + "valid rank rejection", + ) + + print("compact_tile_buf_asm: PASS") + + +if __name__ == "__main__": + main()