diff --git a/.gitignore b/.gitignore index eb685574..af04a51b 100644 --- a/.gitignore +++ b/.gitignore @@ -101,4 +101,6 @@ launcher_cxx11abi* # package backend/triton-shared-opt-v3* backend/dicp_opt -third_party/triton-shared-opt \ No newline at end of file +third_party/triton-shared-opt + +FA_FWD_PROF \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 693c96bb..2babb868 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,16 +50,24 @@ if (TRITON_BUILD_PYTHON_MODULE) MLIRTransforms MLIRSupport MLIRBytecodeWriter + MLIRTransformDialect + ${extension_libs} + MLIRRegisterAllPasses + MLIRLinalgTransforms + MLIRTensorTransforms + MLIRFuncAllExtensions TritonToLinalgNPUCoversion LinalgExtTransforms TritonExtTransforms + LinalgExtAnalysis LinalgToLinked LinkedToHIVM DiscreteMaskAccessConversion TritonToUnstructure + DICPTransformOps ) target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers) endif() \ No newline at end of file diff --git a/backend/compiler.py b/backend/compiler.py index 804bdfdf..624f21db 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -131,6 +131,7 @@ def __init__(self, target: str) -> None: self.binary_ext = "mcfatbin" elif self.driver.target == "ascend": self.binary_ext = "npubin" + self.capability = target.arch else: raise RuntimeError(f"Target '{self.target_type}' is not supported.") @@ -249,7 +250,7 @@ def add_stages(self, stages, options, language=None): ) stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile( - src, metadata, options + src, metadata, options, self.capability ) ) else: @@ -264,17 +265,20 @@ def add_stages(self, stages, options, language=None): ) stages["npubin"] = ( lambda src, metadata: linalg_to_bin_enable_npu_compile( - src, metadata, options + src, metadata, options, self.capability ) ) else: raise RuntimeError("backend not supported") def load_dialects(self, ctx): + # TODO Warning If additional backends are integrated into the common IR with customized passes, their respective Dialect interfaces must be registered here. A decoupled registration mechanism for each backend is preferred to maintain modularity. if self.driver.target == "mlu": from triton._C.libtriton import mlu - mlu.load_dialects(ctx) + else: + from triton._C.libtriton import dicp_triton + dicp_triton.load_dialects(ctx) return @functools.lru_cache() diff --git a/backend/npu.py b/backend/npu.py index fc39baac..53b2e3ba 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -491,10 +491,20 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) + # dicp_triton.passes.linked_npu.add_npu_vector_tile_tagging(pm,2) + # dicp_triton.passes.linked_npu.add_npu_vector_tile_transform(pm) + # dicp_triton.passes.linked_npu.add_fuse_loop(pm) + # dicp_triton.passes.linked_npu.add_de_linalgize(pm) + # dicp_triton.passes.linked_npu.add_loop_unroll_stage(pm) + # dicp_triton.passes.linked_npu.add_npu_unroll_pipeline(pm) pm.run(mod) - - # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 content = str(mod) + dicp_utils._dump_stage_ir( + content, metadata["hash"], "kernel.linkedir.mlir.2" + ) + if replace_linked_ir is not None: + content= Path(replace_linked_ir).read_text() + # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 # 将"*xfxxx"替换成"?xfxxx" content = content.replace("*xf", "?xf") content = content.replace("*xi", "?xi") @@ -520,9 +530,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): content, metadata["hash"], "kernel.linkedir.mlir", cmd_list ) - if replace_linked_ir is not None: - print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}") - return Path(replace_linked_ir).read_text() + return content @@ -683,7 +691,7 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): return linalg, metadata -def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): +def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt, capability): linalg, metadata = _parse_linalg_metadata(linalg, metadata) with tempfile.TemporaryDirectory() as tmpdir: ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") @@ -706,6 +714,8 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): _compile_option_list += [ f"--enable-auto-multi-buffer={multibuffer}", ] + if capability: + _compile_option_list += [f"--target={capability}"] if _is_ascend_sanitizer_enabled(): _compile_option_list += ["--enable-sanitizer=true"] diff --git a/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h b/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h new file mode 100644 index 00000000..cb16f301 --- /dev/null +++ b/compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h @@ -0,0 +1,166 @@ +#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H +#define DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H + +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" + +#include +#include +#include + +namespace mlir { +namespace dicp { + +/// Classification of a dimension's role in the computation graph. +/// This helps determine if a dimension is safe to tile or parallelize. +enum class DimKind { + Unknown, // No specific property inferred yet. + Parallel, // Dimension implies independent iterations (safe to tile). + Reduction, // Dimension is collapsed/reduced (requires accumulation). + Broadcast, // Dimension is replicated (data invariant along this axis). + Complex // Dimension undergoes complex transformation (e.g., non-affine + // reshape). +}; + +std::string toString(DimKind k); + +/// Disjoint Set Union (DSU) for tracking dimension equivalence and properties. +/// +/// This class implements a Disjoint Set data structure (Union-Find) +/// specifically designed for Tensor/MemRef dimensions. It serves two main +/// purposes: +/// 1. **Equivalence Tracking**: Determines which dimensions across different +/// values +/// represent the same logical axis (e.g., the 'N' dimension in a Matmul +/// propagating through element-wise adds). +/// 2. **Property Propagation**: Merges semantic properties (DimKind) when +/// dimensions +/// are unified. For example, if a dimension is used as a Reduction iterator +/// in one operation, that property propagates to all equivalent dimensions +/// in the set. +class DimensionDisjointSet { +public: + explicit DimensionDisjointSet(size_t size = 0) { resize(size); } + + /// Allocates `n` new dimension IDs in the set. + /// \return The ID of the first allocated dimension. + int64_t allocate(size_t n = 1); + + /// Finds the representative (root) ID for the set containing dimension `i`. + /// Implements path compression for amortized constant time lookups. + int64_t find(int64_t i); + + /// Merges the sets containing dimensions `i` and `j`. + /// This also merges the `DimKind` properties of both roots using + /// `mergeKinds`. + void unionSets(int64_t i, int64_t j); + + /// Updates the DimKind property for the set containing dimension `i`. + /// The new kind is merged with the existing kind to ensure safety (e.g., + /// Reduction is sticky). + void setKind(int64_t i, DimKind k); + + /// Retrieves the DimKind property of the set containing dimension `i`. + DimKind getKind(int64_t i); + +private: + /// Resizes the internal storage to accommodate `n` dimensions. + void resize(size_t n); + + /// Defines the logic for combining two dimension kinds. + /// Hierarchy of "stickiness": Complex > Reduction > Broadcast/Parallel. + DimKind mergeKinds(DimKind a, DimKind b); + + std::vector parent; // Parent pointers for DSU. + std::vector kind; // Properties associated with each root. +}; + +/// DimAnalyzer: +/// Analyzes a specific execution stage (StageInfo) to determine tiling +/// strategies. +/// +/// The analyzer constructs a constraint graph where nodes are tensor dimensions +/// and edges represent data flow relationships. It uses a Breadth-First Search +/// (BFS) approach to traverse operations and propagate dimension IDs. +/// +/// Algorithm Overview: +/// 1. **Initialization**: Seeds the analysis with stage inputs (operands +/// defined outside the stage). +/// 2. **BFS Propagation**: Traverses the def-use chains. For each operation, it +/// uses specific handlers (e.g., processMatmulOp) to bind input dimensions to +/// output dimensions. +/// 3. **Anchor Heuristic**: Identifies the "Anchor" operation (typically the +/// final LinalgOp) to interpret the resulting loops. +/// 4. **Tiling Selection**: Checks the properties of the Anchor's loops in the +/// DSU to recommend outermost parallel loops for tiling. +class DimAnalyzer { +public: + explicit DimAnalyzer(const StageInfo &stage); + + /// Analyzes the stage operations and returns indices of loops recommended for + /// tiling. The indices correspond to the loop nest of the "Anchor" operation. + SmallVector analyzeAndGetTilingDims(); + +private: + const StageInfo &stage_; + // Quick lookup for ops belonging to this stage. + DenseSet stageOps_; + DimensionDisjointSet dsu_; + // Maps SSA Value -> [Dim IDs] + DenseMap> valueDims_; + + // BFS State passed to handlers to allow them to enqueue new values. + using BFSQueue = std::queue; + using VisitedSet = DenseSet; + + /// Drives the traversal of the data flow graph. + void processBFS(); + + /// Dispatches the operation to the appropriate handler. + /// \return true if the operation was handled, false otherwise. + bool processOperation(Operation *op, Value current, BFSQueue &q, + VisitedSet &v); + + /// Lazily retrieves or allocates unique IDs for the dimensions of a Value. + std::vector getOrAllocateDims(Value v); + + /// Helper to strictly bind all dimensions of v1 to v2 (1-to-1 mapping). + /// Used for Elementwise, Copy, etc. + void bindDimensions(Value v1, Value v2); + + // --- Op Handlers --- + // Each handler interprets the semantics of the op to union input/output + // dimensions correctly. + + void processElementwise(Operation *op, Value current); + void processMatmulOp(linalg::MatmulOp op); + void processReduceOp(linalg::ReduceOp op); + void processTransposeOp(linalg::TransposeOp op); + void processBroadcastOp(linalg::BroadcastOp op); + void processLinalgOpGeneric(linalg::LinalgOp op); + void processReshapeOp(Operation *op); + void processConcatOp(tensor::ConcatOp op); + void processPadOp(tensor::PadOp op); + void processExtractSliceOp(tensor::ExtractSliceOp op); + void processInsertSliceOp(tensor::InsertSliceOp op); + + // Handlers that may need to continue BFS propagation explicitly + void processMemrefCopyOp(memref::CopyOp op, Value current, BFSQueue &q, + VisitedSet &v); + void processMemrefCastOp(Operation *op); + void processBufferizationToTensor(bufferization::ToTensorOp op); + void processMaterializeOp(bufferization::MaterializeInDestinationOp op, + Value current, BFSQueue &q, VisitedSet &v); +}; + +} // namespace dicp +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h new file mode 100644 index 00000000..955e59bc --- /dev/null +++ b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h @@ -0,0 +1,105 @@ +#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H +#define DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/SetVector.h" + +#include +#include + +namespace mlir { +namespace dicp { + +/// Defines the execution unit type for the stage. +enum class StageType { + Vector, // Default: General vector or scalar operations + Cube // Matrix operations (e.g., Matmul) +}; + +/// Represents a single pipeline stage. +/// A stage is a sequence of operations that execute together. +/// Synchronization operations (SyncBlockWaitOp) typically delimit stage +/// boundaries. +struct StageInfo { + int id = -1; + SmallVector ops; + // IDs of stages that this stage depends on + std::set preds; + // IDs of stages that depend on this stage + std::set succs; + bool hasSync = false; + // Stage execution type + StageType type = StageType::Vector; +}; + +// StageDependencyAnalyzer: +// 1. Partitioning a block into "stages" based on synchronization primitives +// (hivm::SyncBlockWaitOp). +// 2. Building a dependency graph between these stages considering both: +// - SSA Data Flow (Producer-Consumer relationships). +// - Memory Dependencies (Read-After-Write via AliasAnalysis). +// 3. Computing a topological ordering (levels) to detect cycles and determine +// a valid execution schedule. +// 4. Physically reordering the IR operations to match the valid schedule. +// +class StageDependencyAnalyzer { +public: + StageDependencyAnalyzer(Block *block, AliasAnalysis &aliasAnalysis) + : block(block), aliasAnalysis(aliasAnalysis) {} + + /// Runs the analysis, computes the topological sort, and physically reorders + /// the operations in the block. + /// Returns the ordered list of StageInfo on success, or failure if a cycle is + /// detected. + FailureOr> runAndReorder(RewriterBase &rewriter); + + /// Scans the block to populate the `stages` vector. + FailureOr> collectStages(); + +private: + /// Internal node structure for the dependency graph. + struct StageNode { + int id; + StageInfo *stageInfo; + int level = 0; // Topological level (depth) + + // Memory dependencies + llvm::SetVector readValues; + llvm::SetVector writeValues; + + // SSA Value dependencies + llvm::SetVector producedValues; // Values defined in this stage + llvm::SetVector consumedValues; // Values used in this stage + }; + + Block *block; + AliasAnalysis &aliasAnalysis; + std::vector stages; + std::vector nodes; + + /// Collects SSA definitions/uses and Memory Read/Write effects for each + /// stage. + void collectEffects(); + + /// Builds the directed graph edges based on SSA and Memory conflicts. + void buildDependencyGraph(); + + /// Computes the topological level of each node using DFS. + /// Returns failure if a cycle is detected. + LogicalResult computeStageLevels(); + + /// Sorts the `stages` vector based on the computed topological levels. + void reorderStagesLogical(); + + /// Moves the operations in the IR to match the logical order of `stages`. + void materializeScheduleToIR(); +}; + +} // namespace dicp +} // namespace mlir + +#endif // DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H \ No newline at end of file diff --git a/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageUtils.h b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageUtils.h new file mode 100644 index 00000000..03bf9f0c --- /dev/null +++ b/compiler/include/dicp/Dialect/LinalgExt/Analysis/StageUtils.h @@ -0,0 +1,105 @@ +#ifndef DICP_LINALGEXT_STAGEUTILS_H +#define DICP_LINALGEXT_STAGEUTILS_H + +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace dicp { +namespace LinalgExt { + +/// Represents a subset of operations within a stage, bounded by synchronization +/// points (e.g., HIVM Sync ops). +struct SubStage { + unsigned index = 0; + int stageId = -1; + SmallVector ops; + + bool isValid() const { return !ops.empty(); } +}; + +/// Represents a logical execution stage consisting of multiple SubStages. +struct Stage { + int id = -1; + SmallVector subStages; + StageType type = StageType::Vector; + Stage() = default; + Stage(int id, StageType type) : id(id), type(type) { + subStages.push_back(SubStage{0, id, {}}); + } + + void addOp(Operation *op) { + if (subStages.empty()) + subStages.push_back(SubStage{0, id, {}}); + subStages.back().ops.push_back(op); + } + bool isValid() const { return id != -1 && !subStages.empty(); } + + /// Returns the total number of operations across all substages. + size_t getTotalOpCount() const { + size_t count = 0; + for (const auto &ss : subStages) + count += ss.ops.size(); + return count; + } +}; + +/// Check if an operation belongs to a specific stage ID. +bool isOpInStage(Operation *op, int stageId); + +/// Shared utility class for analyzing blocks and partitioning them into stages. +class StagePartitioner { +public: + /// Identifies all blocks containing HIVM sync operations. + static SmallVector findBlocksWithHivmSyncOps(ModuleOp module); + + /// Analyzes a block for stage dependencies and tags operations with stage + /// attributes. \returns success if analysis succeeds, failure if a cycle is + /// detected. + static LogicalResult analyzeAndTagBlock(Block *block, MLIRContext *ctx, + bool &anyStageFound); + + /// Extracts all unique stage IDs present in the block in deterministic order. + static SetVector getStageIdsInBlock(Block *block); + + /// Partitions the block for the given stageId into a Stage object containing + /// SubStages. This is the primary entry point for stage-based decomposition. + static Stage partition(Block *block, int stageId); + + /// Collects all stages present in the block. + static SmallVector getAllStagesInBlock(Block *block); +}; + + +class CubeVectorSplitter { +public: + /// Segments the provided block into a sequence of Cube and Vector stages. + /// + /// \param block The block to analyze. + /// \param stages Output vector to hold the resulting stages. + /// \return failure() if validation fails (e.g., illegal nesting). + static LogicalResult splitBlock(Block &block, + llvm::SmallVectorImpl &stages); + + /// Finds the core computation block based on the "mix_mode" attribute. + /// Uses a maximum-density strategy for "mix" mode, and an outermost-level + /// strategy for "aiv" / "aic" single modes. + static Block *findTargetBlock(func::FuncOp funcOp); +}; + + +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +#endif // DICP_LINALGEXT_STAGEUTILS_H \ No newline at end of file diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index 7ae43b6c..f46969e8 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -14,6 +14,9 @@ class FuncOp; namespace mlir::dicp::LinalgExt { +#define GEN_PASS_DECL +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" + std::unique_ptr> createLinalgIfToSelectPass(); std::unique_ptr> createLinalgGenericToSCFPass(); @@ -23,6 +26,24 @@ std::unique_ptr> createScalarTo1DTensorPass(); std::unique_ptr> createNormalizeSliceOpsPass(); +std::unique_ptr> +createNPUUnroolPipelinePass(); + +std::unique_ptr> +createNPUVectorTileTaggingPass(const NPUVectorTileTaggingOptions &options = {}); +std::unique_ptr> +createNPUVectorTileTaggingPass(unsigned vectorTile); + +std::unique_ptr> +createNPUVectorTileTransformPass(); + +std::unique_ptr> createDeLinalgizePass(); + +std::unique_ptr> createFuseLoopPass(); +std::unique_ptr> createLoopUnrollStagePass(); + +std::unique_ptr> createShrinkBuffersPass(); + #define GEN_PASS_REGISTRATION #include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index c486210a..9b8df24f 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -68,4 +68,102 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } +def NPUUnroolPipeline : Pass<"npu-unrool-pipeline", "func::FuncOp"> { + let summary = "DLC Pipelines."; + let constructor = "mlir::dicp::LinalgExt::createNPUUnroolPipelinePass()"; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect", + "mlir::bufferization::BufferizationDialect", + "mlir::func::FuncDialect" + ]; +} + + +def NPUVectorTileTagging : Pass<"npu-tile-loop-tagging", "ModuleOp"> { + let summary = "Normalize and tag operations for NPU tiling and fusion"; + let description = [{ + This pass performs IR normalization (elementwise to generic, copy lowering) + and analyzes stages to identify tiling anchors and fusion candidates. + It marks operations with attributes to guide the subsequent transform pass. + }]; + let constructor = "mlir::dicp::LinalgExt::createNPUVectorTileTaggingPass()"; + let options = [ + Option<"tiledMixVectorLoopNumber", "vector-tile", "unsigned", + /*default=*/"2", "Trip count for vector loop tiling"> + ]; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + "mlir::tensor::TensorDialect", + "mlir::memref::MemRefDialect", + "mlir::bufferization::BufferizationDialect", + "mlir::scf::SCFDialect" + ]; +} + +def NPUVectorTileTransform : Pass<"npu-tile-loop-transform", "ModuleOp"> { + let summary = "Apply tiling and fusion using Transform Dialect based on tags"; + let description = [{ + This pass reads the tags generated by the tagging pass and executes a + Transform Dialect sequence to perform the actual tiling and fusion. + }]; + let constructor = "mlir::dicp::LinalgExt::createNPUVectorTileTransformPass()"; + let dependentDialects = [ + "mlir::transform::TransformDialect", + "mlir::linalg::LinalgDialect", + "mlir::scf::SCFDialect" + ]; +} + +def DeLinalgize : Pass<"de-linalgize", "mlir::ModuleOp"> { + let summary = "De-linalgize Linalg operations back to specific dialects."; + let description = [{ + This pass restores high-level operations from their `linalg` representations: + 1. Converts `linalg.generic` back to elementwise operations (e.g., `arith` or `math`). + 2. Converts `linalg.copy` back to `bufferization.materialize_in_destination` + or `memref.copy`. + It is essentially the inverse of normalization/generalization passes. + }]; + + let constructor = "mlir::dicp::LinalgExt::createDeLinalgizePass()"; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::memref::MemRefDialect", + "mlir::bufferization::BufferizationDialect", + "mlir::linalg::LinalgDialect" + ]; +} + +def FuseLoop : Pass<"fuse-loop", "mlir::ModuleOp"> { + let summary = "Fuse cube and vector loops"; + let constructor = "mlir::dicp::LinalgExt::createFuseLoopPass()"; + let dependentDialects = ["mlir::transform::TransformDialect"]; +} + +def LoopUnrollStage : Pass<"dicp-loop-unroll-stage", "mlir::func::FuncOp"> { + let summary = "Unroll loops containing operations marked with DICP stage attributes."; + let description = [{ + This pass identifies `scf.for` loops that contain operations with a specific + DICP stage prefix and performs full unrolling when constant bounds are available. + After processing, it cleans up the internal stage attributes. + }]; + let constructor = "mlir::dicp::LinalgExt::createLoopUnrollStagePass()"; + let dependentDialects = [ + "mlir::scf::SCFDialect", + "mlir::func::FuncDialect" + ]; +} + +def ShrinkBuffers : Pass<"shrink-buffers", "mlir::func::FuncOp"> { + let summary = "Shrink memref.alloc and tensor.empty based on consistent slicing usage."; + let description = [{ + Runs two rewrite patterns (ShrinkAllocWithSlicing and ShrinkEmptyTensorWithSlicing) + to reduce allocation/empty sizes when all slices agree on a smaller size. + }]; + let constructor = "mlir::dicp::LinalgExt::createShrinkBuffersPass()"; +} + #endif diff --git a/compiler/include/dicp/TransformOps/CMakeLists.txt b/compiler/include/dicp/TransformOps/CMakeLists.txt new file mode 100644 index 00000000..c953afa4 --- /dev/null +++ b/compiler/include/dicp/TransformOps/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS DicpTransformOps.td) +mlir_tablegen(DicpTransformOps.h.inc -gen-op-decls) +mlir_tablegen(DicpTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(DICPTransformOpsIncGen) +add_dependencies(mlir-headers DICPTransformOpsIncGen) diff --git a/compiler/include/dicp/TransformOps/DicpTransformOps.h b/compiler/include/dicp/TransformOps/DicpTransformOps.h new file mode 100644 index 00000000..0652f88a --- /dev/null +++ b/compiler/include/dicp/TransformOps/DicpTransformOps.h @@ -0,0 +1,40 @@ +//===- DicpTransformOps.h - DICP transform ops -------------*- C++-*-===// +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef DICP_DICPTRANSFORMOPS_H +#define DICP_DICPTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/TransformTypes.h" + +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// DICP Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "dicp/TransformOps/DicpTransformOps.h.inc" + +namespace mlir { +namespace dicp { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace dicp +} // namespace mlir + +#endif // DICP_DicpTransformOps_H diff --git a/compiler/include/dicp/TransformOps/DicpTransformOps.td b/compiler/include/dicp/TransformOps/DicpTransformOps.td new file mode 100644 index 00000000..6c9fdc71 --- /dev/null +++ b/compiler/include/dicp/TransformOps/DicpTransformOps.td @@ -0,0 +1,179 @@ +#ifndef DICP_DICPTRANSFORMOPS +#define DICP_DICPTRANSFORMOPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +def ReverseOp : Op, + TransformOpInterface]> { + let description = [{ + This transform op gets and reverses the list of operations held by the + input `target` handle. + + This transform reads the `target` handle and produces the `result` handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$result); + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + }]; + let assemblyFormat = [{ + $target attr-dict `:` functional-type($target, $result) + }]; +} + +def ForwardInitToIterArgOp : Op, + TransformOpInterface]> { + let description = [{ + Target must be a handle to `scf.for` operations. + + This transform analyzes the `scf.for` loop body. If it finds an `extract_slice` + using the loop's init_arg that is structurally equivalent to the `insert_slice` + yielded into the corresponding iter_arg, it replaces the `extract_slice` source + with the iter_arg. + + This enables In-Place updates and removes dependencies on the init_arg. + }]; + + // 输入是一个指向 scf.for 的 Handle + let arguments = (ins TransformHandleTypeInterface:$target); + // 输出是处理后的 scf.for Handle (通常还是原来的 Op,但内容变了) + let results = (outs TransformHandleTypeInterface:$result); + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + }]; + + let assemblyFormat = [{ + $target attr-dict `:` functional-type($target, $result) + }]; +} + +//===----------------------------------------------------------------------===// +// ExtendedFuseIntoContainingOp +//===----------------------------------------------------------------------===// +def ExtendedFuseIntoContainingOp : + Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait, + AttrSizedResultSegments, + FunctionalStyleTransformOpTrait]> { + let summary = "Fuse a producer into a containing operation."; + + let description = [{ + Fuses the `producer_op` into the `containing_op`. + Returns a handle to the fused ops and the `new_containing_op`. + + The producer is typically a slice of a tileable op (i.e., implements + TilingInterface). In that case, this transform computes the accessed + producer slice inside of the containing op ("tile and fuse") and if required, + creates a new containing op with outputs from the fused producer. Otherwise, + the entire producer is cloned inside the containing op ("clone and fuse"). + + Each containing op handle must be associated with exactly one payload op. The + producer op handle may be associated with multiple payload ops. This + transform fuses producers one-by-one, always picking an unspecified producer + that has at least one use inside the containing op among the + producers. A producer can be listed multiple times in the handle. + + If the `producer_op` has uses that are post-dominated by the `containing_op`, + then it is fused into `containing_op` completely to avoid recomputation. + This behavior can be disabled by setting `duplicate_producer` to true. + + Note: If a producer has multiple uses inside the containing op, a union + of the requested regions is computed, and each consumer will only access the + region it needs via slicing. + + #### Return modes + + If at least one producer could not be fused, this operation produces a + silenceable failure. This is the case when tiling fails or when no + producer op could be found among the remaining producers that has at least + one use within the containing op. I.e., "producers" that are not consumed + within the containing op are rejected by this operation. + + This operation consumes the producer handle. + This operation only reads the containing op handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$producer_op, + Variadic:$containing_op, + DefaultValuedOptionalAttr:$duplicate_producer); + + let results = (outs Variadic:$fused_op, + Variadic:$new_containing_op); + + let builders = [ + OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure fuseIntoOneContaining( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &results, + ::mlir::transform::TransformState &state, + size_t index, + ::mlir::Operation* containingOp); + }]; + + let hasCustomAssemblyFormat = 1; +} + +def ExtendedLoopFuseSiblingOp : Op]> { + let summary = "Fuse a loop into another loop, assuming the fusion is legal."; + + let description = [{ + Fuses the `target` loop into the `source` loop assuming they are + independent of each other. In the fused loop, the arguments, body and + results of `target` are placed _before_ those of `source`. + + For fusion of two `scf.for` loops, the bounds and step size must match. For + fusion of two `scf.forall` loops, the bounds and the mapping must match. + Otherwise a silencable failure is produced. + + The `target` and `source` handles must refer to exactly one operation, + otherwise a definite failure is produced. It is the responsibility of the + user to ensure that the `target` and `source` loops are independent of each + other -- this op will only perform rudimentary legality checks. + + #### Return modes + + This operation consumes the `target` and `source` handles and produces the + `fused_loop` handle, which points to the fused loop. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + TransformHandleTypeInterface:$source); + let results = (outs TransformHandleTypeInterface:$fused_loop); + let assemblyFormat = "$target `into` $source attr-dict " + " `:` functional-type(operands, results)"; +} + +#endif // DICP_DicpTransformOps diff --git a/compiler/include/dicp/TransformOps/Transforms.h b/compiler/include/dicp/TransformOps/Transforms.h new file mode 100644 index 00000000..9b36c4fc --- /dev/null +++ b/compiler/include/dicp/TransformOps/Transforms.h @@ -0,0 +1,66 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" + +#ifndef BISHENGIR_TRANSFORMS_TRANSFORMS_H +#define BISHENGIR_TRANSFORMS_TRANSFORMS_H + +#define DICP_STAGE_PREFIX "dicp.stage." +namespace mlir::dicp { + +static const llvm::StringLiteral kDicpStagePrefix = DICP_STAGE_PREFIX; +static const llvm::StringLiteral kStageOpToTileAttr = + DICP_STAGE_PREFIX "op_to_tile.stage_{0}_sub_{1}_u{2}_"; +static const llvm::StringLiteral kStageProducerToFuseAttr = + DICP_STAGE_PREFIX "producer_to_fuse.stage_{0}_sub_{1}_u{2}"; +static const llvm::StringLiteral kStageProducerAllocToFuseAttr = + DICP_STAGE_PREFIX "alloc_producer"; +static const llvm::StringLiteral kCrossTillUnitAttr = + DICP_STAGE_PREFIX "till_unit_has_cross_user"; +static const llvm::StringLiteral kHadFusedAttr = + DICP_STAGE_PREFIX "op_had_fused"; + +static const llvm::StringLiteral kNPUStageAttrName = "dicp.npu.stage"; +static const llvm::StringLiteral kOriginalOpNameAttr = "dicp.original_op_name"; + +void unionProducerUsers(mlir::RewriterBase &rewriter, mlir::Diagnostic &diag, + mlir::Operation *producerOp, + mlir::Operation *containingOp); +std::tuple, mlir::Operation *> + +tileAndFuseAllSubsetOps(mlir::RewriterBase &rewriter, mlir::Diagnostic &diag, + mlir::Operation *producerOp, + mlir::Operation *containingOp, bool duplicateProducer); + +llvm::SmallVector +tileAndFuseAllSubsetOpsThroughContainingOpBlockArgument( + mlir::RewriterBase &rewriter, mlir::Diagnostic &diag, + mlir::Operation *producerOp, LoopLikeOpInterface containingOp); + +mlir::Operation *cloneAndFuseAllSubsetOps(mlir::RewriterBase &rewriter, + mlir::Diagnostic &diag, + mlir::Operation *producerOp, + mlir::Operation *containingOp); + +/// Callback function type for generating transform dialect operations. +/// \param builder The OpBuilder to use. +/// \param loc The location for generated ops. +/// \param rootHandle The handle to the root operation (usually the module or +/// block). +using TransformGenerationCallback = + std::function; + +/// Shared utility to apply a unified transform sequence to a module. +class TransformApplier { +public: + /// Applies a transformation defined by the generator callback to the module. + /// Uses a transactional approach (clones module) to ensure safety. + static void apply(ModuleOp module, TransformGenerationCallback generator); +}; + +} // namespace mlir::dicp + +#endif // BISHENGIR_TRANSFORMS_TRANSFORMS_H \ No newline at end of file diff --git a/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp b/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp index b62ba68d..f69ea288 100644 --- a/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp +++ b/compiler/lib/Conversion/LinalgToLinked/TritonOpConverter.cpp @@ -405,12 +405,6 @@ ScanConverter::convertToTargetOp(triton::ScanOp op, Value scanInput = op.getOperand(0); - scanInput.dump(); - - for (Value operand : op->getOperands()) { - operand.dump(); - } - auto srcType = mlir::dyn_cast(scanInput.getType()); if (!srcType) { return rewriter.notifyMatchFailure( diff --git a/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp index 3f69b378..36957c91 100644 --- a/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp +++ b/compiler/lib/Conversion/TritonToUnstructure/UnstructureConversionPass.cpp @@ -1,6 +1,8 @@ #include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h" #include "dicp/Utils/Utils.h" + #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "bishengir/Dialect/Annotation/IR/Annotation.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -680,7 +682,7 @@ void replacePtrLoopArguments(Operation *rootOp, op.getLoc(), rewriter.getI32Type(), ValueRange({})) ->getResult(0); if (auto forOp = dyn_cast(op.getOperation())) { - newOp = rewriter.create( + auto createdFor = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), constructOperands(forOp.getInitArgs(), tempVar, mapping), @@ -701,6 +703,13 @@ void replacePtrLoopArguments(Operation *rootOp, yieldOp.getLoc(), constructOperands(yieldOp.getOperands(), tempVar, mapping)); }); + + // propagate Triton-specific loop attribute if present on the old for + if (forOp->hasAttr(triton::kNumStagesAttrName)) + createdFor->setAttr(triton::kNumStagesAttrName, + forOp->getAttr(triton::kNumStagesAttrName)); + + newOp = createdFor; } else if (auto whileOp = dyn_cast(op.getOperation())) { newOp = rewriter.create( whileOp.getLoc(), constructTypes(whileOp->getResultTypes()), diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt new file mode 100644 index 00000000..ce4bbc0b --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(LinalgExtAnalysis + DimAnalyzer.cpp + StageDependencyAnalyzer.cpp + StageUtils.cpp + + LINK_LIBS PUBLIC + + MLIRAffineDialect + MLIRArithDialect + MLIRDialectUtils + MLIRFuncDialect + MLIRLinalgDialect + MLIRLinalgUtils + MLIRMemRefDialect + MLIRPass + MLIRShapeDialect + MLIRTensorDialect + MLIRTensorUtils +) \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp b/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp new file mode 100644 index 00000000..618ed105 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/DimAnalyzer.cpp @@ -0,0 +1,646 @@ +#include "dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "npu-stage-dim-analyzer" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +std::string mlir::dicp::toString(DimKind k) { + switch (k) { + case DimKind::Unknown: + return "Unknown"; + case DimKind::Parallel: + return "Parallel"; + case DimKind::Reduction: + return "Reduction"; + case DimKind::Broadcast: + return "Broadcast"; + case DimKind::Complex: + return "Complex"; + } + return "INVALID"; +} + +//===----------------------------------------------------------------------===// +// DimensionDisjointSet Implementation +//===----------------------------------------------------------------------===// + +int64_t DimensionDisjointSet::allocate(size_t n) { + size_t start = parent.size(); + resize(start + n); + LDBG(" [DSU] Allocated " << n << " new dims. Range: [" << start << ", " + << start + n - 1 << "]"); + return static_cast(start); +} + +int64_t DimensionDisjointSet::find(int64_t i) { + if (i < 0 || i >= (int64_t)parent.size()) + return -1; + // Path compression: Point directly to the root to speed up future lookups. + if (parent[i] == i) + return i; + return parent[i] = find(parent[i]); +} + +void DimensionDisjointSet::unionSets(int64_t i, int64_t j) { + int64_t rootI = find(i); + int64_t rootJ = find(j); + if (rootI != -1 && rootJ != -1 && rootI != rootJ) { + DimKind kI = kind[rootI]; + DimKind kJ = kind[rootJ]; + + // Merge properties based on priority logic (e.g., Reduction takes + // precedence). + DimKind mergedKind = mergeKinds(kI, kJ); + + // Union by attaching I to J (could be optimized with rank/size). + parent[rootI] = rootJ; + kind[rootJ] = mergedKind; + + LDBG(" [DSU] Union(ID:" << i << " [" << toString(kI) << "] -> ID:" << j + << " [" << toString(kJ) + << "]) => Merged Kind: " << toString(mergedKind)); + } +} + +void DimensionDisjointSet::setKind(int64_t i, DimKind k) { + int64_t root = find(i); + if (root != -1) { + DimKind oldK = kind[root]; + // Update the kind, ensuring we don't downgrade a strong property (like + // Reduction). + kind[root] = mergeKinds(kind[root], k); + if (oldK != kind[root]) { + LDBG(" [DSU] SetKind ID:" << i << " (Root:" << root << ") changed from " + << toString(oldK) << " to " + << toString(kind[root])); + } + } +} + +DimKind DimensionDisjointSet::getKind(int64_t i) { + int64_t root = find(i); + return (root != -1) ? kind[root] : DimKind::Unknown; +} + +void DimensionDisjointSet::resize(size_t n) { + size_t oldSize = parent.size(); + if (n > oldSize) { + parent.resize(n); + // Initialize new elements to point to themselves (roots) with Unknown kind. + std::iota(parent.begin() + oldSize, parent.end(), oldSize); + kind.resize(n, DimKind::Unknown); + } +} + +DimKind DimensionDisjointSet::mergeKinds(DimKind a, DimKind b) { + if (a == b) + return a; + // Complex is the strongest property: if a dimension is complex anywhere, it's + // complex everywhere. + if (a == DimKind::Complex || b == DimKind::Complex) + return DimKind::Complex; + // Reduction is stronger than Parallel/Broadcast: forces serialization/atomic + // handling. + if (a == DimKind::Reduction || b == DimKind::Reduction) + return DimKind::Reduction; + // Broadcast + Parallel is treated as Parallel for tiling purposes. + // (Tiling a broadcasted loop is valid and often efficient). + if ((a == DimKind::Broadcast && b == DimKind::Parallel) || + (a == DimKind::Parallel && b == DimKind::Broadcast)) + return DimKind::Parallel; + // If one is Unknown, take the known one. + return (a != DimKind::Unknown) ? a : b; +} + +//===----------------------------------------------------------------------===// +// DimAnalyzer Implementation +//===----------------------------------------------------------------------===// + +DimAnalyzer::DimAnalyzer(const StageInfo &stage) : stage_(stage) { + // Populate the set for fast O(1) membership checks during traversal. + for (auto *op : stage_.ops) { + stageOps_.insert(op); + } +} + +std::vector DimAnalyzer::getOrAllocateDims(Value v) { + if (valueDims_.count(v)) + return valueDims_[v]; + + auto type = dyn_cast(v.getType()); + if (!type || !type.hasRank()) { + LDBG(" [Warn] Skipping unranked/non-shaped value: " << v); + return {}; + } + + int64_t rank = type.getRank(); + int64_t startId = dsu_.allocate(rank); + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), startId); + + // Default assumption: Dimensions are Parallel unless proven otherwise. + // This helps when operations (like elementwise) don't impose constraints. + for (auto id : dims) + dsu_.setKind(id, DimKind::Parallel); + + valueDims_[v] = dims; + return dims; +} + +void DimAnalyzer::bindDimensions(Value v1, Value v2) { + auto d1 = getOrAllocateDims(v1); + auto d2 = getOrAllocateDims(v2); + if (d1.empty() || d2.empty()) + return; + + if (d1.size() != d2.size()) { + LDBG(" [Warn] Rank mismatch binding " << v1 << " <-> " << v2); + return; + } + // 1-to-1 binding of dimensions (e.g., for Copy, Cast, or Elementwise). + for (size_t i = 0; i < d1.size(); ++i) { + dsu_.unionSets(d1[i], d2[i]); + } +} + +SmallVector DimAnalyzer::analyzeAndGetTilingDims() { + LDBG("\n>>> [Analysis] Starting Analysis for Stage ID: " << stage_.id); + // 1. Build the constraint graph via BFS traversal. + processBFS(); + + // 2. Identify Anchor Op. + // Heuristic: The last LinalgOp in the stage is usually the "Compute" or + // "Write" op. Tiling decisions should be based on this op's loop structure. + linalg::LinalgOp anchorOp; + for (auto it = stage_.ops.rbegin(); it != stage_.ops.rend(); ++it) { + if (auto op = dyn_cast(*it)) { + anchorOp = op; + break; + } + } + + if (!anchorOp) { + LDBG(">>> [Analysis] No LinalgOp anchor found. Tiling unknown."); + return {}; + } + + LDBG(">>> [Analysis] Anchor Op: " << anchorOp->getName()); + + // 3. Map Anchor Loops to Global Dimension IDs. + SmallVector chosenLoops; + auto iterTypes = anchorOp.getIteratorTypesArray(); + auto maps = anchorOp.getIndexingMapsArray(); + std::vector loopToDSU(iterTypes.size(), -1); + + // Iterate over operands to find which Value Dimension corresponds to which + // Loop. + auto operands = anchorOp->getOperands(); + int mapIdx = 0; + for (auto val : operands) { + if (mapIdx >= (int)maps.size()) + break; + if (!isa(val.getType())) { + mapIdx++; + continue; + } + + auto valDims = getOrAllocateDims(val); + AffineMap map = maps[mapIdx++]; + + // Analyze the AffineMap: (d0, d1) -> (d0, d1) + // If result[i] is a simple DimExpr(d_k), then Loop k corresponds to Value + // Dim i. + for (unsigned dimIdx = 0; dimIdx < map.getNumResults(); ++dimIdx) { + if (dimIdx >= valDims.size()) + continue; + if (auto dimExpr = dyn_cast(map.getResult(dimIdx))) { + unsigned loopPos = dimExpr.getPosition(); + if (loopPos < loopToDSU.size()) { + // Link the loop to the global DSU ID of the operand dimension. + loopToDSU[loopPos] = valDims[dimIdx]; + } + } + } + } + + // 4. Evaluate Loops for Tiling. + LDBG(">>> [Analysis] Loop Classification:"); + for (size_t i = 0; i < loopToDSU.size(); ++i) { + DimKind k = DimKind::Unknown; + if (loopToDSU[i] != -1) { + // Get the global property from DSU (propagated from all ops in the + // stage). + k = dsu_.getKind(loopToDSU[i]); + } else { + // Fallback: If loop isn't linked to any data dimension (rare), rely on + // local iterator type. + if (linalg::isReductionIterator(iterTypes[i])) + k = DimKind::Reduction; + else if (linalg::isParallelIterator(iterTypes[i])) + k = DimKind::Parallel; + } + + LDBG(" Loop " << i << ": " << toString(k)); + + // Policy: We only auto-tile global Parallel loops. + // (Future work: support Tiling Reduction if atomic updates are supported). + if (k == DimKind::Parallel) { + chosenLoops.push_back(i); + } + } + return chosenLoops; +} + +void DimAnalyzer::processBFS() { + BFSQueue bfsQueue; + VisitedSet visited; + DenseSet definedInStage; + + // Identify all values defined within the stage to find boundary inputs. + for (auto *op : stage_.ops) + for (auto res : op->getResults()) + definedInStage.insert(res); + + // 1. Seeds: Operands used in stage but defined externally (Inputs). + for (auto *op : stage_.ops) { + for (auto operand : op->getOperands()) { + if (!definedInStage.contains(operand)) { + if (visited.insert(operand).second) { + bfsQueue.push(operand); + getOrAllocateDims(operand); // Pre-allocate IDs for inputs. + } + } + } + } + + // 2. Seeds: Internal roots (Fallback). + // If the graph is fully internal or disconnected, start from the first op. + if (bfsQueue.empty() && !stage_.ops.empty()) { + for (auto res : stage_.ops[0]->getResults()) { + bfsQueue.push(res); + visited.insert(res); + } + } + + // Standard BFS Traversal + while (!bfsQueue.empty()) { + Value current = bfsQueue.front(); + bfsQueue.pop(); + + for (Operation *user : current.getUsers()) { + // Only process users that are part of the current stage. + if (!stageOps_.contains(user)) + continue; + + // Dispatch processing to specific Op handler. + // This establishes constraints between 'current' and 'user's results. + processOperation(user, current, bfsQueue, visited); + + // Enqueue results for downstream propagation. + for (Value result : user->getResults()) { + if (visited.insert(result).second) { + bfsQueue.push(result); + getOrAllocateDims(result); + } + } + } + } +} + +bool DimAnalyzer::processOperation(Operation *op, Value current, BFSQueue &q, + VisitedSet &v) { + // Dispatcher: Directs operation to the specific semantic handler. + if (auto matmulOp = dyn_cast(op)) + processMatmulOp(matmulOp); + else if (auto reduceOp = dyn_cast(op)) + processReduceOp(reduceOp); + else if (auto transOp = dyn_cast(op)) + processTransposeOp(transOp); + else if (auto bcastOp = dyn_cast(op)) + processBroadcastOp(bcastOp); + else if (auto linalgOp = dyn_cast(op)) + processLinalgOpGeneric(linalgOp); + + // Tensor manipulation ops + else if (auto castOp = dyn_cast(op)) + bindDimensions(castOp.getSource(), castOp.getDest()); + else if (isa(op)) + processReshapeOp(op); + else if (auto concatOp = dyn_cast(op)) + processConcatOp(concatOp); + else if (auto padOp = dyn_cast(op)) + processPadOp(padOp); + else if (auto extSlice = dyn_cast(op)) + processExtractSliceOp(extSlice); + else if (auto insSlice = dyn_cast(op)) + processInsertSliceOp(insSlice); + + // Bufferization & MemRef ops + else if (auto copyOp = dyn_cast(op)) + processMemrefCopyOp(copyOp, current, q, v); + else if (isa(op)) + processMemrefCastOp(op); + else if (auto toTensor = dyn_cast(op)) + processBufferizationToTensor(toTensor); + else if (auto matOp = dyn_cast(op)) + processMaterializeOp(matOp, current, q, v); + + // Elementwise ops (Arith, Math) + else if (isa(op->getDialect())) + processElementwise(op, current); + else { + // Default fallback: assume 1-to-1 preservation if results exist. + if (op->getNumResults() > 0) + bindDimensions(current, op->getResult(0)); + } + return true; +} + +//===----------------------------------------------------------------------===// +// Specific Handlers +//===----------------------------------------------------------------------===// + +void DimAnalyzer::processMemrefCopyOp(memref::CopyOp op, Value current, + BFSQueue &q, VisitedSet &v) { + LDBG(" [Op] Processing MemRef Copy"); + Value src = op.getSource(); + Value dst = op.getTarget(); + bindDimensions(src, dst); + + // Special Case: Copy sends data to 'dst', which is an operand (outs), not a + // result. We must explicitly enqueue 'dst' to continue BFS. + if (current == src) { + if (v.insert(dst).second) { + q.push(dst); + getOrAllocateDims(dst); + LDBG(" -> Enqueued Copy Destination: " << dst); + } + } +} + +void DimAnalyzer::processMaterializeOp( + bufferization::MaterializeInDestinationOp op, Value current, BFSQueue &q, + VisitedSet &v) { + LDBG(" [Op] Processing MaterializeInDestination"); + Value src = op.getSource(); + Value dst = op.getDest(); + bindDimensions(src, dst); + + // Similar to Copy: Propagate to destination buffer. + if (current == src) { + if (v.insert(dst).second) { + q.push(dst); + getOrAllocateDims(dst); + LDBG(" -> Enqueued Materialize Destination: " << dst); + } + } +} + +void DimAnalyzer::processMemrefCastOp(Operation *op) { + LDBG(" [Op] Processing MemRef Cast/Reinterpret"); + Value src = op->getOperand(0); + Value dst = op->getResult(0); + + auto srcType = dyn_cast(src.getType()); + auto dstType = dyn_cast(dst.getType()); + + if (srcType && srcType.hasRank() && dstType && dstType.hasRank()) { + if (srcType.getRank() == dstType.getRank()) { + bindDimensions(src, dst); + } else { + // Rank changing casts (e.g. collapse/expand via reinterpret) break strict + // 1-to-1 binding. We treat dst dims as new/separate. + LDBG(" Rank change detected, breaking strict binding."); + getOrAllocateDims(dst); + } + } else { + getOrAllocateDims(dst); + } +} + +void DimAnalyzer::processBufferizationToTensor(bufferization::ToTensorOp op) { + LDBG(" [Op] Processing ToTensor"); + // Converts MemRef to Tensor. Dimensions are strictly preserved. + Value memrefValue = op.getOperand(); + Value tensorResult = op.getResult(); + bindDimensions(memrefValue, tensorResult); +} + +void DimAnalyzer::processTransposeOp(linalg::TransposeOp op) { + LDBG(" [Op] Processing TransposeOp"); + Value input = op.getInput(); + Value result = op.getResult()[0]; + auto perm = op.getPermutation(); + + auto inputDims = getOrAllocateDims(input); + auto resDims = getOrAllocateDims(result); + + if (inputDims.empty() || resDims.empty()) + return; + + // Bind Input[Perm[i]] <-> Result[i] + for (size_t i = 0; i < perm.size(); ++i) { + int64_t srcIdx = perm[i]; + if (srcIdx < (int)inputDims.size() && i < resDims.size()) { + dsu_.unionSets(inputDims[srcIdx], resDims[i]); + } + } +} + +void DimAnalyzer::processMatmulOp(linalg::MatmulOp op) { + LDBG(" [Op] Processing MatmulOp"); + // Standard Matmul: [M, K] * [K, N] -> [M, N] + Value lhs = op.getInputs()[0]; + Value rhs = op.getInputs()[1]; + Value out = op.getResults()[0]; + + auto lhsDims = getOrAllocateDims(lhs); + auto rhsDims = getOrAllocateDims(rhs); + auto outDims = getOrAllocateDims(out); + + // Allocate implicit loops for M, N, K and set their properties. + int64_t loopM = dsu_.allocate(1); + int64_t loopN = dsu_.allocate(1); + int64_t loopK = dsu_.allocate(1); + dsu_.setKind(loopM, DimKind::Parallel); + dsu_.setKind(loopN, DimKind::Parallel); + dsu_.setKind(loopK, DimKind::Reduction); + + // Bind operand dimensions to these loops. + // Assumes standard layout: LHS=[..., M, K], RHS=[..., K, N], Out=[..., M, N] + if (lhsDims.size() >= 2 && rhsDims.size() >= 2 && outDims.size() >= 2) { + dsu_.unionSets(lhsDims[lhsDims.size() - 2], loopM); + dsu_.unionSets(lhsDims[lhsDims.size() - 1], loopK); + dsu_.unionSets(rhsDims[rhsDims.size() - 2], loopK); + dsu_.unionSets(rhsDims[rhsDims.size() - 1], loopN); + dsu_.unionSets(outDims[outDims.size() - 2], loopM); + dsu_.unionSets(outDims[outDims.size() - 1], loopN); + } +} + +void DimAnalyzer::processReduceOp(linalg::ReduceOp op) { + LDBG(" [Op] Processing ReduceOp"); + Value input = op.getInputs()[0]; + Value output = op.getResults()[0]; + auto inputDims = getOrAllocateDims(input); + auto outputDims = getOrAllocateDims(output); + auto reduceIndices = op.getDimensions(); + std::set reduceSet(reduceIndices.begin(), reduceIndices.end()); + + int outIdx = 0; + for (size_t i = 0; i < inputDims.size(); ++i) { + if (reduceSet.count(i)) { + // Input dimension is being reduced -> Mark as Reduction. + dsu_.setKind(inputDims[i], DimKind::Reduction); + } else if (outIdx < (int)outputDims.size()) { + // Input dimension is preserved -> Bind to Output dimension. + dsu_.unionSets(inputDims[i], outputDims[outIdx++]); + } + } +} + +void DimAnalyzer::processBroadcastOp(linalg::BroadcastOp op) { + LDBG(" [Op] Processing BroadcastOp"); + auto inDims = getOrAllocateDims(op.getInput()); + auto resDims = getOrAllocateDims(op.getResult()[0]); + auto broadcastIndices = op.getDimensions(); + std::set bcastSet(broadcastIndices.begin(), broadcastIndices.end()); + + int inIdx = 0; + for (size_t i = 0; i < resDims.size(); ++i) { + if (bcastSet.count(i)) { + // New dimension added by broadcast -> Mark as Broadcast. + dsu_.setKind(resDims[i], DimKind::Broadcast); + } else if (inIdx < (int)inDims.size()) { + // Existing dimension -> Bind to input. + dsu_.unionSets(resDims[i], inDims[inIdx++]); + } + } +} + +void DimAnalyzer::processReshapeOp(Operation *op) { + LDBG(" [Op] Processing Reshape"); + bool isExpand = isa(op); + auto srcDims = getOrAllocateDims(op->getOperand(0)); + auto dstDims = getOrAllocateDims(op->getResult(0)); + + SmallVector indices; + if (isExpand) { + indices = cast(op).getReassociationIndices(); + } else { + indices = cast(op).getReassociationIndices(); + } + + // Map between Collapsed (1 dim) and Expanded (N dims). + auto &collapsed = isExpand ? srcDims : dstDims; + auto &expanded = isExpand ? dstDims : srcDims; + + if (indices.size() != collapsed.size()) + return; + + // Bind the single collapsed dimension to ALL corresponding expanded + // dimensions. This is a conservative approach: it effectively groups them all + // into one equivalence class. + for (size_t i = 0; i < indices.size(); ++i) { + int64_t colID = collapsed[i]; + for (int64_t expIdx : indices[i]) { + if (expIdx < (int64_t)expanded.size()) + dsu_.unionSets(colID, expanded[expIdx]); + } + } +} + +void DimAnalyzer::processElementwise(Operation *op, Value current) { + LDBG(" [Op] Processing Elementwise"); + // Elementwise ops (Add, Sub, etc.) strictly preserve shape. + // Bind input dimensions to result dimensions 1-to-1. + if (op->getNumResults() > 0) + bindDimensions(current, op->getResult(0)); +} + +void DimAnalyzer::processConcatOp(tensor::ConcatOp op) { + // Concat preserves all dimensions except the concatenation axis. + // Even on the concat axis, the logical meaning of the dimension usually + // matches (e.g., stacking Batches). We bind inputs to output 1-to-1. + Value result = op.getResult(); + for (Value operand : op.getOperands()) + bindDimensions(operand, result); +} + +void DimAnalyzer::processPadOp(tensor::PadOp op) { + // Padding extends the size but preserves the logical axis. + bindDimensions(op.getSource(), op.getResult()); +} + +void DimAnalyzer::processExtractSliceOp(tensor::ExtractSliceOp op) { + auto srcDims = getOrAllocateDims(op.getSource()); + auto dstDims = getOrAllocateDims(op.getResult()); + auto dropped = op.getDroppedDims(); + int dstIdx = 0; + for (size_t i = 0; i < srcDims.size(); ++i) { + // If dimension is NOT dropped (rank-reduced), bind it to the next output + // dimension. + if (!dropped.test(i) && dstIdx < (int)dstDims.size()) { + dsu_.unionSets(srcDims[i], dstDims[dstIdx++]); + } + // Dropped dimensions are effectively ignored for tiling propagation of the + // result. + } +} + +void DimAnalyzer::processInsertSliceOp(tensor::InsertSliceOp op) { + // InsertSlice modifies 'Dest'. The Result shape matches 'Dest'. + bindDimensions(op.getDest(), op.getResult()); +} + +void DimAnalyzer::processLinalgOpGeneric(linalg::LinalgOp op) { + LDBG(" [Op] Processing Generic: " << op->getName()); + auto maps = op.getIndexingMapsArray(); + auto iterTypes = op.getIteratorTypesArray(); + + // Allocate IDs for the op's loop iterators. + int64_t loopStart = dsu_.allocate(op.getNumLoops()); + + // Set properties based on iterator types (Parallel vs Reduction). + for (int i = 0; i < (int)iterTypes.size(); ++i) { + DimKind k = linalg::isReductionIterator(iterTypes[i]) ? DimKind::Reduction + : DimKind::Parallel; + dsu_.setKind(loopStart + i, k); + } + + // Bind Operands to Loops using AffineMaps. + auto operands = op->getOperands(); + int mapIdx = 0; + for (auto val : operands) { + if (mapIdx >= (int)maps.size()) + break; + if (!isa(val.getType())) { + mapIdx++; + continue; + } + + AffineMap map = maps[mapIdx++]; + auto valDims = getOrAllocateDims(val); + + // If map is (d0, d1) -> (d0, d1), bind ValDim[0] to Loop[0], etc. + for (unsigned d = 0; d < map.getNumResults(); ++d) { + if (d >= valDims.size()) + continue; + if (auto dimExpr = dyn_cast(map.getResult(d))) { + dsu_.unionSets(valDims[d], loopStart + dimExpr.getPosition()); + } + } + } +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp b/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp new file mode 100644 index 00000000..04e3adbb --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.cpp @@ -0,0 +1,295 @@ +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "npu-stage-dep-analyzer" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; + +// Helper to check if an operation is a Matmul (Cube unit op) +static bool isCubeOp(Operation *op) { + // We check if the operation name contains "matmul". + // This covers linalg.matmul, linalg.batch_matmul, or vendor specific ops like + // npu.matmul / aclnn.matmul. + StringRef opName = op->getName().getStringRef(); + return opName.contains_insensitive("matmul"); +} + +static std::string getStageTypeStr(StageType type) { + switch (type) { + case StageType::Vector: + return "Vector"; + case StageType::Cube: + return "Cube"; + default: + return "Unknown"; + } +} + +FailureOr> StageDependencyAnalyzer::collectStages() { + LDBG(">>> [Analysis] Collecting Stages..."); + std::vector collectedStages; + StageInfo currentStage; + currentStage.id = 0; + // Default type is Vector, upgrades to Cube if a matmul is found + currentStage.type = StageType::Vector; + + for (Operation &op : block->without_terminator()) { + // If the current operation is a SyncBlockWaitOp, it marks the start of a + // new stage. We complete the current stage (if it's not empty) and start a + // new one. The SyncBlockWaitOp will become the first operation of the new + // stage. + if (isa(op)) { + if (!currentStage.ops.empty()) { + collectedStages.push_back(currentStage); + currentStage = StageInfo(); + currentStage.id = collectedStages.size(); + currentStage.type = StageType::Vector; // Reset type for new stage + } + } + + currentStage.ops.push_back(&op); + + // Mark the stage if it contains a sync wait operation + if (isa(op)) { + currentStage.hasSync = true; + } + + // Check if this op upgrades the stage to a Cube stage + if (isCubeOp(&op)) { + currentStage.type = StageType::Cube; + } + } + + if (!currentStage.ops.empty()) { + collectedStages.push_back(currentStage); + } + + LDBG("Collected " << collectedStages.size() << " stages."); + + // Debug: dump the ops contained in each stage (print full op IR). + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "] Detailed stage contents:\n"; + for (const auto &stage : collectedStages) { + llvm::dbgs() << "[" DEBUG_TYPE << "] Stage " << stage.id + << " [Type: " << getStageTypeStr(stage.type) << "]" + << (stage.hasSync ? " (hasSync)" : "") + << " - ops: " << stage.ops.size() << "\n"; + for (Operation *op : stage.ops) { + llvm::dbgs() << " - "; + op->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + } + }); + + return collectedStages; +} + +void StageDependencyAnalyzer::collectEffects() { + for (auto &node : nodes) { + for (Operation *op : node.stageInfo->ops) { + // 1. SSA Def-Use (Produced Values) + for (Value res : op->getResults()) { + node.producedValues.insert(res); + } + // 1. SSA Def-Use (Consumed Values) + for (Value operand : op->getOperands()) { + // We only care about operands defined within the loop (not block args + // or invariant) + if (auto defOp = operand.getDefiningOp()) { + if (defOp->getBlock() == block) { + node.consumedValues.insert(operand); + } + } + } + + // 2. Memory Effects + if (auto memEffect = dyn_cast(op)) { + SmallVector> effects; + memEffect.getEffects(effects); + for (auto &effect : effects) { + Value val = effect.getValue(); + if (!val) + continue; + if (isa(effect.getEffect())) + node.writeValues.insert(val); + else if (isa(effect.getEffect())) + node.readValues.insert(val); + } + continue; + } + // Explicit handling for ops not implementing MemoryEffects but having + // semantics + if (auto matOp = + dyn_cast(op)) { + node.readValues.insert(matOp.getSource()); + node.writeValues.insert(matOp.getDest()); + } else if (auto copyOp = dyn_cast(op)) { + node.readValues.insert(copyOp.getSource()); + node.writeValues.insert(copyOp.getTarget()); + } + } + } +} + +void StageDependencyAnalyzer::buildDependencyGraph() { + LDBG(">>> [Analysis] Building Dependency Graph..."); + for (int i = 0; i < nodes.size(); ++i) { + for (int j = 0; j < nodes.size(); ++j) { + if (i == j) + continue; + bool hasDependency = false; + + // 1. Check SSA Dependencies (Direct Data Flow) + // If Stage J consumes a value produced by Stage I, J depends on I. + for (Value consumed : nodes[j].consumedValues) { + if (nodes[i].producedValues.count(consumed)) { + hasDependency = true; + break; + } + } + + // 2. Check Memory Dependencies + if (!hasDependency) { + for (Value writeVal : nodes[i].writeValues) { + for (Value readVal : nodes[j].readValues) { + AliasResult result = aliasAnalysis.alias(writeVal, readVal); + if (result.isMust() || result.isPartial()) { + hasDependency = true; + LDBG(" MEM DEPENDENCY: Stage " << i << " -> Stage " << j); + break; + } + } + if (hasDependency) + break; + } + } + + if (hasDependency) { + nodes[i].stageInfo->succs.insert(j); + nodes[j].stageInfo->preds.insert(i); + } + } + } +} + +LogicalResult StageDependencyAnalyzer::computeStageLevels() { + std::vector visitState(nodes.size(), + 0); // 0: unvisited, 1: visiting, 2: visited + std::function dfs = [&](int u) -> LogicalResult { + visitState[u] = 1; + int maxPredLevel = -1; + for (int v : nodes[u].stageInfo->preds) { + if (visitState[v] == 1) { + llvm::errs() << "Error: Cycle detected involving stages " << u + << " and " << v << "\n"; + return failure(); + } + if (visitState[v] == 0) { + if (failed(dfs(v))) + return failure(); + } + if (nodes[v].level > maxPredLevel) + maxPredLevel = nodes[v].level; + } + nodes[u].level = maxPredLevel + 1; + visitState[u] = 2; + return success(); + }; + + for (int i = 0; i < nodes.size(); ++i) { + if (visitState[i] == 0) { + if (failed(dfs(i))) + return failure(); + } + } + return success(); +} + +void StageDependencyAnalyzer::reorderStagesLogical() { + std::vector sortedNodes = nodes; + std::stable_sort(sortedNodes.begin(), sortedNodes.end(), + [](const StageNode &a, const StageNode &b) { + if (a.level != b.level) + return a.level < b.level; + return a.id < b.id; + }); + std::vector newStages; + newStages.reserve(stages.size()); + LDBG(">>> [Analysis] Reordered Stages (Logical Order):"); + for (const auto &node : sortedNodes) { + LDBG(" Stage ID: " << node.id << ", Level: " << node.level + << ", Type: " << getStageTypeStr(node.stageInfo->type)); + newStages.push_back(*node.stageInfo); + } + stages = std::move(newStages); +} + +void StageDependencyAnalyzer::materializeScheduleToIR() { + LDBG(">>> [Analysis] Materializing Schedule to IR (Physical Move)..."); + Operation *terminator = block->getTerminator(); + for (const auto &stage : stages) { + for (Operation *op : stage.ops) { + if (op == terminator) + continue; + op->moveBefore(terminator); + } + } +} + +FailureOr> +StageDependencyAnalyzer::runAndReorder(RewriterBase &rewriter) { + // 1. Collect Stages + auto stagesOrErr = collectStages(); + if (failed(stagesOrErr)) + return failure(); + + // Move collected stages to member variable + stages = std::move(*stagesOrErr); + + // Initialize graph nodes. + // Note: We do this here instead of in collectStages because 'nodes' stores + // pointers to elements of 'stages'. We must ensure 'stages' is in its final + // location (member variable) before taking addresses. + nodes.resize(stages.size()); + for (size_t i = 0; i < stages.size(); ++i) { + nodes[i].id = i; + nodes[i].stageInfo = &stages[i]; + } + + // 2. Run Analysis + collectEffects(); + buildDependencyGraph(); + if (failed(computeStageLevels())) + return failure(); + + // 3. Reorder + reorderStagesLogical(); + + LDBG(">>> [Result] Final Stage Dependency Summary:"); + LLVM_DEBUG(for (const auto &stage + : stages) { + llvm::dbgs() << "[" DEBUG_TYPE "] Stage " << stage.id << " (" + << getStageTypeStr(stage.type) << "):\n"; + llvm::dbgs() << " Predecessors (Depends on): { stage: "; + for (int p : stage.preds) + llvm::dbgs() << p << " "; + llvm::dbgs() << "}\n"; + llvm::dbgs() << " Successors (Depended by): { stage: "; + for (int s : stage.succs) + llvm::dbgs() << s << " "; + llvm::dbgs() << "}\n"; + }); + + materializeScheduleToIR(); + return stages; +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Analysis/StageUtils.cpp b/compiler/lib/Dialect/LinalgExt/Analysis/StageUtils.cpp new file mode 100644 index 00000000..9e0f11df --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Analysis/StageUtils.cpp @@ -0,0 +1,558 @@ +#include "dicp/Dialect/LinalgExt/Analysis/StageUtils.h" +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +// External dialect dependency - assuming HIVM dialect is available +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#define DEBUG_TYPE "dicp-stage-utils" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace mlir::dicp; +using namespace mlir::dicp::LinalgExt; + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +bool mlir::dicp::LinalgExt::isOpInStage(Operation *op, int stageId) { + if (auto attr = op->getAttrOfType(kNPUStageAttrName)) { + return attr.getInt() == stageId; + } + return false; +} + +//===----------------------------------------------------------------------===// +// StagePartitioner Implementation +//===----------------------------------------------------------------------===// + +SmallVector +StagePartitioner::findBlocksWithHivmSyncOps(ModuleOp module) { + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE + "] Searching for blocks with HIVM sync ops...\n"); + + SetVector blockSet; + module.walk([&](Operation *op) { + // Using TypeSwitch for extensible sync op detection + llvm::TypeSwitch(op) + .Case( + [&](auto) { blockSet.insert(op->getBlock()); }); + }); + + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] Found " << blockSet.size() + << " blocks requiring stage partitioning.\n"); + return blockSet.takeVector(); +} + +LogicalResult StagePartitioner::analyzeAndTagBlock(Block *block, + MLIRContext *ctx, + bool &anyStageFound) { + Operation *parentOp = block->getParentOp(); + if (!parentOp) { + LLVM_DEBUG(llvm::dbgs() + << "[" DEBUG_TYPE "] Error: Block has no parent operation.\n"); + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE + "] Analyzing dependencies for block in parent: " + << parentOp->getName() << "\n"); + + // Perform Alias Analysis to ensure conservative dependency tracking + AliasAnalysis aliasAnalysis(parentOp); + StageDependencyAnalyzer analyzer(block, aliasAnalysis); + + FailureOr> stagesOrErr = analyzer.collectStages(); + if (failed(stagesOrErr)) { + return parentOp->emitError( + "Dependency cycle or analysis failure detected in stage analysis."); + } + + anyStageFound = false; + unsigned tagCount = 0; + + for (const auto &stage : *stagesOrErr) { + // Only tag operations belonging to Vector-type stages + if (stage.type == StageType::Vector) { + IntegerAttr attr = IntegerAttr::get(IntegerType::get(ctx, 32), stage.id); + for (Operation *op : stage.ops) { + op->setAttr(kNPUStageAttrName, attr); + tagCount++; + } + anyStageFound = true; + } + } + + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] Successfully tagged " << tagCount + << " operations across stages.\n"); + return success(); +} + +SetVector StagePartitioner::getStageIdsInBlock(Block *block) { + SetVector ids; + if (!block) + return ids; + + for (Operation &op : *block) { + if (auto attr = op.getAttrOfType(kNPUStageAttrName)) { + ids.insert(static_cast(attr.getInt())); + } + } + return ids; +} + +Stage StagePartitioner::partition(Block *block, int stageId) { + Stage stage; + stage.id = stageId; + + SmallVector currentOps; + unsigned subIdx = 0; + + LLVM_DEBUG( + llvm::dbgs() << "[" DEBUG_TYPE + "] Partitioning block into substages for Stage ID: " + << stageId << "\n"); + + auto finalizeSubStage = [&](SmallVectorImpl &ops) { + if (ops.empty()) + return; + + SubStage subStage; + subStage.index = subIdx++; + subStage.stageId = stageId; + subStage.ops = std::move(ops); + stage.subStages.push_back(std::move(subStage)); + ops.clear(); + + LLVM_DEBUG(llvm::dbgs() + << " -> Created SubStage " << subStage.index << " with " + << stage.subStages.back().ops.size() << " ops.\n"); + }; + + for (Operation &op : *block) { + // Filter by stage ID + if (!isOpInStage(&op, stageId)) + continue; + + // Synchronization operations act as hard boundaries for sub-stages. + // This prevents hoisting/sinking across sync points during later + // scheduling. + bool isSync = isa(op); + + if (isSync) { + finalizeSubStage(currentOps); + // Note: Sync ops themselves are currently excluded from substage op lists + // as they serve as delimiters. + continue; + } + + currentOps.push_back(&op); + } + + // Handle trailing operations + finalizeSubStage(currentOps); + + LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] Stage " << stageId + << " partitioned into " << stage.subStages.size() + << " substages.\n"); + return stage; +} + +SmallVector StagePartitioner::getAllStagesInBlock(Block *block) { + SmallVector result; + SetVector stageIds = getStageIdsInBlock(block); + + for (int id : stageIds) { + result.push_back(partition(block, id)); + } + + return result; +} + +namespace { + +//===----------------------------------------------------------------------===// +// Helper Utilities +//===----------------------------------------------------------------------===// + +static bool isMatMul(Operation *op) { + return isa(op); +} + +static bool isCubeSatge(Operation *op) { + return isa(op->getDialect()); +} + +static bool isVectorOp(Operation *op) { + if (!isa(op->getDialect())) { + return false; + } + return llvm::any_of(op->getOperandTypes(), + [](Type t) { return isa(t); }); +} + +//===----------------------------------------------------------------------===// +// BlockSegmenter +//===----------------------------------------------------------------------===// + +class BlockSegmenter { +public: + explicit BlockSegmenter(Block &block) : block(block) {} + + LogicalResult run(SmallVectorImpl &stages); + +private: + struct OpInfo { + StageType type = StageType::Vector; + int64_t index = -1; + int64_t closestDistance = std::numeric_limits::max(); + }; + + LogicalResult validate(); + void indexOperations(); + void classify(); + void buildStages(SmallVectorImpl &stages); + void deduplicate(SmallVectorImpl &stages); + + // Classification Helpers + void processAnchor(Operation *anchor); + void propagateSlice(Operation *anchor, bool backward); + void tryClaim(Operation *op, int64_t anchorIndex); + + Block █ + DenseMap opInfoMap; +}; + +//===----------------------------------------------------------------------===// +// Validation Logic +//===----------------------------------------------------------------------===// + +LogicalResult BlockSegmenter::validate() { + LDBG("Validating block constraints..."); + + for (Operation &op : block) { + auto walkRes = op.walk([&](Operation *nested) -> WalkResult { + auto loop = dyn_cast(nested); + if (!loop) + return WalkResult::advance(); + + bool hasMatmul = false; + bool hasForbidden = false; + + loop->walk([&](Operation *inner) { + if (inner == loop) + return; + if (isMatMul(inner)) + hasMatmul = true; + else if (isVectorOp(inner)) + hasForbidden = true; + }); + + if (hasMatmul && hasForbidden) { + LDBG("ERROR: Loop mixes matmul and forbidden vector ops:\n" << *loop); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (walkRes.wasInterrupted()) + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Core Pipeline Implementation +//===----------------------------------------------------------------------===// + +void BlockSegmenter::indexOperations() { + LDBG("Indexing operations in block..."); + int64_t idx = 0; + for (Operation &op : block) { + auto &info = opInfoMap[&op]; + info.index = idx++; + LDBG(" Indexed op #" << info.index << " : " << op); + } +} + +void BlockSegmenter::classify() { + LDBG("Classifying Cube vs Vector operations..."); + for (Operation &op : block) { + if (isMatMul(&op)) { + LDBG("Found Cube anchor (matmul): " << op); + processAnchor(&op); + } + } +} + +void BlockSegmenter::processAnchor(Operation *anchor) { + OpInfo &info = opInfoMap[anchor]; + info.type = StageType::Cube; + info.closestDistance = 0; + + LDBG(" Anchor forced to Cube stage"); + + // Propagate Cube type through dialect-specific slices + propagateSlice(anchor, /*backward=*/true); + propagateSlice(anchor, /*backward=*/false); +} + +void BlockSegmenter::propagateSlice(Operation *anchor, bool backward) { + SetVector slice; + auto filter = [&](Operation *op) { + return op->getBlock() == &block && isCubeSatge(op); + }; + + if (backward) { + BackwardSliceOptions opt; + opt.filter = filter; + (void)getBackwardSlice(anchor, &slice, opt); + LDBG(" Backward slice size: " << slice.size()); + } else { + ForwardSliceOptions opt; + opt.filter = filter; + getForwardSlice(anchor, &slice, opt); + LDBG(" Forward slice size: " << slice.size()); + } + + int64_t anchorIdx = opInfoMap[anchor].index; + for (Operation *op : slice) + tryClaim(op, anchorIdx); +} + +void BlockSegmenter::tryClaim(Operation *op, int64_t anchorIndex) { + OpInfo &info = opInfoMap[op]; + int64_t dist = std::abs(info.index - anchorIndex); + + if (dist < info.closestDistance) { + LDBG(" Claiming op " << op->getName() << " as Cube (dist = " << dist + << ")"); + info.type = StageType::Cube; + info.closestDistance = dist; + } else { + LDBG(" Skipping op " << op->getName() << " (closer anchor exists)"); + } +} + +void BlockSegmenter::buildStages(SmallVectorImpl &stages) { + LDBG("Building execution stages..."); + int stageId = 0; + + auto ensureStage = [&](StageType type) { + if (stages.empty() || stages.back().type != type) { + stages.emplace_back(stageId++, type); + LDBG(" Created new stage " + << stages.back().id + << " type = " << (type == StageType::Cube ? "Cube" : "Vector")); + } + }; + + for (Operation &op : block) { + OpInfo &info = opInfoMap[&op]; + + // Final sanity check for MatMul classification + if (isMatMul(&op) && info.type != StageType::Cube) { + LDBG(" WARNING: Matmul not classified as Cube. Forcing."); + info.type = StageType::Cube; + } + + ensureStage(info.type); + stages.back().addOp(&op); + LDBG(" Added op " << op << " to stage " << stages.back().id); + } +} + +void BlockSegmenter::deduplicate(SmallVectorImpl &stages) { + LDBG("Running cross-stage deduplication..."); + + // Track which stage (index in 'stages' vector) owns which operation + DenseMap ownerStage; + + for (int i = 0, e = stages.size(); i < e; ++i) { + for (auto &sub : stages[i].subStages) { + for (Operation *op : sub.ops) { + if (!ownerStage.count(op)) { + ownerStage[op] = i; + } else { + int prevOwner = ownerStage[op]; + // Use the distance metric to resolve ownership if duplicated + if (opInfoMap[op].closestDistance < opInfoMap[op].closestDistance) { + ownerStage[op] = i; + } + } + } + } + } + + // Filter stages based on ownership + for (int i = 0, e = stages.size(); i < e; ++i) { + for (auto &sub : stages[i].subStages) { + SmallVector filtered; + for (Operation *op : sub.ops) { + if (ownerStage.lookup(op) == i) { + filtered.push_back(op); + } else { + LDBG(" Removing duplicated op " << op->getName() << " from stage " + << i); + } + } + sub.ops.swap(filtered); + } + } + + // Clean up empty stages + llvm::erase_if(stages, [](const Stage &s) { + bool isEmpty = + llvm::all_of(s.subStages, [](auto &sub) { return sub.ops.empty(); }); + if (isEmpty) + LDBG(" Removing empty stage " << s.id); + return isEmpty; + }); + + LDBG("Deduplication complete. Final stage count = " << stages.size()); +} + +LogicalResult BlockSegmenter::run(SmallVectorImpl &stages) { + LDBG("=== Starting Cube/Vector Segmentation (Block-level) ==="); + + if (failed(validate())) + return failure(); + + indexOperations(); + classify(); + buildStages(stages); + deduplicate(stages); + + LDBG("=== Segmentation Finished. Total stages = " << stages.size() << " ==="); + return success(); +} + +//===----------------------------------------------------------------------===// +// Target Block Search Helpers +//===----------------------------------------------------------------------===// + +/// Calculates the nesting depth of a block. +/// FuncBody = 0, scf.for loop body = 1, nested scf.for = 2, etc. +static unsigned getBlockNestingLevel(Block *block) { + unsigned level = 0; + Region *region = block->getParent(); + while (region) { + Operation *parentOp = region->getParentOp(); + if (!parentOp) + break; + + // If the parent operation is a Function, we consider this the top level + // (0). + if (isa(parentOp)) + return level; + + level++; + region = parentOp->getParentRegion(); + } + return level; +} + +} // namespace +//===----------------------------------------------------------------------===// +// CubeVectorSplitter Public API +//===----------------------------------------------------------------------===// + + +/** + * Identifies the primary IR Block to be processed based on the function's + * "mix_mode" attribute. + * * Selection Strategies: + * - "mix": Returns the Block with the highest density of anchor operations + * (MatMul or Vector ops). + * - "aiv"/"aic": Returns the outermost Block (minimum nesting) containing + * the mode-specific operations. + */ +Block *CubeVectorSplitter::findTargetBlock(func::FuncOp funcOp) { + // 1. Extract configuration with StringRef to avoid allocations. + StringRef mixMode = "mix"; + if (auto attr = funcOp->getAttrOfType("mix_mode")) + mixMode = attr.getValue(); + + // Define predicates for operation filtering. + auto isCube = [](Operation *op) { return isMatMul(op); }; + auto isVector = [](Operation *op) { return isVectorOp(op); }; + + // 2. State tracking using LLVM-optimized containers. + DenseMap anchorCounts; + Block *bestBlock = nullptr; + unsigned maxAnchors = 0; + unsigned minLevel = std::numeric_limits::max(); + + bool isMixMode = (mixMode == "mix"); + bool isAIV = (mixMode == "aiv"); + + // 3. Single-pass IR traversal. + funcOp.walk([&](Operation *op) { + // Determine if the op is relevant to the current mode. + bool match = isMixMode ? (isCube(op) || isVector(op)) + : (isAIV ? isVector(op) : isCube(op)); + if (!match) + return; + + Block *currentBlock = op->getBlock(); + + if (isMixMode) { + // Strategy: Maximize operation density. + unsigned count = ++anchorCounts[currentBlock]; + if (count > maxAnchors) { + maxAnchors = count; + bestBlock = currentBlock; + } + } else { + // Strategy: Find the outermost scope (minimum nesting depth). + unsigned currentLevel = getBlockNestingLevel(currentBlock); + if (currentLevel < minLevel) { + minLevel = currentLevel; + bestBlock = currentBlock; + } + } + }); + + // 4. Diagnostics and Validation. + if (!bestBlock) { + funcOp.emitError() << "Target block discovery failed for mode: '" + << mixMode << "'. No matching operations found."; + return nullptr; + } + + // Optional: Warn if multiple blocks exist (Mix Mode), selecting the first found. + if (isMixMode && anchorCounts.size() > 1) { + llvm::dbgs() << "[Warning] Multiple candidate blocks found in mix mode. " + << "Selecting block with " << maxAnchors << " anchors.\n"; + } + bestBlock->dump(); + return bestBlock; +} + +//===----------------------------------------------------------------------===// +// Public API Implementation +//===----------------------------------------------------------------------===// + +LogicalResult CubeVectorSplitter::splitBlock(Block &block, + SmallVectorImpl &stages) { + BlockSegmenter segmenter(block); + return segmenter.run(stages); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 0b28548a..4b7fba0c 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -4,9 +4,17 @@ add_triton_library(LinalgExtTransforms ScalarTo1DTensorPass.cpp RemoveSingleIterationLoop.cpp TensorTransform.cpp + NPUUnroolPipeline.cpp + NPUVectorTileTransform.cpp + NPUTileLoopTagging.cpp + DeLinalgize.cpp + FuseLoop.cpp + LoopUnrollStage.cpp + ShrinkBuffers.cpp DEPENDS LinalgExtTransformsIncGen + LinalgExtAnalysis LINK_LIBS PUBLIC TritonTilingExtIR @@ -18,11 +26,21 @@ add_triton_library(LinalgExtTransforms MLIRTensorDialect MLIRTransforms MLIRSupport + MLIRAnalysis + MLIRSCFUtils + MLIRSCFTransforms + MLIRLinalgTransformOps + MLIRLinalgUtils + MLIRTransformUtils + MLIRTransformDialectTransforms + MLIRTensorTransforms TritonAnalysis TritonIR TritonTransforms TritonSharedAnalysis + DICPTransformOps + TritonArithToLinalg StructuredToMemref TritonToStructured diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/DeLinalgize.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/DeLinalgize.cpp new file mode 100644 index 00000000..d77df13e --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/DeLinalgize.cpp @@ -0,0 +1,236 @@ +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "de-linalgize" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace mlir::dicp; + +namespace mlir::dicp::LinalgExt { +#define GEN_PASS_DEF_DELINALGIZE +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace mlir::dicp::LinalgExt + +namespace { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +/** + * @brief Restores linalg.generic ops back to arith/math elementwise ops. + * + * This pattern targets ops previously converted to generics for fusion or + * tiling purposes, restoring them to their high-level functional form for + * backend-specific code generation. + */ +struct GenericToElementwisePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp->hasAttr(kNPUStageAttrName)) + return failure(); + + LDBG("Analyzing generic op for restoration: " << genericOp); + + if (!mlir::linalg::isElementwise(genericOp)) { + LDBG(" -> Failed: Not an identity-mapped elementwise generic."); + return failure(); + } + + Block &body = genericOp.getRegion().front(); + if (body.getOperations().size() != 2) { // [ScalarOp, YieldOp] + LDBG(" -> Failed: Body size is " << body.getOperations().size() + << ", expected 2."); + return failure(); + } + + Operation *scalarOp = &body.front(); + StringRef opName = scalarOp->getName().getStringRef(); + + // Only restore arith and math dialects + if (!opName.starts_with("arith.") && !opName.starts_with("math.")) { + LDBG(" -> Failed: Inner op " << opName << " is not arith/math."); + return failure(); + } + + // Map block arguments back to generic operands + SmallVector newOperands; + for (Value operand : scalarOp->getOperands()) { + auto arg = dyn_cast(operand); + if (!arg || arg.getOwner() != &body) { + LDBG(" -> Failed: Operand is not a block argument of the generic."); + return failure(); + } + + unsigned argIdx = arg.getArgNumber(); + if (argIdx >= genericOp.getNumDpsInputs()) { + LDBG(" -> Failed: Operation uses output/accumulator as input."); + return failure(); + } + newOperands.push_back(genericOp.getDpsInputOperand(argIdx)->get()); + } + + LDBG(" -> Restoring to " << opName); + + // Filter attributes: keep scalar op attributes and carry over NPU stage + SmallVector newAttrs; + for (auto attr : scalarOp->getAttrs()) + newAttrs.push_back(attr); + + Operation *newOp = rewriter.create( + genericOp.getLoc(), rewriter.getStringAttr(opName), newOperands, + genericOp.getResultTypes(), scalarOp->getAttrs()); + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; + +/** + * @brief Restores linalg.copy to memref.copy or materialization. + * + * Specifically handles the restoration of bufferization artifacts that were + * lowered to linalg.copy for generic transformation passes. + */ +struct LinalgCopyToOriginalPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::CopyOp copyOp, + PatternRewriter &rewriter) const override { + auto originalNameAttr = + copyOp->getAttrOfType(kOriginalOpNameAttr); + if (!originalNameAttr) + return failure(); + + StringRef originalName = originalNameAttr.getValue(); + Location loc = copyOp.getLoc(); + Value source = copyOp.getInputs().front(); + Value dest = copyOp.getOutputs().front(); + + LDBG("Restoring linalg.copy to original op: " << originalName); + + // Propagate attributes excluding our internal marker + SmallVector filteredAttrs; + for (auto attr : copyOp->getAttrs()) { + if (attr.getName() != kOriginalOpNameAttr) + filteredAttrs.push_back(attr); + } + + if (originalName == + bufferization::MaterializeInDestinationOp::getOperationName()) { + Value tensorSource = recoverTensorSource(source, rewriter, loc); + if (!tensorSource) { + LDBG(" -> Failed to recover tensor source for materialization."); + return failure(); + } + + auto matOp = rewriter.create( + loc, tensorSource, dest); + matOp->setAttrs(filteredAttrs); + rewriter.eraseOp(copyOp); + return success(); + } + + if (originalName == memref::CopyOp::getOperationName()) { + auto memrefCopy = rewriter.create(loc, source, dest); + memrefCopy->setAttrs(filteredAttrs); + rewriter.eraseOp(copyOp); + return success(); + } + + return failure(); + } + +private: + /** + * @brief Iteratively traces a memref back to a tensor source. + * Handles SubView chains by creating corresponding ExtractSliceOps. + */ + Value recoverTensorSource(Value val, PatternRewriter &rewriter, + Location loc) const { + if (auto toBuffer = val.getDefiningOp()) + return toBuffer.getTensor(); + + if (auto subview = val.getDefiningOp()) { + Value parentTensor = + recoverTensorSource(subview.getSource(), rewriter, loc); + if (!parentTensor) + return nullptr; + + return rewriter.create( + loc, parentTensor, subview.getMixedOffsets(), subview.getMixedSizes(), + subview.getMixedStrides()); + } + + // Handle generic view-like interfaces if necessary + if (auto viewOp = val.getDefiningOp()) { + return recoverTensorSource(viewOp.getViewSource(), rewriter, loc); + } + + return nullptr; + } +}; + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +struct DeLinalgizePass + : public mlir::dicp::LinalgExt::impl::DeLinalgizeBase { + using DeLinalgizeBase::DeLinalgizeBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add( + context); + + // Use GreedyPatternRewriteDriver to handle potential chains of restorations + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + return; + } + + { + PassManager pm(&getContext(), module.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, module))) { + LDBG("Final cleanup pipeline failed."); + signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createDeLinalgizePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/FuseLoop.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/FuseLoop.cpp new file mode 100644 index 00000000..b0f2ffaf --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/FuseLoop.cpp @@ -0,0 +1,307 @@ +#include "dicp/Dialect/LinalgExt/Analysis/StageUtils.h" +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/DicpTransformOps.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include + +#define DEBUG_TYPE "npu-loop-fusion" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir::dicp::LinalgExt { +#define GEN_PASS_DEF_FUSELOOP +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace mlir::dicp::LinalgExt + +namespace { + +static constexpr llvm::StringRef kFuseLoopTagAttr = "fuse_loop_tag"; + +/// Helper struct to hold information about a group of loops to be fused. +struct FusionGroup { + int32_t stageId; + SmallVector loopTags; +}; + +//===----------------------------------------------------------------------===// +// Dependency Analysis Utilities +//===----------------------------------------------------------------------===// + +/// Checks if `opB` can be safely moved UP to `opA`'s position. +/// This requires that all values used by `opB` (operands and captured values) +/// are defined by operations that properly dominate `opA`. +static bool canMoveUpTo(Operation *opA, Operation *opB, + DominanceInfo &domInfo) { + auto isSafeOperand = [&](Value val) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return true; // Block arguments inherently dominate everything in the + // block. + bool dominates = + domInfo.properlyDominates(defOp, opA, /*enclosingOpOk=*/false); + + if (!dominates) { + LDBG("Cannot move Up: " + << *defOp << " does NOT dominate target position " << *opA); + } + return dominates; + }; + + // Check explicit operands. + if (!llvm::all_of(opB->getOperands(), isSafeOperand)) + return false; + + // Check values implicitly captured in regions. + bool regionsSafe = true; + visitUsedValuesDefinedAbove(opB->getRegions(), [&](OpOperand *operand) { + if (regionsSafe && !isSafeOperand(operand->get())) + regionsSafe = false; + }); + + return regionsSafe; +} + +/// Checks if `opA` can be safely moved DOWN to `opB`'s position. +/// This requires that all users of `opA`'s results are strictly after `opB`. +static bool canMoveDownTo(Operation *opA, Operation *opB, + DominanceInfo &domInfo) { + for (Operation *user : opA->getUsers()) { + if (!domInfo.properlyDominates(opB, user, /*enclosingOpOk=*/false)) { + LDBG("Cannot move Down: User " + << *user << " appears before the target move-to position " << *opB); + return false; + } + } + return true; +} + +/// Checks if two operations in the same block can be safely moved to be +/// adjacent without violating SSA data dependencies. +static bool canBeMadeAdjacent(Operation *opA, Operation *opB, + DominanceInfo &domInfo) { + if (opA == opB || opA->getBlock() != opB->getBlock()) + return false; + + // Enforce structural order (opA before opB) for simpler reasoning. + if (opA->isBeforeInBlock(opB)) { + if (canMoveDownTo(opA, opB, domInfo)) + return true; + + } else { + if (canMoveUpTo(opA, opB, domInfo)) + return true; + } + + LDBG(" [Reject] Ops " + << *opA << " and " << *opB + << " cannot be made adjacent due to SSA dependencies."); + return false; +} + +/// Identifies candidate loops within a SubStage and clusters them into +/// FusionGroups based on dependency analysis. +static SmallVector groupAndTagLoops(SubStage &subStage, + Builder &builder) { + SmallVector loops; + for (Operation *op : subStage.ops) { + if (isa(op)) + loops.push_back(op); + } + + if (loops.size() < 2) + return {}; + + DominanceInfo domInfo(loops.front()->getParentOp()); + SmallVector fusionGroups; + + SmallVector currentOps; + SmallVector currentTags; + + // Finalizes the current cluster of compatible loops. + auto finalizeGroup = [&]() { + if (currentOps.size() >= 2) { + fusionGroups.push_back({subStage.stageId, std::move(currentTags)}); + } else if (!currentOps.empty()) { + // Revert: remove tags from isolated operations. + for (auto [op, tag] : llvm::zip(currentOps, currentTags)) + op->removeAttr(tag); + } + currentOps.clear(); + currentTags.clear(); + }; + + int globalLoopIdx = 0; + for (Operation *loop : loops) { + if (currentOps.empty()) { + currentOps.push_back(loop); + } else { + // Check N-to-N compatibility for multi-way sibling fusions. + bool isCompatible = llvm::all_of(currentOps, [&](Operation *existingOp) { + if (!canBeMadeAdjacent(existingOp, loop, domInfo)) { + LDBG(" [Break] Dependency conflict between " + << loop->getName() << " and group member " + << existingOp->getName()); + return false; + } + return true; + }); + + if (isCompatible) { + currentOps.push_back(loop); + } else { + finalizeGroup(); + currentOps.push_back(loop); + } + } + + // Attach a unique tag to the loop for the Transform Dialect to match. + StringAttr tag = builder.getStringAttr( + llvm::formatv("{0}_{1}_{2}_{3}", kFuseLoopTagAttr, subStage.stageId, + subStage.index, globalLoopIdx++) + .str()); + loop->setAttr(tag, builder.getUnitAttr()); + currentTags.push_back(tag); + } + + finalizeGroup(); + return fusionGroups; +} + +//===----------------------------------------------------------------------===// +// Transform Dialect Command Generation +//===----------------------------------------------------------------------===// + +/// Dispatches a MatchOp returning an opaque transform handle for a tagged op. +static Value getMatchHandle(ImplicitLocOpBuilder &b, Value root, + StringAttr tag) { + auto handleType = b.getType(); + auto attrDict = b.getDictionaryAttr(b.getNamedAttr(tag, b.getUnitAttr())); + + return b + .create( + handleType, root, /*ops=*/nullptr, /*interface=*/nullptr, + /*combined_attr=*/attrDict, /*filter_catalogs=*/nullptr, + /*filter_on_op_names=*/nullptr) + .getResult(); +} + +/// Generates transform dialect IR to perform sibling fusion on a group of +/// loops. +static void fuseLoopsByTags(ImplicitLocOpBuilder &b, Value root, + const FusionGroup &group) { + if (group.loopTags.size() < 2) + return; + + SmallVector loopHandles; + loopHandles.reserve(group.loopTags.size()); + for (StringAttr tag : group.loopTags) + loopHandles.push_back(getMatchHandle(b, root, tag)); + + // Sequentially fuse sibling loops into the first loop. + Value fusedLoop = loopHandles.front(); + auto handleType = b.getType(); + + for (size_t i = 1; i < loopHandles.size(); ++i) { + auto fuseOp = b.create( + handleType, /*target=*/fusedLoop, /*source=*/loopHandles[i]); + fusedLoop = fuseOp.getFusedLoop(); + } + + // Annotate the final fused loop with the hardware Stage ID. + auto stageAttr = b.getI32IntegerAttr(group.stageId); + auto paramType = transform::ParamType::get(b.getContext(), b.getI32Type()); + auto stageParam = b.create(paramType, stageAttr); + + b.create(fusedLoop, kNPUStageAttrName, + stageParam.getResult()); +} + +//===----------------------------------------------------------------------===// +// Pass Orchestrator +//===----------------------------------------------------------------------===// + +struct FuseLoopPass + : public mlir::dicp::LinalgExt::impl::FuseLoopBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = module.getContext(); + Builder builder(ctx); + + // 1. Identify synchronization boundaries. + SmallVector blocks = + StagePartitioner::findBlocksWithHivmSyncOps(module); + if (blocks.empty()) { + LDBG("No synchronization barriers detected. Skipping FuseLoopPass."); + return; + } + + // 2. Perform staging analysis. + bool anyStageFound = false; + for (Block *block : blocks) { + if (failed(StagePartitioner::analyzeAndTagBlock(block, ctx, + anyStageFound))) { + LDBG("Fatal error during staging topology analysis. Aborting."); + return signalPassFailure(); + } + } + + if (!anyStageFound) + return; + + // 3. Group loops within each substage. + SmallVector allFusionGroups; + for (Block *block : blocks) { + for (int32_t stageId : StagePartitioner::getStageIdsInBlock(block)) { + auto subStages = StagePartitioner::partition(block, stageId).subStages; + for (SubStage &subStage : subStages) { + SmallVector groups = groupAndTagLoops(subStage, builder); + allFusionGroups.append(groups.begin(), groups.end()); + } + } + } + + if (allFusionGroups.empty()) + return; + + // 4. Apply transformations via Transform Dialect. + TransformApplier::apply(module, + [&](OpBuilder &b, Location loc, Value root) { + ImplicitLocOpBuilder implicitB(loc, b); + for (const FusionGroup &group : allFusionGroups) + fuseLoopsByTags(implicitB, root, group); + }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createFuseLoopPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/LoopUnrollStage.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/LoopUnrollStage.cpp new file mode 100644 index 00000000..0e068a07 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/LoopUnrollStage.cpp @@ -0,0 +1,303 @@ +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "dicp-loop-unroll" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace mlir::dicp; + +namespace mlir::dicp::LinalgExt { +#define GEN_PASS_DEF_LOOPUNROLLSTAGE +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace mlir::dicp::LinalgExt + +namespace { + +// --- Utilities --- + +/** + * @brief Cleans up internal DICP attributes from all operations in the + * function. + * + * Removes attributes starting with `kDicpStagePrefix` to ensure clean IR output + * after the pass completes. + */ +static void cleanupInternalAttributes(func::FuncOp func) { + MLIRContext *ctx = func.getContext(); + int totalRemoved = 0; + + func.walk([&](Operation *op) { + if (op->getAttrs().empty()) + return; + + // Identify attributes to keep + SmallVector filteredAttrs; + bool needsUpdate = false; + + for (NamedAttribute attr : op->getAttrs()) { + if (attr.getName().strref().starts_with(kDicpStagePrefix)) { + needsUpdate = true; + totalRemoved++; + } else { + filteredAttrs.push_back(attr); + } + } + + if (needsUpdate) { + op->setAttrs(DictionaryAttr::get(ctx, filteredAttrs)); + } + }); + + if (totalRemoved > 0) { + LDBG("Cleaned up " << totalRemoved << " internal stage attributes."); + } +} + +// --- Patterns --- + +/// Helper: Rewrite logic (slightly adapted to accept PatternRewriter &). +static LogicalResult rewriteForallToFor(PatternRewriter &rewriter, + scf::ForallOp forallOp, + SmallVectorImpl &loops) { + Location loc = forallOp.getLoc(); + + // 1. Gather Bounds and Steps + SmallVector lbs = forallOp.getLowerBound(rewriter); + SmallVector ubs = forallOp.getUpperBound(rewriter); + SmallVector steps = forallOp.getStep(rewriter); + + // 2. Prepare Init Args for the outermost loop (from shared_outs) + SmallVector currentIterOperands(forallOp.getOutputs().begin(), + forallOp.getOutputs().end()); + + LLVM_DEBUG(llvm::dbgs() << "[ForallToFor] Creating loop nest of rank " + << lbs.size() << "\n"); + + // 3. Create the Loop Nest + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(forallOp); + + for (auto [lb, ub, step] : llvm::zip(lbs, ubs, steps)) { + auto loop = + rewriter.create(loc, lb, ub, step, currentIterOperands); + + if (!loop.getBody()->empty()) { + rewriter.eraseOp(loop.getBody()->getTerminator()); + } + + // The iter_args of this loop become the init_args for the next inner loop + currentIterOperands.assign(loop.getRegionIterArgs().begin(), + loop.getRegionIterArgs().end()); + // Set insertion point to the beginning of the body for the next + // loop/content + rewriter.setInsertionPointToStart(loop.getBody()); + loops.push_back(loop); + } + + if (loops.empty()) { + // Zero-rank forall -> nothing to do; simply erase or signal failure. + return forallOp.emitError("expected rank > 0 for scf.forall"); + } + + scf::ForOp innermostLoop = cast(loops.back()); + + // 4. Inline the Forall Body + Block *forallBody = forallOp.getBody(); + Block *innermostBlock = innermostLoop.getBody(); + + IRMapping mapping; + // Map induction vars (first N block args of forall) + llvm::SmallVector ivs; + ivs.reserve(loops.size()); + for (Operation *op : loops) + ivs.push_back(cast(op).getInductionVar()); + + for (auto [forallArg, newIv] : + llvm::zip(forallBody->getArguments().take_front(lbs.size()), ivs)) { + mapping.map(forallArg, newIv); + } + + // Map the remaining forall block args (shared outputs) to innermost + // iter_args. + for (auto [forallArg, iterArg] : + llvm::zip(forallBody->getArguments().drop_front(lbs.size()), + innermostLoop.getRegionIterArgs())) { + mapping.map(forallArg, iterArg); + } + + SmallVector mappedArgs; + mappedArgs.reserve(forallBody->getNumArguments()); + for (Value arg : forallBody->getArguments()) + mappedArgs.push_back(mapping.lookup(arg)); + + rewriter.mergeBlocks(forallBody, innermostBlock, mappedArgs); + + // 5. Handle the Terminator (scf.forall.in_parallel) + auto inParallelOp = + dyn_cast(innermostBlock->getTerminator()); + if (!inParallelOp) { + return forallOp.emitError("expected scf.forall.in_parallel terminator"); + } + + LLVM_DEBUG(llvm::dbgs() << "[ForallToFor] Processing terminator: " + << *inParallelOp << "\n"); + + // Convert parallel updates to sequential updates on the iter_args by + // chaining. + DenseMap accumulatorMap; + for (Value iterArg : innermostLoop.getRegionIterArgs()) + accumulatorMap[iterArg] = iterArg; + + Block ¶llelBody = inParallelOp.getRegion().front(); + rewriter.setInsertionPoint(inParallelOp); + + for (Operation &op : llvm::make_early_inc_range(parallelBody)) { + if (auto parallelInsert = dyn_cast(op)) { + + Value destIterArg = parallelInsert.getDest(); + + auto it = accumulatorMap.find(destIterArg); + if (it == accumulatorMap.end()) + return parallelInsert.emitError( + "parallel_insert_slice dest is not a mapped loop iter_arg"); + + Value currentAcc = it->second; + + auto insertSlice = rewriter.create( + parallelInsert.getLoc(), parallelInsert.getSource(), currentAcc, + parallelInsert.getOffsets(), parallelInsert.getSizes(), + parallelInsert.getStrides(), parallelInsert.getStaticOffsets(), + parallelInsert.getStaticSizes(), parallelInsert.getStaticStrides()); + + accumulatorMap[destIterArg] = insertSlice.getResult(); + continue; + } + + if (isa(op)) + continue; + + op.moveBefore(inParallelOp); + } + + // 6. Yield the accumulated results from innermost loop body + SmallVector yieldOperands; + for (Value iterArg : innermostLoop.getRegionIterArgs()) + yieldOperands.push_back(accumulatorMap[iterArg]); + + rewriter.create(loc, yieldOperands); + + // Remove the old terminator + rewriter.eraseOp(inParallelOp); + + // 7. Generate yields for outer loops (inner loop results become yielded + // values) + for (size_t i = loops.size() - 1; i > 0; --i) { + auto innerLoop = cast(loops[i]); + auto outerLoop = cast(loops[i - 1]); + + rewriter.setInsertionPointToEnd(outerLoop.getBody()); + rewriter.create(loc, innerLoop.getResults()); + } + + DictionaryAttr forallAttrs = forallOp->getAttrDictionary(); + loops.front()->setAttrs(forallAttrs); + // 8. Replace Forall uses with the results of the outermost loop + rewriter.replaceOp(forallOp, loops.front()->getResults()); + LLVM_DEBUG(llvm::dbgs() << "[ForallToFor] Transformation complete.\n"); + return success(); +} +/// OpRewritePattern that wraps the above helper. +struct ForallToForPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForallOp forallOp, + PatternRewriter &rewriter) const override { + SmallVector generatedLoops; + if (failed(rewriteForallToFor(rewriter, forallOp, generatedLoops))) + return failure(); + return success(); + } +}; + +// --- Pass Definition --- + +struct LoopUnrollStagePass + : public mlir::dicp::LinalgExt::impl::LoopUnrollStageBase< + LoopUnrollStagePass> { + using LoopUnrollStageBase::LoopUnrollStageBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + + LDBG("--- Starting LoopUnrollStagePass on @" << func.getName() << " ---"); + + // --------------------------------------------------------- + // Phase 1: Normalize scf.forall -> scf.for + // --------------------------------------------------------- + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + // We use GreedyPatternRewriteDriver to handle potential nesting + // and ensure the IR is in a stable state before collection. + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } + } + // --------------------------------------------------------- + // Phase 2: Collect Candidate Loops + // --------------------------------------------------------- + // We collect loops into a vector first. If we unroll while walking, + // we risk invalidating the iterator or missing nested loops. + // func.walk defaults to PostOrder, which is ideal (inner loops first). + SmallVector candidateLoops; + func.walk([&](scf::ForOp forOp) { + if (forOp->hasAttr(kNPUStageAttrName)) { + candidateLoops.push_back(forOp); + } + }); + + if (candidateLoops.empty()) { + LDBG("No scf.for loops found with stage attributes."); + cleanupInternalAttributes(func); + return; + } + + LDBG("Collected " << candidateLoops.size() << " candidate scf.for loops."); + return; + // --------------------------------------------------------- + // Phase 3: Unroll Loops + // --------------------------------------------------------- + int successCount = 0; + + for (auto [index, forOp] : llvm::enumerate(candidateLoops)) { + if (failed(loopUnrollFull(forOp))) + return forOp.emitError("Failed to unroll loop ") << index, + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createLoopUnrollStagePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/NPUTileLoopTagging.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/NPUTileLoopTagging.cpp new file mode 100644 index 00000000..40cf3fd7 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/NPUTileLoopTagging.cpp @@ -0,0 +1,615 @@ +#include "dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Analysis/StageUtils.h" +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/DicpTransformOps.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Block.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" + +#define DEBUG_TYPE "npu-tile-loop-tagging" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_NPUVECTORTILETAGGING +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +/// Check if the operation is a candidate for elementwise-to-generic conversion. +static bool isConvertibleElementwiseOp(Operation *op) { + if (!op->hasAttr(kNPUStageAttrName) || isa(op)) + return false; + return OpTrait::hasElementwiseMappableTraits(op) && + llvm::all_of(op->getOperandTypes(), llvm::IsaPred); +} + +/// Create empty tensors for linalg outputs if matching operands aren't found. +static SmallVector +getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { + Location loc = op->getLoc(); + ValueRange operands = op->getOperands(); + return llvm::map_to_vector(op->getResultTypes(), [&](Type t) -> Value { + auto it = + llvm::find_if(operands, [&](Value v) { return v.getType() == t; }); + if (it != operands.end()) + return *it; + LDBG("getOrCreateOperandsMatchingResultTypes: Creating empty tensor for " + "type " + << t); + return b.create( + loc, tensor::getMixedSizes(b, loc, operands.front()), + cast(t).getElementType()); + }); +} + +/// Retrieve the loop or tensor rank of an operation. +static int64_t getRank(Operation *op) { + return TypeSwitch(op) + .Case( + [](auto linalgOp) { return linalgOp.getNumLoops(); }) + .Default([](Operation *op) -> int64_t { + if (op->getNumResults() > 0) + if (auto type = dyn_cast(op->getResult(0).getType())) + return type.getRank(); + return 0; + }); +} + +/// Calculate tile size based on the target trip count. +static int64_t calculateTileSize(Operation *anchorOp, int64_t dimIdx, + int64_t tripCount) { + if (tripCount <= 0) { + LDBG("calculateTileSize: tripCount is invalid: " << tripCount); + return -1; + } + auto getDimSize = [&](Value v) -> int64_t { + auto type = dyn_cast_or_null(v.getType()); + if (!type || dimIdx >= type.getRank()) + return -1; + return type.getDimSize(dimIdx); + }; + int64_t totalSize = + TypeSwitch(anchorOp) + .Case( + [&](auto op) { return getDimSize(op.getInputs().front()); }) + .Case([&](auto op) { + return op.getDpsInits().empty() + ? -1 + : getDimSize(op.getDpsInits().front()); + }) + .Default([&](auto op) { + return op->getNumResults() > 0 ? getDimSize(op->getResult(0)) : -1; + }); + + if (totalSize <= 0) { + LDBG("calculateTileSize: Could not determine total size for dim " + << dimIdx << " on op " << anchorOp->getName()); + return -1; + } + if (totalSize % tripCount != 0) { + LDBG("calculateTileSize: Total size " + << totalSize << " is not divisible by tripCount " << tripCount); + return -1; + } + return totalSize / tripCount; +} + +//===----------------------------------------------------------------------===// +// Data Structures +//===----------------------------------------------------------------------===// + +/// Metadata for a single tiling group (anchor and fused producers). +struct TilingUnit { + Operation *anchorOp = nullptr; + SmallVector tileSizes; + int64_t tilingDimIndex = -1; + int64_t rank = 0; + std::vector producerOps; + + std::string anchorTag; + std::string producerComputeTag; + std::string producerAllocTag; + std::string crossUserTag; +}; + +/// Represents a sub-stage containing multiple tiling units. +struct TiledSubStage { + SubStage base; + std::vector units; + explicit TiledSubStage(SubStage s) : base(std::move(s)) {} +}; + +//===----------------------------------------------------------------------===// +// Normalization Patterns +//===----------------------------------------------------------------------===// + +/// Sink bufferization.to_tensor immediately after its alloc operand. +struct SinkToTensorToAlloc + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(bufferization::ToTensorOp op, + PatternRewriter &rewriter) const override { + Operation *allocOp = op.getOperand().getDefiningOp(); + if (!isa_and_nonnull(allocOp) || + op->getPrevNode() == allocOp) + return failure(); + LDBG("SinkToTensorToAlloc: Sinking to_tensor after alloc " << *allocOp); + rewriter.modifyOpInPlace(op, [&]() { op->moveAfter(allocOp); }); + return success(); + } +}; + +/// Convert elementwise operations to linalg.generic to enable tiling. +struct ConvertElementwiseToGenericPattern : public RewritePattern { + ConvertElementwiseToGenericPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), 1, ctx) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isConvertibleElementwiseOp(op)) + return failure(); + LDBG("ConvertElementwiseToGenericPattern: Converting op " << op->getName()); + auto rank = getRank(op); + SmallVector maps(op->getNumResults() + op->getNumOperands(), + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iterTypes(rank, + utils::IteratorType::parallel); + auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); + auto genericOp = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands(), outputs, maps, + iterTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + auto resTypes = llvm::map_to_vector(op->getResultTypes(), [](Type t) { + return cast(t).getElementType(); + }); + auto *scalarOp = b.create(loc, op->getName().getIdentifier(), + args.take_front(op->getNumOperands()), + resTypes, op->getAttrs()); + scalarOp->removeAttr(kNPUStageAttrName); + b.create(loc, scalarOp->getResults()); + }); + if (auto attr = op->getAttr(kNPUStageAttrName)) + genericOp->setAttr(kNPUStageAttrName, attr); + rewriter.replaceOp(op, genericOp.getResults()); + return success(); + } +}; + +/// Lower various copy ops to linalg.copy for unified tiling treatment. +template +struct ConvertCopyLikeOpToLinalgCopy : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + auto stageAttr = op->template getAttrOfType(kNPUStageAttrName); + if (!stageAttr) + return failure(); + + LDBG("ConvertCopyLikeOpToLinalgCopy: Converting " << op->getName()); + Value src, dst; + if constexpr (std::is_same_v) { + src = op.getSource(); + dst = op.getDest(); + } else { + src = op.getSource(); + dst = op.getTarget(); + } + auto toMemref = [&](Value v) -> Value { + if (!isa(v.getType())) + return v; + auto type = + MemRefType::get(cast(v.getType()).getShape(), + cast(v.getType()).getElementType()); + auto res = + rewriter.create(op.getLoc(), type, v); + res->setAttr(kNPUStageAttrName, stageAttr); + return res; + }; + auto copy = rewriter.create(op.getLoc(), toMemref(src), + toMemref(dst)); + copy->setAttrs(op->getAttrs()); + copy->setAttr(kOriginalOpNameAttr, + rewriter.getStringAttr(op->getName().getStringRef())); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TilingProcessor +//===----------------------------------------------------------------------===// + +/// Analyzes stages to identify tiling anchors and assign unique tags. +class TilingProcessor { +public: + explicit TilingProcessor(int64_t tripCount, bool enableFallback = true) + : tripCount(tripCount), enableFallback(enableFallback) {} + + /// Analyze the sub-stage and tag operations for subsequent transformation. + LogicalResult analyzeAndTag(TiledSubStage &ts) const { + LDBG("[TilingProcessor] Analyzing SubStage " + << ts.base.index << " in Stage " << ts.base.stageId); + auto candidates = collectCandidates(ts.base); + + auto dims = analyzeDims(ts.base); + if (!dims) { + LDBG(" [Warning] Dim analysis failed for SubStage " << ts.base.index); + return failure(); + } + LDBG(" [Info] Analyzed tiling dims: " << dims->front()); + + SetVector claimedOps; + + // 1. Process high-priority candidates + if (!candidates.empty()) { + for (auto &cand : candidates) { + if (claimedOps.contains(cand.op)) { + LDBG(" [Skip] Candidate " << cand.op->getName() + << " already claimed by previous unit."); + continue; + } + + llvm::SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *op) { + return llvm::is_contained(ts.base.ops, op) && + op->getBlock() == cand.op->getBlock(); + }; + (void)getBackwardSlice(cand.op, &slice, opt); + slice.remove(cand.op); + + SmallVector producers; + for (Operation *p : slice) { + if (!claimedOps.contains(p)) + producers.push_back(p); + } + + if (auto unit = tryCreateUnit(cand.op, producers, *dims, ts)) { + tagUnit(*unit); + claimedOps.insert(cand.op); + claimedOps.insert(producers.begin(), producers.end()); + ts.units.push_back(std::move(*unit)); + LDBG(" [Success] Created Unit " + << ts.units.size() - 1 << " with Anchor " << cand.op->getName()); + } else { + LDBG(" [Fail] Could not create unit for anchor " + << cand.op->getName()); + } + } + } else { + LDBG(" [Info] No high-priority tiling candidates found."); + } + + // 2. Process Fallback: Cover remaining ops if enabled + if (enableFallback) { + processFallback(ts, claimedOps, *dims); + } + + return ts.units.empty() ? failure() : success(); + } + +private: + int64_t tripCount; + bool enableFallback; + + enum class Priority { Normalized = 1, Yield = 2, Fallback = 3 }; + struct Candidate { + Operation *op; + Priority prio; + size_t irIdx; + }; + + /// Identify dimensions suitable for tiling. + std::optional> analyzeDims(const SubStage &ss) const { + StageInfo info{ss.stageId, ss.ops}; + auto dims = DimAnalyzer(info).analyzeAndGetTilingDims(); + if (dims.empty()) + return std::nullopt; + std::sort(dims.begin(), dims.end()); + return dims; + } + + /// Collect potential anchor operations for tiling, sorted by priority. + std::vector collectCandidates(const SubStage &ss) const { + DenseMap best; + auto update = [&](Operation *o, Priority p, size_t i) { + auto &entry = best[o]; + if (!entry.op || p < entry.prio || (p == entry.prio && i < entry.irIdx)) + entry = {o, p, i}; + }; + + for (auto [i, op] : llvm::enumerate(ss.ops)) { + if (auto copy = dyn_cast(op)) + if (copy->hasAttr(kOriginalOpNameAttr)) + update(op, Priority::Normalized, i); + } + + if (!ss.ops.empty()) { + if (auto yield = dyn_cast( + ss.ops.back()->getBlock()->getTerminator())) { + for (Value v : yield.getOperands()) { + Operation *def = v.getDefiningOp(); + if (!def || !llvm::is_contained(ss.ops, def)) + continue; + + bool feedsNorm = llvm::any_of(def->getUsers(), [](Operation *u) { + return isa(u) && u->hasAttr(kOriginalOpNameAttr); + }); + if (isa(def) && !feedsNorm) { + auto it = llvm::find(ss.ops, def); + update(def, Priority::Yield, std::distance(ss.ops.begin(), it)); + } + } + } + } + + for (int i = ss.ops.size() - 1; i >= 0; --i) { + if (isa(ss.ops[i])) { + update(ss.ops[i], Priority::Fallback, i); + break; + } + } + + std::vector res; + for (auto &kv : best) + res.push_back(kv.second); + llvm::sort(res, [](const Candidate &a, const Candidate &b) { + return std::tie(a.prio, a.irIdx) < std::tie(b.prio, b.irIdx); + }); + LDBG("collectCandidates: Found " << res.size() << " potential anchors."); + return res; + } + + /// Attempt to construct a tiling unit for a given anchor. + std::optional tryCreateUnit(Operation *anchor, + ArrayRef producers, + ArrayRef dims, + const TiledSubStage &ts) const { + int64_t dimIdx = dims.front(); + int64_t tileSize = calculateTileSize(anchor, dimIdx, tripCount); + if (tileSize <= 0) { + LDBG("tryCreateUnit: Invalid tile size calculated for " + << anchor->getName()); + return std::nullopt; + } + + TilingUnit u; + u.anchorOp = anchor; + u.tilingDimIndex = dimIdx; + u.tileSizes = {tileSize}; + u.rank = getRank(anchor); + u.producerOps.assign(producers.begin(), producers.end()); + + auto fmt = [&](StringRef pattern) { + return llvm::formatv(pattern.data(), ts.base.stageId, ts.base.index, + ts.units.size()) + .str(); + }; + u.anchorTag = fmt(kStageOpToTileAttr); + u.producerComputeTag = fmt(kStageProducerToFuseAttr); + u.producerAllocTag = kStageProducerAllocToFuseAttr.str(); + u.crossUserTag = kCrossTillUnitAttr.str(); + return u; + } + + /// Attach string attributes to operations to guide the transform dialect. + /// This also packs the necessary tile sizes into a dictionary attribute + /// to pass data elegantly to the subsequent Transform pass. + void tagUnit(TilingUnit &u) const { + MLIRContext *ctx = u.anchorOp->getContext(); + OpBuilder b(ctx); + + LDBG("tagUnit: Tagging anchor " << u.anchorOp->getName() << " with " + << u.anchorTag); + // Mark the anchor with a UnitAttr for the Transform pass to MatchOp against + u.anchorOp->setAttr(u.anchorTag, b.getUnitAttr()); + + // Calculate full tile sizes array + SmallVector sizes(u.rank, 0); + if (u.tilingDimIndex >= 0 && u.tilingDimIndex < u.rank && + !u.tileSizes.empty()) { + sizes[u.tilingDimIndex] = u.tileSizes.front(); + } + + // Embed metadata into a dictionary attribute to safely pass to the + // Transform pass + SmallVector meta; + meta.push_back(b.getNamedAttr("anchor_tag", b.getStringAttr(u.anchorTag))); + meta.push_back( + b.getNamedAttr("compute_tag", b.getStringAttr(u.producerComputeTag))); + meta.push_back(b.getNamedAttr("tile_sizes", b.getDenseI64ArrayAttr(sizes))); + u.anchorOp->setAttr("npu.tiling_meta", b.getDictionaryAttr(meta)); + LDBG("tagUnit: Attached npu.tiling_meta to anchor."); + + SetVector currentUnitOps(u.producerOps.begin(), + u.producerOps.end()); + currentUnitOps.insert(u.anchorOp); + + for (Operation *p : u.producerOps) { + if (llvm::any_of(p->getUsers(), [&](Operation *user) { + return !currentUnitOps.contains(user); + })) { + LDBG("tagUnit: Producer " + << p->getName() + << " has users outside unit, tagging as cross-user."); + p->setAttr(u.crossUserTag, b.getUnitAttr()); + } + + bool isMem = + TypeSwitch(p) + .Case( + [](auto) { return true; }) + .Case([](auto op) { + return isa_and_nonnull( + op.getOperand().getDefiningOp()); + }) + .Default(false); + + StringRef tag = isMem ? u.producerAllocTag : u.producerComputeTag; + LDBG("tagUnit: Tagging producer " << p->getName() << " with " << tag); + p->setAttr(tag, b.getUnitAttr()); + } + } + + /// Scans for remaining operations in the substage that haven't been claimed + /// by any tiling unit. It traverses backwards to find the last valid + /// operation (lowest priority) and creates a fallback tiling unit. + void processFallback(TiledSubStage &ts, SetVector &claimedOps, + ArrayRef dims) const { + LDBG("processFallback: Checking for uncovered operations in SubStage " + << ts.base.index); + + for (Operation *op : llvm::reverse(ts.base.ops)) { + if (claimedOps.contains(op)) + continue; + + if (!isa(op)) + continue; + + LDBG("processFallback: Found candidate anchor: " << op->getName()); + + llvm::SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *p) { + return llvm::is_contained(ts.base.ops, p) && + p->getBlock() == op->getBlock(); + }; + (void)getBackwardSlice(op, &slice, opt); + slice.remove(op); + + SmallVector producers; + for (Operation *p : slice) { + if (!claimedOps.contains(p)) + producers.push_back(p); + } + + if (auto unit = tryCreateUnit(op, producers, dims, ts)) { + tagUnit(*unit); + claimedOps.insert(op); + claimedOps.insert(producers.begin(), producers.end()); + ts.units.push_back(std::move(*unit)); + LDBG("processFallback: Successfully created fallback Unit " + << ts.units.size() - 1 << " with Anchor " << op->getName()); + break; + } else { + LDBG("processFallback: Failed to create unit for " << op->getName()); + } + } + } +}; + +//===----------------------------------------------------------------------===// +// Pass Main Entry +//===----------------------------------------------------------------------===// + +/// Normalization and Tagging Pass for NPU loop tiling. +class NPUVectorTileTaggingPass + : public mlir::dicp::LinalgExt::impl::NPUVectorTileTaggingBase< + NPUVectorTileTaggingPass> { +public: + using NPUVectorTileTaggingBase::NPUVectorTileTaggingBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + + Block *targetBlock = CubeVectorSplitter::findTargetBlock( + *(module.getOps().begin())); + llvm::SmallVector stages; + (void)CubeVectorSplitter::splitBlock(*targetBlock, stages); + + return; + + int64_t tripCount = static_cast(tiledMixVectorLoopNumber); + LDBG("Run NPUVectorTileTaggingPass with tripCount=" << tripCount); + + auto blocks = StagePartitioner::findBlocksWithHivmSyncOps(module); + if (blocks.empty()) { + LDBG("No blocks with HIVM sync ops found. Skipping pass."); + return; + } + + bool anyStage = false; + for (Block *b : blocks) { + if (failed(StagePartitioner::analyzeAndTagBlock(b, ctx, anyStage))) { + LDBG("StagePartitioner failed to analyze/tag block."); + return; + } + } + if (!anyStage) { + LDBG("No stages identified in the module. Skipping."); + return; + } + + // Normalize operations to a tileable form. + RewritePatternSet patterns(ctx); + patterns.add, + ConvertCopyLikeOpToLinalgCopy>(ctx); + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + LDBG("Greedy pattern rewrite for normalization failed."); + return; + } + + // Analyze stages and assign transform tags. + TilingProcessor processor(tripCount); + for (Block *block : blocks) { + for (int stageId : StagePartitioner::getStageIdsInBlock(block)) { + for (auto &subStage : + StagePartitioner::partition(block, stageId).subStages) { + TiledSubStage ts(std::move(subStage)); + if (failed(processor.analyzeAndTag(ts))) { + LDBG("Processor failed to analyze SubStage " + << ts.base.index << " of Stage " << stageId); + } + } + } + } + + LDBG("NPUVectorTileTaggingPass completed successfully."); + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createNPUVectorTileTaggingPass( + const NPUVectorTileTaggingOptions &options) { + return std::make_unique(options); +} + +std::unique_ptr> +mlir::dicp::LinalgExt::createNPUVectorTileTaggingPass(unsigned vectorTile) { + NPUVectorTileTaggingOptions opt; + opt.tiledMixVectorLoopNumber = vectorTile; + return std::make_unique(opt); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp new file mode 100644 index 00000000..20819446 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/NPUUnroolPipeline.cpp @@ -0,0 +1,505 @@ +#include "dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h" +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include + +#define DEBUG_TYPE "npu-unroll-pipeline" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_NPUUNROOLPIPELINE +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +LogicalResult verifyLoopForPipelining(scf::ForOp forOp) { + auto lbOpt = getConstantIntValue(forOp.getLowerBound()); + auto ubOpt = getConstantIntValue(forOp.getUpperBound()); + auto stepOpt = getConstantIntValue(forOp.getStep()); + + if (!lbOpt.has_value() || !ubOpt.has_value() || !stepOpt.has_value()) { + LDBG("Verification FAILED: Loop bounds or step are dynamic."); + return failure(); + } + + int64_t step = stepOpt.value(); + if (step == 0) { + LDBG("Verification FAILED: Infinite loop (step = 0)."); + return failure(); + } + + int64_t lb = lbOpt.value(); + int64_t ub = ubOpt.value(); + if (step > 0 && lb >= ub) { + LDBG("Verification FAILED: Loop body is never executed."); + return failure(); + } + + int64_t tripCount = (ub - lb + step - 1) / step; + LDBG("Verification PASSED. Static Trip Count: " << tripCount); + return success(); +} + +// Marks operations that define yielded values for tensor/memref iter_args +// This allows us to track loop-carried dependencies across unrolled iterations. +static LogicalResult markYieldSources(scf::ForOp forOp) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (auto [idx, iterArg] : llvm::enumerate(forOp.getRegionIterArgs())) { + Value yieldVal = yieldOp.getOperand(idx); + + // Only strictly necessary for SSA values (tensors/scalars), but harmless + // for others. + if (auto defOp = yieldVal.getDefiningOp()) { + std::string attrName = "dicp.yield_for_iter_arg." + std::to_string(idx); + // We assume one op might feed multiple yield args, though rare. + // Ideally we check if attr exists, but simple overwrite is okay for 1:1. + defOp->setAttr( + attrName, + IntegerAttr::get(IntegerType::get(forOp.getContext(), 32), idx)); + LDBG(" Marked op '" << defOp->getName() + << "' as yield source for iter_arg " << idx); + } + } + return success(); +} + +static Operation *getYieldSourceForIterArg(scf::ForOp forOp, int iterArgIdx) { + // Linear scan is acceptable for loop bodies which are typically small-ish + for (Operation &op : forOp.getBody()->without_terminator()) { + std::string attrName = + "dicp.yield_for_iter_arg." + std::to_string(iterArgIdx); + if (op.hasAttr(attrName)) { + return &op; + } + } + return nullptr; +} + +class NPUUnrollPipeline { +public: + NPUUnrollPipeline(scf::ForOp forOp, int unrollFactor, + const std::vector &orderedStages) + : forOp(forOp), unrollFactor(unrollFactor), stages(orderedStages) {} + + LogicalResult run(RewriterBase &rewriter); + +private: + scf::ForOp forOp; + int unrollFactor; + const std::vector &stages; + int maxFlagPerIter = 0; + + // Map: OriginalValue -> Vector of Unrolled Values (one per iteration) + DenseMap> valueMapping; + // Map: OriginalOp -> Vector of Unrolled Ops (one per iteration) + // Needed to find cloned yield sources. + DenseMap> opMapping; + + void calculateMaxFlagStride(); + void prepareInitialMappings(RewriterBase &rewriter); + void updateHivmFlag(Operation *op, int iterIdx, RewriterBase &rewriter); + + Value getUnrolledValue(Value originalVal, int iterIdx); + Operation *cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + int iterIdx); +}; + +void NPUUnrollPipeline::calculateMaxFlagStride() { + int maxFlag = -1; + for (const auto &stage : stages) { + for (Operation *op : stage.ops) { + if (auto syncSetOp = dyn_cast(op)) { + int flag = getConstantIntValue(syncSetOp.getFlagId()).value_or(-1); + if (flag > maxFlag) + maxFlag = flag; + } else if (auto syncWaitOp = dyn_cast(op)) { + int flag = getConstantIntValue(syncWaitOp.getFlagId()).value_or(-1); + if (flag > maxFlag) + maxFlag = flag; + } + } + } + this->maxFlagPerIter = (maxFlag < 0) ? 0 : (maxFlag + 1); + LDBG("Flag Stride calculated: " << maxFlagPerIter); +} + +void NPUUnrollPipeline::prepareInitialMappings(RewriterBase &rewriter) { + LDBG(">>> [Unroll] Preparing Initial Mappings (Constants & IVs)..."); + Location loc = forOp.getLoc(); + Value lb = forOp.getLowerBound(); + Value step = forOp.getStep(); + Value iv = forOp.getInductionVar(); + Type ivType = iv.getType(); + + valueMapping[iv].resize(unrollFactor); + auto iterArgs = forOp.getRegionIterArgs(); + for (Value arg : iterArgs) { + valueMapping[arg].resize(unrollFactor, nullptr); + } + + for (int i = 0; i < unrollFactor; ++i) { + // 1. IV Calculation + Value idxVal = rewriter.create(loc, i); + Value idxValTyped = idxVal; + if (ivType != idxVal.getType()) + idxValTyped = rewriter.create(loc, ivType, idxVal); + + Value stepOffset = rewriter.create(loc, step, idxValTyped); + Value currentIV = rewriter.create(loc, lb, stepOffset); + valueMapping[iv][i] = currentIV; + + // 2. Simple IterArg Calculation (e.g. arithmetic induction) + auto yieldOp = cast(forOp.getBody()->getTerminator()); + for (auto it : llvm::enumerate(iterArgs)) { + Value iterArg = it.value(); + Value yieldVal = yieldOp.getOperand(it.index()); + + Operation *defOp = yieldVal.getDefiningOp(); + bool isSimpleIV = false; + int64_t stepConst = 0; + + if (auto addOp = dyn_cast_or_null(defOp)) { + Value lhs = addOp.getLhs(); + Value rhs = addOp.getRhs(); + Value constOp = nullptr; + if (lhs == iterArg) + constOp = rhs; + else if (rhs == iterArg) + constOp = lhs; + + if (constOp) { + if (auto cst = constOp.getDefiningOp()) { + stepConst = cst.value(); + isSimpleIV = true; + } else if (auto cst = constOp.getDefiningOp()) { + stepConst = cst.value(); + isSimpleIV = true; + } + } + } + + if (isSimpleIV) { + Value initVal = forOp.getInitArgs()[it.index()]; + Value kVal = rewriter.create(loc, i); + Value kValTyped = kVal; + if (iterArg.getType() != kVal.getType()) + kValTyped = + rewriter.create(loc, iterArg.getType(), kVal); + + Value stepVal; + if (iterArg.getType().isIndex()) + stepVal = rewriter.create(loc, stepConst); + else + stepVal = rewriter.create( + loc, iterArg.getType(), stepConst); + + Value offset = rewriter.create(loc, kValTyped, stepVal); + Value currVal = rewriter.create(loc, initVal, offset); + valueMapping[iterArg][i] = currVal; + } + } + } +} + +Value NPUUnrollPipeline::getUnrolledValue(Value originalVal, int iterIdx) { + // 1. Check existing mapping (simple IVs or previously cloned ops) + if (valueMapping.count(originalVal)) { + if (iterIdx >= 0 && iterIdx < valueMapping[originalVal].size()) { + Value mapped = valueMapping[originalVal][iterIdx]; + if (mapped) + return mapped; + } + } + + // 2. Handle BlockArguments (IterArgs) + if (auto arg = dyn_cast(originalVal)) { + if (arg.getOwner() == forOp.getBody()) { + // IV is handled in prepareInitialMappings (Slot 0 of args) + if (arg.getArgNumber() == 0) + return nullptr; + + // IterArgs start at index 1 + int iterArgIdx = arg.getArgNumber() - 1; + + // Case 2a: Iteration 0 uses the Loop Init Args (Full unroll) + if (iterIdx == 0) { + return forOp.getInitArgs()[iterArgIdx]; + } + + // Case 2b: Iteration K > 0 uses Yield result from K-1 + // Strategy: Find the op marked as yield source and look up its clone. + Operation *yieldSourceOp = getYieldSourceForIterArg(forOp, iterArgIdx); + if (yieldSourceOp) { + // The YieldOp operand tells us which result of the source op is used + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value yieldOperand = yieldOp.getOperand(iterArgIdx); + + if (auto res = dyn_cast(yieldOperand)) { + // If yield operand is a direct result of the marked op + if (res.getOwner() == yieldSourceOp) { + int resIdx = res.getResultNumber(); + // Check if the source op for the previous iteration was cloned + if (opMapping.count(yieldSourceOp) && + iterIdx - 1 < opMapping[yieldSourceOp].size()) { + Operation *prevClone = opMapping[yieldSourceOp][iterIdx - 1]; + if (prevClone) { + return prevClone->getResult(resIdx); + } else { + LDBG(" WARNING: Yield source clone missing for iter " + << iterIdx - 1); + } + } + } + } else if (auto argOperand = dyn_cast(yieldOperand)) { + // The yield operand is an IterArg itself (Pass-through) + // Recursively resolve it + return getUnrolledValue(argOperand, iterIdx - 1); + } + } + + // Fallback: If no complex logic found, try recursive lookup on yield + // operand (This handles cases where the yield val is invariant or defined + // elsewhere) + Value directYieldVal = + cast(forOp.getBody()->getTerminator()) + .getOperand(iterArgIdx); + return getUnrolledValue(directYieldVal, iterIdx - 1); + } + } + + // 3. Invariant or Global values + return originalVal; +} + +Operation *NPUUnrollPipeline::cloneAndUpdateOperands(RewriterBase &rewriter, + Operation *op, + int iterIdx) { + IRMapping mapper; + + // Walk the operation to identify and map all externally defined values used + // within 'op' or its nested regions. This ensures that when 'op' is cloned, + // any references to values defined in the original loop scope are correctly + // remapped to their unrolled counterparts for the current iteration. + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + // Skip if already mapped. + if (mapper.contains(operand)) + continue; + + bool isExternal = false; + // Check if the operand is a BlockArgument defined outside of 'op'. + if (auto arg = dyn_cast(operand)) { + Operation *parentOp = arg.getOwner()->getParentOp(); + // It is external if the parent op is neither 'op' nor a descendant of 'op'. + if (parentOp != op && !op->isAncestor(parentOp)) + isExternal = true; + } + // Check if the operand is an OpResult defined outside of 'op'. + else if (auto defOp = operand.getDefiningOp()) { + // It is external if the defining op is neither 'op' nor a descendant of 'op'. + if (defOp != op && !op->isAncestor(defOp)) + isExternal = true; + } + + if (isExternal) { + // Retrieve the unrolled value for the current iteration. + Value replacement = getUnrolledValue(operand, iterIdx); + if (replacement) { + mapper.map(operand, replacement); + } + } + } + }); + + // Clone the operation using the populated mapper. + // This handles deep cloning and operand remapping for both the op and its + // nested regions (like scf.for body). + Operation *clone = rewriter.clone(*op, mapper); + + // Record the cloned op in the mapping for future lookups (Yield Source + // resolution) + if (opMapping[op].size() <= iterIdx) + opMapping[op].resize(unrollFactor); + opMapping[op][iterIdx] = clone; + + return clone; +} + +void NPUUnrollPipeline::updateHivmFlag(Operation *op, int iterIdx, + RewriterBase &rewriter) { + if (maxFlagPerIter == 0) + return; + auto update = [&](auto syncOp) { + if (auto attr = syncOp.getStaticFlagIdAttr()) { + int64_t newFlag = attr.getInt() + iterIdx * maxFlagPerIter; + syncOp.setStaticFlagIdAttr(rewriter.getI64IntegerAttr(newFlag)); + } + }; + if (auto setOp = dyn_cast(op)) + update(setOp); + else if (auto waitOp = dyn_cast(op)) + update(waitOp); +} + +LogicalResult NPUUnrollPipeline::run(RewriterBase &rewriter) { + calculateMaxFlagStride(); + + // Resize mappings + for (Operation &op : forOp.getBody()->without_terminator()) { + opMapping[&op].resize(unrollFactor, nullptr); + for (Value res : op.getResults()) { + valueMapping[res].resize(unrollFactor, nullptr); + } + } + + rewriter.setInsertionPoint(forOp); + prepareInitialMappings(rewriter); + + LDBG(">>> [Unroll] Starting Clone (Stage-Major Order)..."); + + for (const auto &stage : stages) { + LDBG(" Processing Stage " << stage.id); + for (int iterIdx = 0; iterIdx < unrollFactor; ++iterIdx) { + for (Operation *op : stage.ops) { + if (isa(op)) + continue; + + Operation *clonedOp = cloneAndUpdateOperands(rewriter, op, iterIdx); + + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "] [Stg " << stage.id << "][Iter " + << iterIdx << "] Cloned Op: "; + clonedOp->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + updateHivmFlag(clonedOp, iterIdx, rewriter); + + // Update value mapping for results + for (auto it : llvm::zip(op->getResults(), clonedOp->getResults())) { + Value originalRes = std::get<0>(it); + Value newRes = std::get<1>(it); + if (iterIdx < valueMapping[originalRes].size()) + valueMapping[originalRes][iterIdx] = newRes; + } + } + } + } + + LDBG(">>> [Unroll] Replacing Loop Results..."); + Operation *terminator = forOp.getBody()->getTerminator(); + SmallVector finalResults; + + // The final results correspond to the yield values of the LAST iteration + int lastIter = unrollFactor - 1; + + for (Value operand : terminator->getOperands()) { + Value remapped = getUnrolledValue(operand, lastIter); + if (!remapped) + remapped = operand; + finalResults.push_back(remapped); + } + + if (forOp.getNumResults() != finalResults.size()) { + return forOp.emitError("Unroll result count mismatch"); + } + + rewriter.replaceOp(forOp, finalResults); + LDBG("<<< Pass Complete."); + return success(); +} + +struct NPUUnroolPipelinePass + : public mlir::dicp::LinalgExt::impl::NPUUnroolPipelineBase< + NPUUnroolPipelinePass> { + NPUUnroolPipelinePass() = default; + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + + SmallVector loops; + func.walk([&](scf::ForOp loop) { + if (loop->hasAttr(mlir::triton::kNumStagesAttrName)) + loops.push_back(loop); + }); + + if (loops.size() != 1) { + LDBG("The number of candidate loops is not one."); + return; + } + + scf::ForOp targetLoop = loops[0]; + if (failed(verifyLoopForPipelining(targetLoop))) { + LDBG("Loop verification failed, skipping."); + return; + } + + int numStages = mlir::cast( + targetLoop->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + + LDBG("Processing Loop with num_stages = " << numStages); + + mlir::IRRewriter rewriter(func.getContext()); + AliasAnalysis &aa = getAnalysis(); + // 1. Analyze and Reorder Stages (Topological Sort) + StageDependencyAnalyzer analyzer(targetLoop.getBody(), aa); + auto orderedStagesOrFailure = analyzer.runAndReorder(rewriter); + + if (failed(orderedStagesOrFailure)) { + LDBG("Failed to reorder stages (cyclic dependency detected)."); + signalPassFailure(); + return; + } + // 2. Mark Yield Sources for complex iter_args + if (failed(markYieldSources(targetLoop))) { + signalPassFailure(); + return; + } + + // 3. Execute Unroll (Stage-Major) + NPUUnrollPipeline unroller(targetLoop, numStages, + orderedStagesOrFailure.value()); + if (failed(unroller.run(rewriter))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createNPUUnroolPipelinePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/NPUVectorTileTransform.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/NPUVectorTileTransform.cpp new file mode 100644 index 00000000..27d9e730 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/NPUVectorTileTransform.cpp @@ -0,0 +1,190 @@ +#include "dicp/Dialect/LinalgExt/Analysis/StageUtils.h" +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/DicpTransformOps.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/Support/FormatVariadic.h" + +#define DEBUG_TYPE "npu-tile-loop-transform" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_NPUVECTORTILETRANSFORM +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +//===----------------------------------------------------------------------===// +// Data Structures +//===----------------------------------------------------------------------===// + +/// Extracted metadata from the tagging pass. +struct TilingMeta { + std::string anchorTag; + std::string computeTag; + SmallVector tileSizes; +}; + +//===----------------------------------------------------------------------===// +// TransformGenerator +//===----------------------------------------------------------------------===// + +/// Generates Transform Dialect IR based on the analyzed tiling units. +class TransformGenerator { +public: + TransformGenerator(OpBuilder &b, Location loc, Value root) + : b(b), loc(loc), root(root) {} + + /// Process all tiling metadata extracted from the module. + void generate(ArrayRef metas) { + for (const auto &meta : metas) { + generateUnit(meta); + } + } + +private: + OpBuilder &b; + Location loc; + Value root; + + /// Emit tiling and fusion sequence for a single unit. + void generateUnit(const TilingMeta &meta) { + LDBG("generateUnit: Emitting transform sequence for anchor tag: " + << meta.anchorTag); + Value anchor = getMatch(meta.anchorTag); + + auto tile = b.create( + loc, anchor, meta.tileSizes, transform::TileSizesSpec(), nullptr); + Value loop = tile.getForallOp(); + + Value producers = getMatch(meta.computeTag, true); + + auto foreachOp = b.create( + loc, + TypeRange{b.getType(), + b.getType()}, + producers, false); + + OpBuilder::InsertionGuard g(b); + b.createBlock(&foreachOp.getBody(), {}, {b.getType()}, + {loc}); + + ImplicitLocOpBuilder ib(loc, b); + // Correct number of args for fusion + auto fused = ib.create( + ib.getType(), ib.getType(), + foreachOp.getBody().getArgument(0), loop); + + auto apply = ib.create( + fused.getNewContainingOp().front(), [](OpBuilder &pb, Location ploc) { + pb.create(ploc); + }); + apply.setApplyCse(true); + ib.create(fused.getResults()); + } + + /// Helper to create a transform.match op for a given attribute tag. + Value getMatch(StringRef attr, bool reverse = false) { + auto match = b.create( + loc, b.getType(), root, nullptr, nullptr, + b.getDictionaryAttr(b.getNamedAttr(attr, b.getUnitAttr())), nullptr, + nullptr); + return reverse ? b.create( + loc, b.getType(), match) + .getResult() + : match.getResult(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Main Entry +//===----------------------------------------------------------------------===// + +/// Main pass for applying the tiling and fusion sequence via Transform Dialect. +class NPUVectorTileTransformPass + : public mlir::dicp::LinalgExt::impl::NPUVectorTileTransformBase< + NPUVectorTileTransformPass> { +public: + using NPUVectorTileTransformBase::NPUVectorTileTransformBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + LDBG("Running NPUVectorTileTransformPass..."); + + // 1. Collect metadata from the tagging pass and clean up attributes. + SmallVector metaList; + module.walk([&](Operation *op) { + if (auto dict = op->getAttrOfType("npu.tiling_meta")) { + TilingMeta meta; + meta.anchorTag = + cast(dict.get("anchor_tag")).getValue().str(); + meta.computeTag = + cast(dict.get("compute_tag")).getValue().str(); + auto sizesAttr = cast(dict.get("tile_sizes")); + meta.tileSizes = llvm::to_vector(sizesAttr.asArrayRef()); + metaList.push_back(meta); + + LDBG("Found tiling meta for anchor: " << meta.anchorTag); + // Clean up the metadata attribute to avoid polluting final IR. + op->removeAttr("npu.tiling_meta"); + } + }); + + if (metaList.empty()) { + LDBG("No tiling metadata found in the module. Skipping transform " + "application."); + return; + } + + // 2. Apply transformation via Transform Dialect interpreter. + LDBG("Applying Transform Dialect generation for " << metaList.size() + << " tiling units."); + TransformApplier::apply( + module, [&](OpBuilder &b, Location loc, Value root) { + TransformGenerator(b, loc, root).generate(metaList); + }); + + // 3. Final IR cleanup. + LDBG("Running final cleanup pipeline (CSE/Canonicalizer)."); + PassManager pm(ctx, module.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, module))) { + LDBG("Final cleanup pipeline failed."); + signalPassFailure(); + } + + LDBG("NPUVectorTileTransformPass completed successfully."); + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createNPUVectorTileTransformPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/ShrinkBuffers.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/ShrinkBuffers.cpp new file mode 100644 index 00000000..f8dacaf1 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/ShrinkBuffers.cpp @@ -0,0 +1,379 @@ +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "buffer-shrink" +#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] " << X << "\n") + +using namespace mlir; +using namespace dicp; +using namespace LinalgExt; + +namespace mlir::dicp::LinalgExt { +#define GEN_PASS_DEF_SHRINKBUFFERS +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace mlir::dicp::LinalgExt + +namespace { + +//===----------------------------------------------------------------------===// +// Helper: Slice Parameter Analysis +//===----------------------------------------------------------------------===// + +/// Holds extracted static parameters for a slice/view operation. +struct SliceParams { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + Type elementType; + + /// Checks if two SliceParams represent the exact same region and type. + bool operator==(const SliceParams &other) const { + return offsets == other.offsets && sizes == other.sizes && + strides == other.strides && elementType == other.elementType; + } + + bool operator!=(const SliceParams &other) const { return !(*this == other); } +}; + +/// Tries to extract static slice parameters from an operation. +/// Returns failure if the op is not a supported slice/view or has dynamic +/// shapes. +LogicalResult getStaticSliceParams(Operation *op, SliceParams ¶ms) { + // Support both tensor.extract_slice and memref.subview + auto iface = dyn_cast(op); + if (!iface) + return failure(); + + // 1. Check if all offsets, sizes, and strides are static. + // We check the "static" arrays for any dynamic placeholders. + auto staticOffsets = iface.getStaticOffsets(); + auto staticSizes = iface.getStaticSizes(); + auto staticStrides = iface.getStaticStrides(); + + auto isDynamic = [](int64_t v) { return ShapedType::isDynamic(v); }; + + if (llvm::any_of(staticOffsets, isDynamic) || + llvm::any_of(staticSizes, isDynamic) || + llvm::any_of(staticStrides, isDynamic)) { + LLVM_DEBUG(llvm::dbgs() + << " [Shrink] Op has dynamic shapes: " << *op << "\n"); + return failure(); + } + + // 2. Use .assign() instead of '=' for SmallVector + params.offsets.assign(staticOffsets.begin(), staticOffsets.end()); + params.sizes.assign(staticSizes.begin(), staticSizes.end()); + params.strides.assign(staticStrides.begin(), staticStrides.end()); + + // Get element type for consistency checking + if (auto shapedType = dyn_cast(op->getResult(0).getType())) { + params.elementType = shapedType.getElementType(); + } else { + return failure(); + } + + return success(); +} + +/// Validates that the slice is contiguous (unit strides). +/// Shrinking a buffer with non-unit strides into a dense buffer changes data +/// layout, which breaks consumers expecting specific pointer arithmetic. +bool isContiguousSlice(const SliceParams ¶ms) { + return llvm::all_of(params.strides, [](int64_t s) { return s == 1; }); +} + + +//===----------------------------------------------------------------------===// +// Pattern 1: ShrinkTensorEmpty +//===----------------------------------------------------------------------===// + +/// Pattern to shrink `tensor.empty` size based on its usage. +/// +/// Matches: +/// %0 = tensor.empty() : tensor<100x100xf32> +/// %1 = tensor.extract_slice %0[0,0][10,10][1,1] ... +/// +/// And/Or: +/// %0 = tensor.empty() ... +/// %buf = bufferization.to_memref %0 ... +/// %view = memref.subview %buf[0,0][10,10][1,1] ... +/// +/// Requirement: All users must point to the *same* static contiguous +/// sub-region. +struct ShrinkTensorEmpty : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::EmptyOp emptyOp, + PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() + << "Analyze ShrinkTensorEmpty: " << emptyOp << "\n"); + + SliceParams commonParams; + bool isFirst = true; + + SmallVector extractSliceUsers; + SmallVector subviewUsers; + + // 1. Analyze all direct users of tensor.empty + for (Operation *user : emptyOp->getUsers()) { + // Case A: Direct use by tensor.extract_slice + if (isa(user)) { + SliceParams params; + if (failed(getStaticSliceParams(user, params))) + return failure(); + + if (isFirst) { + commonParams = params; + isFirst = false; + } else if (commonParams != params) { + LLVM_DEBUG(llvm::dbgs() + << " [Shrink] Inconsistent extract_slice params.\n"); + return failure(); + } + extractSliceUsers.push_back(user); + continue; + } + + // Case B: Direct use by bufferization.to_memref (or to_buffer) + // Note: We check for ToBufferOp. If your dialect uses a different cast + // (e.g., CastOp), adapt here. + if (isa(user)) { + // Check users of the buffer cast + for (Operation *bufUser : user->getUsers()) { + if (isa(bufUser)) { + SliceParams params; + if (failed(getStaticSliceParams(bufUser, params))) + return failure(); + + if (isFirst) { + commonParams = params; + isFirst = false; + } else if (commonParams != params) { + LLVM_DEBUG( + llvm::dbgs() + << " [Shrink] Inconsistent subview params via to_memref.\n"); + return failure(); + } + subviewUsers.push_back(bufUser); + } else { + LLVM_DEBUG(llvm::dbgs() << " [Shrink] to_memref has invalid user: " + << *bufUser << "\n"); + return failure(); + } + } + continue; + } + + // Invalid user found + LLVM_DEBUG(llvm::dbgs() + << " [Shrink] Unhandled user: " << *user << "\n"); + return failure(); + } + + if (isFirst) { + LLVM_DEBUG(llvm::dbgs() << " [Shrink] No valid users found.\n"); + return failure(); + } + + // 2. Safety Check: Ensure the slice is contiguous (stride 1). + if (!isContiguousSlice(commonParams)) { + LLVM_DEBUG( + llvm::dbgs() + << " [Shrink] Slice is not contiguous. Cannot shrink safely.\n"); + return failure(); + } + + // 3. Rewrite + // Create new smaller tensor.empty + // The result type is inferred from the slice size and element type. + auto newRankedType = + RankedTensorType::get(commonParams.sizes, commonParams.elementType); + + // We replace the consumers first, then the producer. + rewriter.setInsertionPoint(emptyOp); + auto newEmptyOp = rewriter.create( + emptyOp.getLoc(), newRankedType, ValueRange{}); // No dynamic sizes + + // Replace all extract_slice users with the new empty tensor + // (Since the new tensor *is* the slice, the extract operation is + // redundant/identity) + for (Operation *op : extractSliceUsers) { + if (op->getResult(0).getType() != newRankedType) { + // This implies rank-reduction or type mismatch, which we must handle or + // bail. For simplicity in this pattern, we assume exact match or let + // verification fail if complex. But usually extract_slice result type + // == empty tensor of that size. + } + rewriter.replaceOp(op, newEmptyOp); + } + + // Handle to_memref -> subview chains + // We need a new to_memref converting the NEW empty tensor to a NEW memref + if (!subviewUsers.empty()) { + auto newMemRefType = + MemRefType::get(commonParams.sizes, commonParams.elementType); + auto newToMemref = rewriter.create( + emptyOp.getLoc(), newMemRefType, newEmptyOp); + + for (Operation *subview : subviewUsers) { + // Check if we need a cast. The subview result might have a specific + // layout map (offset: ?). The new alloc is canonical (offset: 0). + Value replacement = newToMemref; + if (subview->getResult(0).getType() != newMemRefType) { + replacement = rewriter.create( + subview->getLoc(), subview->getResult(0).getType(), newToMemref); + } + rewriter.replaceOp(subview, replacement); + } + } + + // tensor.empty itself doesn't have side effects, so if users are gone, it's + // dead. However, we must ensure we don't leave dangling users of the + // intermediate to_memref. The standard DCE will handle the old to_memref + // and emptyOp if they have no users. + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pattern 2: ShrinkMemRefAlloc +//===----------------------------------------------------------------------===// + +/// Pattern to shrink `memref.alloc` size based on `subview` usage. +/// +/// Matches: +/// %alloc = memref.alloc() : memref<128xf32> +/// %view = memref.subview %alloc[0][64][1] : ... +/// +/// Requirements: +/// - Users of %alloc are ONLY subviews. +/// - All subviews access the *same* static contiguous region. +struct ShrinkMemRefAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocOp allocOp, + PatternRewriter &rewriter) const override { + LLVM_DEBUG(llvm::dbgs() + << "Analyze ShrinkMemRefAlloc: " << allocOp << "\n"); + + SliceParams commonParams; + bool isFirst = true; + SmallVector subviewUsers; + + // 1. Analyze users + for (Operation *user : allocOp->getUsers()) { + auto subview = dyn_cast(user); + if (!subview) { + LLVM_DEBUG(llvm::dbgs() << " [Shrink] Alloc has non-subview user: " + << *user << "\n"); + return failure(); + } + + SliceParams params; + if (failed(getStaticSliceParams(subview, params))) + return failure(); + + if (isFirst) { + commonParams = params; + isFirst = false; + } else if (commonParams != params) { + LLVM_DEBUG(llvm::dbgs() + << " [Shrink] Inconsistent subview definitions.\n"); + return failure(); + } + subviewUsers.push_back(subview); + } + + if (isFirst) { + return failure(); // No users + } + + // 2. Safety Check: Contiguous + if (!isContiguousSlice(commonParams)) { + LLVM_DEBUG( + llvm::dbgs() + << " [Shrink] Subview is not contiguous. Cannot shrink alloc.\n"); + return failure(); + } + + // 3. Rewrite + // Construct the new MemRef type (Canonical layout, because it's a fresh + // alloc) + auto newMemRefType = + MemRefType::get(commonParams.sizes, commonParams.elementType); + + rewriter.setInsertionPoint(allocOp); + auto newAlloc = + rewriter.create(allocOp.getLoc(), newMemRefType); + + // Propagate attributes (e.g., memory space, alignment) if necessary + if (allocOp.getAlignment().has_value()) { + newAlloc.setAlignment(allocOp.getAlignment().value()); + } + // Note: We deliberately drop extra attributes like "dicp.npu.stage" unless + // we know they remain valid. However, usually alloc attributes should be + // preserved. rewriter.replaceOpWithNewOp handles result replacement, but + // here we are changing types. + + for (auto subview : subviewUsers) { + Value replacement = newAlloc; + + // If the subview result type differs from the new canonical alloc type + // (e.g. strict stride info), we must cast the new alloc to the old + // subview type to satisfy consumers. + if (subview.getType() != newMemRefType) { + replacement = rewriter.create( + subview.getLoc(), subview.getType(), newAlloc); + } + + rewriter.replaceOp(subview, replacement); + } + + // The old alloc has no users now (we replaced the subviews, which were the + // only users). + rewriter.eraseOp(allocOp); + + return success(); + } +}; + +struct ShrinkBuffersPass + : public mlir::dicp::LinalgExt::impl::ShrinkBuffersBase { + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + + // Insert your two patterns. The pattern constructors below assume they take + // MLIRContext* or PatternRewriter/OperationContext as appropriate. + patterns.add(ctx); + + // Optionally: add canonicalization / folding patterns if helpful. + // apply patterns greedily to the function. + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::dicp::LinalgExt::createShrinkBuffersPass() { + return std::make_unique(); +} diff --git a/compiler/lib/TransformOps/CMakeLists.txt b/compiler/lib/TransformOps/CMakeLists.txt new file mode 100644 index 00000000..0e954ea6 --- /dev/null +++ b/compiler/lib/TransformOps/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(DICPTransformOps + TransformsUtils.cpp + DialectExtension.cpp + DicpTransformOps.cpp + + DEPENDS + DICPTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRArithDialect + MLIRFuncDialect + MLIRIndexDialect + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgTransformOps + MLIRSideEffectInterfaces + MLIRTensorDialect + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRTransformUtils + ) \ No newline at end of file diff --git a/compiler/lib/TransformOps/DialectExtension.cpp b/compiler/lib/TransformOps/DialectExtension.cpp new file mode 100644 index 00000000..0d69dabb --- /dev/null +++ b/compiler/lib/TransformOps/DialectExtension.cpp @@ -0,0 +1,54 @@ + +#include "dicp/TransformOps/DicpTransformOps.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/TypeID.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class DCIPTransformDialectExtension + : public transform::TransformDialectExtension< + DCIPTransformDialectExtension> { +public: + using Base::Base; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DCIPTransformDialectExtension); + + void init() { + declareDependentDialect(); + declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "dicp/TransformOps/DicpTransformOps.cpp.inc" + >(); + } +}; + +} // namespace + +void mlir::dicp::registerTransformDialectExtension(DialectRegistry ®istry) { + mlir::linalg::registerTilingInterfaceExternalModels(registry); + mlir::tensor::registerTilingInterfaceExternalModels(registry); + registry.addExtensions(); +} diff --git a/compiler/lib/TransformOps/DicpTransformOps.cpp b/compiler/lib/TransformOps/DicpTransformOps.cpp new file mode 100644 index 00000000..c5bf8b25 --- /dev/null +++ b/compiler/lib/TransformOps/DicpTransformOps.cpp @@ -0,0 +1,685 @@ +#include "dicp/TransformOps/DicpTransformOps.h" +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/Syntax.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "dicp-transform-op" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::transform; +using namespace mlir::dicp; + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure ReverseOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + SmallVector reversedOperations = {targets.rbegin(), + targets.rend()}; + transformResults.set(cast(getResult()), reversedOperations); + return DiagnosedSilenceableFailure::success(); +} + +void ReverseOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); +} + +// ============================================================================ +// ForwardInitToIterArgOp +// ============================================================================ + +/// Checks if two operations operating on subsets (extraction vs insertion) +/// are geometrically equivalent (same offsets, sizes, and strides). +/// +/// This uses the `OffsetSizeAndStrideOpInterface` to be dialect-agnostic. +static bool areSlicesEquivalent(Operation *readOp, Operation *writeOp) { + auto readInterface = dyn_cast(readOp); + auto writeInterface = dyn_cast(writeOp); + + if (!readInterface || !writeInterface) { + LDBG(" One of the ops does not implement OffsetSizeAndStrideOpInterface."); + return false; + } + + // Compare mixed offsets, sizes, and strides. + // Note: This relies on SSA value equality for dynamic dims. + bool offsetsMatch = llvm::equal(readInterface.getMixedOffsets(), + writeInterface.getMixedOffsets()); + bool sizesMatch = llvm::equal(readInterface.getMixedSizes(), + writeInterface.getMixedSizes()); + bool stridesMatch = llvm::equal(readInterface.getMixedStrides(), + writeInterface.getMixedStrides()); + + if (!offsetsMatch || !sizesMatch || !stridesMatch) { + LLVM_DEBUG({ + if (!offsetsMatch) + DBGS() << " Offsets mismatch.\n"; + if (!sizesMatch) + DBGS() << " Sizes mismatch.\n"; + if (!stridesMatch) + DBGS() << " Strides mismatch.\n"; + }); + return false; + } + + return true; +} + +/// Strategy for scf.forall: +/// The write-back occurs in the `scf.in_parallel` terminator via a +/// SubsetInsertionOp (usually tensor.parallel_insert_slice). +static Operation *findWriteBackOp(scf::ForallOp loopOp, + BlockArgument regionArg) { + scf::InParallelOp terminator = loopOp.getTerminator(); + Block &terminatorBlock = terminator.getRegion().front(); + + for (Operation &op : terminatorBlock) { + // Check if it's a subset insertion (like parallel_insert_slice) + if (auto insertOp = dyn_cast(op)) { + // For parallel insert, the destination is the BlockArgument of the loop + // that corresponds to the output. + if (insertOp.getDestinationOperand().get() == regionArg) { + return &op; + } + } + } + return nullptr; +} + +/// Strategy for scf.for: +/// The write-back is the value yielded by `scf.yield`. We need to check if +/// that yielded value is defined by a SubsetInsertionOp that inserts *into* +/// the corresponding region argument. +static Operation *findWriteBackOp(scf::ForOp loopOp, BlockArgument regionArg) { + auto yieldOp = cast(loopOp.getBody()->getTerminator()); + + // 1. Identify which result index this regionArg corresponds to. + // scf.for region args: [iv, iter_arg_0, iter_arg_1, ...] + // The iter_args start at index 1 (since index 0 is IV). + unsigned iterArgIndex = regionArg.getArgNumber() - 1; // Subtract IV + + if (iterArgIndex >= yieldOp.getResults().size()) { + return nullptr; // Should not happen if IR is valid + } + + Value yieldedVal = yieldOp.getOperand(iterArgIndex); + Operation *defOp = yieldedVal.getDefiningOp(); + + if (!defOp) + return nullptr; + + // 2. Check if the yielded value comes from an insertion op + if (auto insertOp = dyn_cast(defOp)) { + // 3. Check if the insertion destination IS the region argument. + // i.e., %new = insert_slice %update into %iter_arg + if (insertOp.getDestinationOperand().get() == regionArg) { + return defOp; + } + } + + return nullptr; +} + +/// Generic processor that works for both scf.for and scf.forall. +/// It relies on the `findWriteBackOp` overload to handle structural +/// differences. +template +static void processLoop(LoopTy loopOp, RewriterBase &rewriter) { + LDBG("Processing loop: " << loopOp.getOperation()->getName()); + + auto regionIterArgs = loopOp.getRegionIterArgs(); + + // scf.forall uses getOutputs(), scf.for uses getInitArgs(). + // We use a lambda to abstract this access. + auto getInitOperands = [&](auto op) -> OperandRange { + if constexpr (std::is_same_v) + return op.getOutputs(); + else + return op.getInitArgs(); + }; + + auto initOperands = getInitOperands(loopOp); + + // Iterate over each (InitOperand, RegionIterArg) pair + for (auto it : llvm::zip(initOperands, regionIterArgs)) { + Value initVal = std::get<0>(it); + BlockArgument regionArg = std::get<1>(it); + + LDBG(" Analyzing pair: InitVal=" << initVal + << ", RegionArg=" << regionArg); + + // 1. Find the write-back operation (Insertion) + Operation *writeOp = findWriteBackOp(loopOp, regionArg); + if (!writeOp) { + LDBG(" No valid write-back (insertion) found for this argument. " + "Skipping."); + continue; + } + LDBG(" Found write-back op: " << *writeOp); + + // 2. Find read operations (Extraction) inside the loop body + // We look for extractions that read from the *external* 'initVal'. + SmallVector candidates; + for (Operation *user : initVal.getUsers()) { + // Ensure the user is strictly inside the loop body + if (loopOp.getBody()->findAncestorOpInBlock(*user)) { + if (isa(user)) { + candidates.push_back(user); + } + } + } + + if (candidates.empty()) { + LDBG(" No extraction users of InitVal found inside loop."); + continue; + } + + // 3. Compare and Replace + for (Operation *readOp : candidates) { + LDBG(" Checking candidate read op: " << *readOp); + + if (areSlicesEquivalent(readOp, writeOp)) { + LDBG(" MATCH! Slices are equivalent. Forwarding init arg to iter " + "arg."); + + // Transform: Replace the source of the extraction (which is currently + // the external init_arg) with the internal region_arg (iter_arg). + // This enables in-place bufferization. + + rewriter.setListener( + nullptr); // Disable listener for simple operand updates + rewriter.modifyOpInPlace(readOp, [&]() { + // Use the interface to set the source operand generically + // Note: getSourceOperand() returns an OpOperand&. + auto subsetOp = cast(readOp); + subsetOp.getSourceOperand().set(regionArg); + }); + } else { + LDBG(" Mismatch: Geometry differs."); + } + } + } +} + +DiagnosedSilenceableFailure +ForwardInitToIterArgOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + SmallVector processedOps; + auto payloadOps = state.getPayloadOps(getTarget()); + + for (Operation *op : payloadOps) { + bool isProcessed = false; + if (!op) { + LDBG("Skipping op: " << op->getName() << " is null"); + continue; + } + + // Dispatch to the appropriate template instantiation + if (auto forallOp = dyn_cast(op)) { + processLoop(forallOp, rewriter); + isProcessed = true; + } else if (auto forOp = dyn_cast(op)) { + processLoop(forOp, rewriter); + isProcessed = true; + } + + if (!isProcessed) { + LDBG("Skipping op: " << op->getName() << " (not scf.for or scf.forall)"); + } + + // We preserve the operations in the result handle + processedOps.push_back(op); + } + + transformResults.set(cast(getResult()), processedOps); + return DiagnosedSilenceableFailure::success(); +} + +void ForwardInitToIterArgOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// ExtendedFuseIntoContainingOp +//===----------------------------------------------------------------------===// + +void transform::ExtendedFuseIntoContainingOp::build(OpBuilder &builder, + OperationState &result, + Value producerOp, + Value containingOp) { + result.addOperands({producerOp, containingOp}); + auto resultType = transform::AnyOpType::get(builder.getContext()); + result.addTypes({resultType, resultType}); +} + +bool transform::ExtendedFuseIntoContainingOp::allowsRepeatedHandleOperands() { + // Allow repeated handles since we are fusing everything anyway. + return true; +} + +DiagnosedSilenceableFailure +transform::ExtendedFuseIntoContainingOp::fuseIntoOneContaining( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state, + size_t index, Operation *containingOp) { + assert(index < getFusedOp().size()); + assert(index < getNewContainingOp().size()); + + SmallVector fusedOps; + auto producerOps = state.getPayloadOps(getProducerOp()); + + LLVM_DEBUG({ + DBGS() << "=== ExtendedFuseIntoContainingOp: producerOps ===\n"; + for (Operation *op : producerOps) { + DBGS() << "producerOp @" << op << ":\n"; + op->print(DBGS()); + DBGS() << "\n----------------------------------------\n"; + } + DBGS() << "containingOp @" << containingOp << " :\n "; + containingOp->print(DBGS()); + DBGS() << "=== end producerOps ===\n"; + }); + + // If nothing to fuse, propagate success. + if (std::empty(producerOps)) { + results.set(cast(getFusedOp()[index]), + SmallVector{}); + results.set(cast(getNewContainingOp()[index]), {containingOp}); + return DiagnosedSilenceableFailure::success(); + } + + SetVector remainingProducers(producerOps.begin(), + producerOps.end()); + auto getNextProducer = [&]() -> FailureOr> { + for (const auto &it : enumerate(remainingProducers)) { + Operation *producerOp = it.value(); + // The containing op may be a user of producerOp: use isAncestor. + int64_t numUsesInContainingOp = + llvm::count_if(producerOp->getUsers(), [&](Operation *op) { + return containingOp->isAncestor(op); + }); + LLVM_DEBUG(DBGS() << "producerOp: " << *producerOp << "\n"); + LLVM_DEBUG(DBGS() << "numUsesInContainingOp: " << numUsesInContainingOp + << "\n"); + if (numUsesInContainingOp > 0) { + return std::make_pair(producerOp, it.index()); + } + } + return failure(); + }; + + // Helper function to erase producerOp from eraseRemainingProducer if no + // users. + auto eraseRemainingProducer = [&](Operation *producerOp, size_t pos) { + int64_t numUsesInContainingOp = + llvm::count_if(producerOp->getUsers(), [&](Operation *op) { + return containingOp->isAncestor(op); + }); + if (numUsesInContainingOp == 0) { + remainingProducers.erase(remainingProducers.begin() + pos); + } + }; + + while (!remainingProducers.empty()) { + auto nextProducer = getNextProducer(); + if (failed(nextProducer)) { + auto diag = mlir::emitSilenceableFailure(getLoc()) + << "could not find next producer to fuse into container"; + diag.attachNote(containingOp->getLoc()) << "containing op"; + return diag; + } + + Operation *producerOp; + size_t producerIndex; + std::tie(producerOp, producerIndex) = *nextProducer; + + // Default diagnostic, to be complemented with more failure information. + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); + diag << "could not fuse " << *producerOp << " into " << *containingOp; + + // 1. Try to tile and fuse all subset users (extract slice, etc.) + // Note: unionProducerUsers is removed because tileAndFuseAllSubsetOps + // handles multiple users individually, removing the need to pre-union them. + auto [tiledOps, newContainingOp] = mlir::dicp::tileAndFuseAllSubsetOps( + rewriter, diag, producerOp, containingOp, getDuplicateProducer()); + + if (!tiledOps.empty()) { + LLVM_DEBUG(DBGS() << "\nFused direct subset ops\n" + << *containingOp << "\n"); + fusedOps.append(tiledOps); + if (newContainingOp) { + // Update handles associated with the containing op so we don't need + // to invalidate them. This supports better composability between + // tiling and fusion. + LLVM_DEBUG({ + llvm::dbgs() << "[extended_fuse] replacing containing op\n"; + llvm::dbgs() << " old: "; + containingOp->print(llvm::dbgs()); + llvm::dbgs() << "\n new: "; + newContainingOp->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + LogicalResult replacementStatus = + rewriter.notifyPayloadOperationReplaced(containingOp, + newContainingOp); + (void)replacementStatus; + assert(succeeded(replacementStatus) && + "unable to update transform state mapping"); + containingOp = newContainingOp; + } + eraseRemainingProducer(producerOp, producerIndex); + continue; + } + + // 2. Try to tile and fuse subset users of the block argument + // (e.g., when the producer is passed as an init operand to scf.forall) + SmallVector tiledContainingOpOperand; + if (auto loopLike = dyn_cast(containingOp)) { + tiledContainingOpOperand = + mlir::dicp::tileAndFuseAllSubsetOpsThroughContainingOpBlockArgument( + rewriter, diag, producerOp, loopLike); + } + if (!tiledContainingOpOperand.empty()) { + LLVM_DEBUG(DBGS() << "\nFused subset ops through block argument\n" + << *containingOp); + fusedOps.append(tiledContainingOpOperand); + eraseRemainingProducer(producerOp, producerIndex); + continue; + } + + // 3. Try to clone and fuse users (element-wise fusion by cloning) + Operation *cloned = mlir::dicp::cloneAndFuseAllSubsetOps( + rewriter, diag, producerOp, containingOp); + if (cloned) { + LLVM_DEBUG(DBGS() << "\nFused uses by cloning\n" << *containingOp); + // We append the single representative fused op returned by cloneAndFuse. + // Ideally, we might want to track all cloned ops, but the interface + // returns the last one currently. + fusedOps.push_back(cloned); + eraseRemainingProducer(producerOp, producerIndex); + continue; + } + + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + results.set(cast(getFusedOp()[index]), fusedOps); + results.set(cast(getNewContainingOp()[index]), {containingOp}); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::ExtendedFuseIntoContainingOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto containingOps = getContainingOp(); + LLVM_DEBUG({ + for (auto containingOpHandle : containingOps) { + auto payloads = state.getPayloadOps(containingOpHandle); + DBGS() << "Containing op handle has " + << std::distance(payloads.begin(), payloads.end()) + << " payload operations\n"; + } + }); + + for (auto it : llvm::enumerate(containingOps)) { + auto containingOpPayloads = state.getPayloadOps(it.value()); + if (!llvm::hasSingleElement(containingOpPayloads)) { + return emitDefiniteFailure() + << "requires exactly one containing_op handle (got " + << llvm::range_size(containingOpPayloads) << ")"; + } + Operation *currentOp = *containingOpPayloads.begin(); + auto status = + fuseIntoOneContaining(rewriter, results, state, it.index(), currentOp); + if (!status.succeeded()) + return status; + } + return DiagnosedSilenceableFailure::success(); +} + +ParseResult ExtendedFuseIntoContainingOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand producer; + SmallVector containingOps; + FunctionType functionalType; + llvm::SMLoc producerLoc; + llvm::SMLoc containingOpsLoc; + + if (parser.getCurrentLocation(&producerLoc) || parser.parseOperand(producer)) + return ParseResult::failure(); + + if (parser.parseKeyword("into")) + return ParseResult::failure(); + + if (parser.getCurrentLocation(&containingOpsLoc) || + parser.parseOperandList(containingOps)) + return ParseResult::failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return ParseResult::failure(); + + if (result.propertiesAttr) { + NamedAttrList attrs = llvm::cast(result.propertiesAttr); + attrs.append("resultSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(containingOps.size()), + static_cast(containingOps.size())})); + result.propertiesAttr = attrs.getDictionary(parser.getContext()); + } else { + result.addAttribute("resultSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(containingOps.size()), + static_cast(containingOps.size())})); + } + + if (parser.parseColonType(functionalType)) + return ParseResult::failure(); + + if (parser.resolveOperand(producer, functionalType.getInputs().front(), + result.operands) || + parser.resolveOperands(containingOps, + functionalType.getInputs().drop_front(), + containingOpsLoc, result.operands)) { + return ParseResult::failure(); + } + + result.addTypes(functionalType.getResults()); + return ParseResult::success(); +} + +void ExtendedFuseIntoContainingOp::print(OpAsmPrinter &p) { + p << ' ' << getProducerOp(); + p << ' ' << "into"; + p << ' '; + p.printOperands(getContainingOp()); + p.printOptionalAttrDict((*this)->getAttrs(), {"resultSegmentSizes"}); + p << " : "; + p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); +} + +void transform::ExtendedFuseIntoContainingOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getProducerOpMutable(), effects); + onlyReadsHandle(getContainingOpMutable(), effects); + producesHandle(getResults(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// LoopFuseSiblingOp +//===----------------------------------------------------------------------===// + +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static DiagnosedSilenceableFailure isOpSibling(Operation *target, + Operation *source) { + // Check if both operations are same. + if (target == source) + return emitSilenceableFailure(source) + << "target and source need to be different loops"; + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) + return emitSilenceableFailure(source) + << "target and source are not in the same block"; + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + return emitSilenceableFailure(target) + << "user of results of target should be properly dominated by " + "source"; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) + return emitSilenceableFailure(target) + << "operands of target should be properly dominated by source"; + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) + return emitSilenceableFailure(failedValue->getOwner()) + << "values used inside regions of target should be properly " + "dominated by source"; + } + + return DiagnosedSilenceableFailure::success(); +} + +/// Check if `target` scf.forall can be fused into `source` scf.forall. +/// +/// This simply checks if both loops have the same bounds, steps and mapping. +/// No attempt is made at checking that the side effects of `target` and +/// `source` are independent of each other. +static bool isForallWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); +} + +/// Check if `target` scf.for can be fused into `source` scf.for. +/// +/// This simply checks if both loops have the same bounds and steps. No attempt +/// is made at checking that the side effects of `target` and `source` are +/// independent of each other. +static bool isForWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + return targetOp.getLowerBound() == sourceOp.getLowerBound() && + targetOp.getUpperBound() == sourceOp.getUpperBound() && + targetOp.getStep() == sourceOp.getStep(); +} + +DiagnosedSilenceableFailure +ExtendedLoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + auto sourceOps = state.getPayloadOps(getSource()); + + if (!llvm::hasSingleElement(targetOps) || + !llvm::hasSingleElement(sourceOps)) { + return emitDefiniteFailure() + << "requires exactly one target handle (got " + << llvm::range_size(targetOps) << ") and exactly one " + << "source handle (got " << llvm::range_size(sourceOps) << ")"; + } + + Operation *target = *targetOps.begin(); + Operation *source = *sourceOps.begin(); + + // Check if the target and source are siblings. + DiagnosedSilenceableFailure diag = isOpSibling(target, source); + if (!diag.succeeded()) + return diag; + + Operation *fusedLoop; + /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. + if (isForWithIdenticalConfiguration(target, source)) { + fusedLoop = fuseIndependentSiblingForLoops( + cast(target), cast(source), rewriter); + } else if (isForallWithIdenticalConfiguration(target, source)) { + fusedLoop = fuseIndependentSiblingForallLoops( + cast(target), cast(source), rewriter); + } else { + return emitSilenceableFailure(target->getLoc()) + << "operations cannot be fused"; + } + + assert(fusedLoop && "failed to fuse operations"); + + results.set(cast(getFusedLoop()), {fusedLoop}); + return DiagnosedSilenceableFailure::success(); +} + +#define GET_OP_CLASSES +#include "dicp/TransformOps/DicpTransformOps.cpp.inc" diff --git a/compiler/lib/TransformOps/TransformsUtils.cpp b/compiler/lib/TransformOps/TransformsUtils.cpp new file mode 100644 index 00000000..4eceb81e --- /dev/null +++ b/compiler/lib/TransformOps/TransformsUtils.cpp @@ -0,0 +1,1190 @@ +#include "dicp/TransformOps/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/TransformOps/Syntax.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/SubsetOpInterface.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "dicp-transform-op-utils" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::transform; +using namespace mlir::dicp; + +//===----------------------------------------------------------------------===// +// Common Utilities +//===----------------------------------------------------------------------===// + +static bool isSubsetOp(Operation *op) { + return isa(op) && + (isa(op) || + isa(op)); +} + +static SmallVector recursiveClone(RewriterBase &rewriter, + SmallVector values, + Operation *clonePoint) { + LDBG("Start recursiveClone"); + SmallVector newValues; + for (auto value : values) { + if (isa(value)) { + newValues.push_back(value); + continue; + } + + auto *defOperation = value.getDefiningOp(); + if (defOperation == nullptr) { + return newValues; + } + + // Clone dependency if defined before the clone point in the same block. + if (clonePoint->getBlock() == defOperation->getBlock() && + clonePoint->isBeforeInBlock(defOperation)) { + LDBG(" Cloning dependency: " << *defOperation); + auto operands = defOperation->getOperands(); + auto clonedValues = recursiveClone(rewriter, operands, clonePoint); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(clonePoint); + + IRMapping mapping; + mapping.map(operands, clonedValues); + auto *clonedOp = rewriter.clone(*defOperation, mapping); + newValues.push_back( + clonedOp->getResult(cast(value).getResultNumber())); + } else { + newValues.push_back(value); + } + } + return newValues; +} + +static bool isValidSliceOpInContainingOp(Operation *op, + Operation *containingOp) { + if (!op || !containingOp->isProperAncestor(op)) { + return false; + } + + if (!isSubsetOp(op)) { + return false; + } + + // Only support unit strides for union logic. + auto sliceOp = cast(op); + auto staticStrides = sliceOp.getStaticStrides(); + if (llvm::count_if(staticStrides, [](int64_t s) { return s != 1; }) > 0) { + LDBG("Slice has non-unit stride, invalid for union: " << *op); + return false; + } + + return true; +} + +static void getFirstSliceUserInContainingOp( + Operation *producerOp, Operation *containingOp, + llvm::DenseMap *result2FirstSliceOp, + llvm::DenseMap *result2ValidNum) { + LDBG("Scanning for first slice user of producer: " << producerOp->getName()); + for (auto res : producerOp->getResults()) { + Operation *firstSliceOp = nullptr; + int validNum = 0; + for (auto user : res.getUsers()) { + if (!isValidSliceOpInContainingOp(user, containingOp)) { + continue; + } + + if (!firstSliceOp || user->isBeforeInBlock(firstSliceOp)) { + firstSliceOp = user; + } + validNum++; + } + result2ValidNum->insert(std::pair(res, validNum)); + if (firstSliceOp) { + result2FirstSliceOp->insert(std::pair(res, firstSliceOp)); + LDBG(" Found first slice: " << *firstSliceOp + << " (Total valid: " << validNum << ")"); + } + } +} + +enum class MODE { + UNION_MAX, + UNION_MIN, + COMPUTE_SLICE_MAX, + COMPUTE_SUB, + COMPUTE_DISTANCE +}; + +static SmallVector compute(RewriterBase &rewriter, MODE mode, + const SmallVectorImpl &lhs, + const SmallVectorImpl &rhs, + Location loc) { + auto symA = rewriter.getAffineSymbolExpr(0); + auto symB = rewriter.getAffineSymbolExpr(1); + auto one = rewriter.getAffineConstantExpr(1); + AffineMap map; + + switch (mode) { + case MODE::UNION_MAX: + case MODE::UNION_MIN: + map = AffineMap::get(0, 2, {symA, symB}, rewriter.getContext()); + break; + case MODE::COMPUTE_SLICE_MAX: + map = AffineMap::get(0, 2, {symA + symB - one}, rewriter.getContext()); + break; + case MODE::COMPUTE_SUB: + map = AffineMap::get(0, 2, {symA - symB}, rewriter.getContext()); + break; + case MODE::COMPUTE_DISTANCE: + map = AffineMap::get(0, 2, {symA - symB + one}, rewriter.getContext()); + break; + } + + SmallVector results; + for (auto it : llvm::zip(lhs, rhs)) { + auto l = std::get<0>(it); + auto r = std::get<1>(it); + Value result; + switch (mode) { + case MODE::UNION_MAX: + result = rewriter.create(loc, map, ValueRange{l, r}); + break; + case MODE::UNION_MIN: + result = rewriter.create(loc, map, ValueRange{l, r}); + break; + case MODE::COMPUTE_SLICE_MAX: + case MODE::COMPUTE_SUB: + case MODE::COMPUTE_DISTANCE: + result = + rewriter.create(loc, map, ValueRange{l, r}); + break; + } + results.push_back(result); + } + return results; +} + +SmallVector convert(SmallVectorImpl &values) { + SmallVector results; + for (auto it : values) { + results.push_back(OpFoldResult(it)); + } + return results; +} + +static SmallVector createEqualZeroOp(const SmallVector &targets, + RewriterBase &rewriter, + Location loc) { + SmallVector results; + for (Value target : targets) { + // Cast to i64 for arithmetic comparison. + Value castResult = + rewriter.create(loc, rewriter.getI64Type(), target); + Value zero = + rewriter.create(loc, rewriter.getI64Type(), 0); + Value cond = rewriter.create(loc, arith::CmpIPredicate::eq, + castResult, zero); + results.push_back(cond); + } + return results; +} + +static SmallVector createSelectOp(const SmallVector &conds, + const SmallVector &trues, + const SmallVector &falses, + RewriterBase &rewriter, Location loc) { + SmallVector results; + for (size_t i = 0; i < conds.size(); ++i) { + Value result = + rewriter.create(loc, conds[i], trues[i], falses[i]); + results.push_back(result); + } + return results; +} + +static void unionFirstProducerUser(RewriterBase &rewriter, + Operation *firstSliceOp, + SmallVector &unionOffsets, + SmallVector &unionMaxes) { + LDBG("first SliceOp \n" << *firstSliceOp); + rewriter.setInsertionPoint(firstSliceOp); + + auto sliceInterface = cast(firstSliceOp); + auto sliceOffsets = getValueOrCreateConstantIndexOp( + rewriter, firstSliceOp->getLoc(), sliceInterface.getMixedOffsets()); + auto sliceSizes = getValueOrCreateConstantIndexOp( + rewriter, firstSliceOp->getLoc(), sliceInterface.getMixedSizes()); + + Value source; + if (auto viewLike = dyn_cast(firstSliceOp)) { + source = viewLike.getViewSource(); + } else { + source = firstSliceOp->getOperand(0); + } + + auto srcMixedSizes = + tensor::getMixedSizes(rewriter, firstSliceOp->getLoc(), source); + auto srcSizes = getValueOrCreateConstantIndexOp( + rewriter, firstSliceOp->getLoc(), srcMixedSizes); + + auto isSizesZero = + createEqualZeroOp(sliceSizes, rewriter, firstSliceOp->getLoc()); + + // If slice size is 0, use source size as offset (MAX_VALUE) to avoid + // affecting min. + unionOffsets = createSelectOp(isSizesZero, srcSizes, sliceOffsets, rewriter, + firstSliceOp->getLoc()); + + // If slice size is 0, use slice size (0/MIN_VALUE) to avoid affecting max. + auto initMaxes = compute(rewriter, MODE::COMPUTE_SLICE_MAX, unionOffsets, + sliceSizes, firstSliceOp->getLoc()); + unionMaxes = createSelectOp(isSizesZero, sliceSizes, initMaxes, rewriter, + firstSliceOp->getLoc()); +} + +static void unionNextProducerUser(RewriterBase &rewriter, Location loc, + const SmallVector &offsets, + const SmallVector &sizes, + SmallVector &unionOffsets, + SmallVector &unionMaxes) { + LDBG("Unioning next user..."); + auto isSizesZero = createEqualZeroOp(sizes, rewriter, loc); + + // Update union offsets: min(current, new). + auto newOffsets = + createSelectOp(isSizesZero, unionOffsets, offsets, rewriter, loc); + unionOffsets = + compute(rewriter, MODE::UNION_MIN, unionOffsets, newOffsets, loc); + + // Update union maxes: max(current, new_end). + auto computeMaxes = + compute(rewriter, MODE::COMPUTE_SLICE_MAX, newOffsets, sizes, loc); + auto clonedMaxes = + createSelectOp(isSizesZero, unionMaxes, computeMaxes, rewriter, loc); + unionMaxes = compute(rewriter, MODE::UNION_MAX, unionMaxes, clonedMaxes, loc); +} + +static tensor::ExtractSliceOp +sliceFromUnion(RewriterBase &rewriter, tensor::ExtractSliceOp unionSlice, + const SmallVector &unionOffsets, Operation *sliceOp) { + LDBG("Creating sliceFromUnion for: " << *sliceOp); + rewriter.setInsertionPoint(sliceOp); + + auto sliceInterface = cast(sliceOp); + + auto offsets = getValueOrCreateConstantIndexOp( + rewriter, sliceOp->getLoc(), sliceInterface.getMixedOffsets()); + auto sizes = getValueOrCreateConstantIndexOp(rewriter, sliceOp->getLoc(), + sliceInterface.getMixedSizes()); + auto isSizesZero = createEqualZeroOp(sizes, rewriter, sliceOp->getLoc()); + + // If zero-sized, reset offset to union offset to avoid out-of-bounds + // calculation. + offsets = createSelectOp(isSizesZero, unionOffsets, offsets, rewriter, + sliceOp->getLoc()); + auto newOffsets = compute(rewriter, MODE::COMPUTE_SUB, offsets, unionOffsets, + sliceOp->getLoc()); + + auto newSlice = rewriter.create( + sliceOp->getLoc(), unionSlice.getResult(), convert(newOffsets), + sliceInterface.getMixedSizes(), unionSlice.getMixedStrides()); + return newSlice; +} + +void mlir::dicp::unionProducerUsers(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, + Operation *containingOp) { + LDBG("unionProducerUsers entry for producer: " << *producerOp); + llvm::DenseMap result2FirstSliceOp; + llvm::DenseMap result2ValidNum; + getFirstSliceUserInContainingOp(producerOp, containingOp, + &result2FirstSliceOp, &result2ValidNum); + + for (auto produceResult : producerOp->getResults()) { + int validSliceOpNum = result2ValidNum[produceResult]; + + // Optimization primarily for > 1 user. + if (validSliceOpNum < 2) { + continue; + } + + auto firstSliceOp = result2FirstSliceOp[produceResult]; + SmallVector unionOffsets; + SmallVector unionMaxes; + + LDBG("begin to union \n" << *containingOp); + unionFirstProducerUser(rewriter, firstSliceOp, unionOffsets, unionMaxes); + + for (auto *user : produceResult.getUsers()) { + if (!isValidSliceOpInContainingOp(user, containingOp) || + user == firstSliceOp) { + continue; + } + + LDBG("union slice \n" << *user); + auto sliceInterface = cast(user); + + // Clone values to ensure availability at the union point. + auto curOffsets = getValueOrCreateConstantIndexOp( + rewriter, user->getLoc(), sliceInterface.getMixedOffsets()); + auto clonedOffsets = recursiveClone(rewriter, curOffsets, firstSliceOp); + + auto curSizes = getValueOrCreateConstantIndexOp( + rewriter, user->getLoc(), sliceInterface.getMixedSizes()); + auto clonedSizes = recursiveClone(rewriter, curSizes, firstSliceOp); + + unionNextProducerUser(rewriter, user->getLoc(), clonedOffsets, + clonedSizes, unionOffsets, unionMaxes); + } + + auto unionSizes = compute(rewriter, MODE::COMPUTE_DISTANCE, unionMaxes, + unionOffsets, firstSliceOp->getLoc()); + + auto firstSliceInterface = + cast(firstSliceOp); + Value source; + if (auto viewLike = dyn_cast(firstSliceOp)) { + source = viewLike.getViewSource(); + } else { + source = firstSliceOp->getOperand(0); + } + + auto unionSlice = rewriter.create( + firstSliceOp->getLoc(), source, convert(unionOffsets), + convert(unionSizes), firstSliceInterface.getMixedStrides()); + + LDBG("insert union slice \n" << unionSlice); + + // Update users to extract from the new union slice. + for (auto *user : llvm::make_early_inc_range(produceResult.getUsers())) { + if (!isValidSliceOpInContainingOp(user, containingOp) || + user == unionSlice) { + continue; + } + auto newSliceOp = + sliceFromUnion(rewriter, unionSlice, unionOffsets, user); + rewriter.replaceOp(user, newSliceOp.getResult()); + } + } +} + +/// Specific handler for scf.forall reconstruction. +static scf::ForallOp appendToForall(RewriterBase &rewriter, + scf::ForallOp forallOp, Value newOutput, + Value tiledVal, + ArrayRef offsets, + ArrayRef sizes) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forallOp); + + // 1. Create new forall op with an additional output operand + SmallVector outputs = llvm::to_vector(forallOp.getOutputs()); + outputs.push_back(newOutput); + + auto newForallOp = rewriter.create( + forallOp.getLoc(), forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), outputs, + forallOp.getMapping()); + + rewriter.eraseBlock(newForallOp.getBody()); + newForallOp.getRegion().takeBody(forallOp.getRegion()); + + // Note: ForallOp's bbArgs are [induction_vars..., output_iter_args...] + BlockArgument newBBArg = newForallOp.getBody()->addArgument( + newOutput.getType(), newOutput.getLoc()); + + // 3. Update the scf.in_parallel terminator + auto terminator = newForallOp.getTerminator(); + rewriter.setInsertionPointToEnd(terminator.getBody()); + SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); + rewriter.create( + newForallOp.getLoc(), tiledVal, newBBArg, offsets, sizes, strides); + + // 4. Update uses of the original loop results (SSA Chain Rewriting). + // The new loop returns the original results plus the appended one. + for (auto [oldRes, newRes] : + llvm::zip(forallOp.getResults(), newForallOp.getResults())) { + rewriter.replaceAllUsesWith(oldRes, newRes); + } + + // 5. Restore a valid dummy body to prevent verification failure for the old + // op. The old op is dead code but must remain valid until cleanup. + Block *ghostBlock = rewriter.createBlock(&forallOp.getRegion()); + // Add induction variables + for (int i = 0; i < forallOp.getRank(); ++i) + ghostBlock->addArgument(rewriter.getIndexType(), forallOp.getLoc()); + // Add output arguments + for (Value out : forallOp.getOutputs()) + ghostBlock->addArgument(out.getType(), forallOp.getLoc()); + + rewriter.setInsertionPointToEnd(ghostBlock); + rewriter.create(forallOp.getLoc()); + + return newForallOp; +} + +/// Specific handler for scf.for reconstruction. +/// Appends a new output to the scf.for loop, moves the body, updates the yield, +/// and ensures the original loop remains syntactically valid (though dead). +static scf::ForOp appendToFor(RewriterBase &rewriter, scf::ForOp forOp, + Value newOutput, Value tiledVal, + ArrayRef offsets, + ArrayRef sizes) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + + // 1. Prepare new iter_args (original inits + new output) + SmallVector newIterArgs = llvm::to_vector(forOp.getInits()); + newIterArgs.push_back(newOutput); + + // 2. Create new scf.for op with the expanded signature + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newIterArgs); + + // 3. Transfer the body from the old loop to the new loop + // We erase the default empty block created by the builder for newForOp first. + rewriter.eraseBlock(newForOp.getBody()); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // 4. Fix the new loop body + // Add the new block argument corresponding to the new iter_arg + BlockArgument newBlockArg = + newForOp.getBody()->addArgument(newOutput.getType(), newOutput.getLoc()); + + // Replace uses of the new output *inside* the loop with the new block + // argument. This enables fusion into the new loop. + rewriter.replaceUsesWithIf(newOutput, newBlockArg, [&](OpOperand &use) { + Operation *op = use.getOwner(); + return newForOp->isProperAncestor(op); + }); + + // 5. Update scf.yield terminator in the new loop + auto yieldOp = cast(newForOp.getBody()->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + + // Create the InsertSliceOp to update the new iter_arg + // (tiledVal -> slice of newBlockArg) + SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); + Value updatedTensor = rewriter.create( + tiledVal.getLoc(), tiledVal, newBlockArg, offsets, sizes, strides); + + // Update yield operands: originals + updated new tensor + SmallVector newYieldOperands = llvm::to_vector(yieldOp.getOperands()); + newYieldOperands.push_back(updatedTensor); + + rewriter.create(yieldOp.getLoc(), newYieldOperands); + rewriter.eraseOp(yieldOp); + + // 6. Update uses of the original loop results (SSA Chain Rewriting). + // The new loop returns the original results plus the appended one. + for (auto [oldRes, newRes] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + rewriter.replaceAllUsesWith(oldRes, newRes); + } + + // 7. Restore a valid dummy body to prevent verification failure. + Block *ghostBlock = rewriter.createBlock(&forOp.getRegion()); + + // Add IV and dummy iter_args matching the original loop signature. + ghostBlock->addArgument(rewriter.getIndexType(), loc); + for (Value init : forOp.getInits()) + ghostBlock->addArgument(init.getType(), loc); + // Yield the dummy iter_args (all arguments except the IV at index 0). + rewriter.setInsertionPointToEnd(ghostBlock); + rewriter.create(loc, ghostBlock->getArguments().drop_front()); + return newForOp; +} + +/// Main logic to append output to a loop and update dependencies. +static Operation * +appendLoopResultAndFuse(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, Operation *containingOp, + TilingResult &tileAndFuseResult, int64_t resultNumber, + SmallVector &offsets, + SmallVector &sizes) { + + LLVM_DEBUG(llvm::dbgs() << "Checking if output appending is needed for: " + << *producerOp << "\n"); + producerOp->setAttr(kHadFusedAttr, UnitAttr::get(rewriter.getContext())); + // 1. Dominance check for users outside the loop + SetVector dominatedUsers; + DominanceInfo domInfo(containingOp); + Value producerResult = producerOp->getResult(resultNumber); + + for (Operation *user : producerResult.getUsers()) { + if (!containingOp->isAncestor(user) && + domInfo.dominates(containingOp, user)) { + LLVM_DEBUG(llvm::dbgs() << "[dominatedUsers]: " << *user << "\n"); + dominatedUsers.insert(user); + } + } + + bool hasCrossSubStageAttr = producerOp->hasAttr(kCrossTillUnitAttr); + // If no dominated users and no cross-stage attribute, we don't need to append + // the result to the loop. + if (dominatedUsers.empty() && !hasCrossSubStageAttr) + return nullptr; + + auto genericOp = dyn_cast(producerOp); + if (!genericOp) + return nullptr; + + Value newOutput = genericOp.getOutputs()[resultNumber]; + Value tiledVal = tileAndFuseResult.tiledValues[0]; + Operation *newLoop = nullptr; + + // 2. Branch based on loop type + if (auto forallOp = dyn_cast(containingOp)) { + newLoop = + appendToForall(rewriter, forallOp, newOutput, tiledVal, offsets, sizes); + } else if (auto forOp = dyn_cast(containingOp)) { + newLoop = appendToFor(rewriter, forOp, newOutput, tiledVal, offsets, sizes); + } + + if (!newLoop) + return nullptr; + + // 3. Update IR usage inside the loop + BlockArgument newBBArg = newLoop->getRegion(0).getArguments().back(); + rewriter.replaceUsesWithIf(newOutput, newBBArg, [&](OpOperand &use) { + return newLoop->isProperAncestor(use.getOwner()); + }); + + // 4. Connect external dominated users to the new loop result. + // The new loop has the appended result at the end. + Value newLoopResult = newLoop->getResults().back(); + if (!dominatedUsers.empty() || hasCrossSubStageAttr) { + rewriter.replaceUsesWithIf( + producerResult, newLoopResult, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return !newLoop->isAncestor(owner) && !owner->hasAttr(kHadFusedAttr); + }); + } + return newLoop; +} + +static Value tryRankReduce(RewriterBase &rewriter, Location loc, Value value, + Type targetType) { + if (value.getType() == targetType) + return value; + + auto targetRT = dyn_cast(targetType); + if (!targetRT) + return nullptr; + + auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( + rewriter, loc, value, targetRT.getShape()); + + if (succeeded(maybeRankReduced) && + maybeRankReduced->getType() == targetType) { + return *maybeRankReduced; + } + return nullptr; +} + +static bool replaceSubsetExtraction(RewriterBase &rewriter, Operation *op, + Value tiledValue) { + if (op->getNumResults() != 1) { + return false; + } + + Value replacement = tryRankReduce(rewriter, op->getLoc(), tiledValue, + op->getResult(0).getType()); + + if (!replacement) { + LDBG(" [SubsetExtraction] Shape mismatch: " + << tiledValue.getType() << " vs " << op->getResult(0).getType()); + return false; + } + + rewriter.replaceOp(op, replacement); + return true; +} + +static bool replaceParallelInsertSlice(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp pInsert, + Value tiledValue, Value originalValue) { + // warning + pInsert->emitWarning() + << "Currently, no processing is performed on the ParallelInsertSliceOp."; + return false; +} + +static bool replaceInsertSlice(RewriterBase &rewriter, + tensor::InsertSliceOp insertOp, Value tiledValue, + Value originalValue) { + const bool isSource = insertOp.getSource() == originalValue; + const bool isDest = insertOp.getDest() == originalValue; + + // Nothing to do if the original value is neither source nor destination. + if (!isSource && !isDest) + return false; + + // Helper: check whether two values have identical ranked tensor shapes. + auto hasSameRankedShape = [](Value a, Value b) -> bool { + auto aType = dyn_cast(a.getType()); + auto bType = dyn_cast(b.getType()); + if (!aType || !bType) + return true; // Non-ranked types are considered compatible. + return aType.getShape() == bType.getShape(); + }; + + // If we are replacing the destination, require shape compatibility between + // the insert_slice result and the tiled value. + if (isDest && !hasSameRankedShape(insertOp.getResult(), tiledValue)) { + LDBG("replaceInsertSlice: result shape does not match tiledValue shape; " + "aborting replacement"); + return false; + } + + Value newSource = insertOp.getSource(); + Value newDest = insertOp.getDest(); + Location loc = insertOp.getLoc(); + + // Update source operand if needed. + if (isSource) { + if (Value reduced = tryRankReduce(rewriter, loc, tiledValue, + insertOp.getSourceType())) { + newSource = reduced; + } else { + // Source matched but cannot be rank-reduced; keep traversal going. + return true; + } + } + + // Update destination operand if needed. + if (isDest) { + if (Value reduced = + tryRankReduce(rewriter, loc, tiledValue, insertOp.getDestType())) { + newDest = reduced; + } else if (Value fallback = tryRankReduce(rewriter, loc, tiledValue, + insertOp.getSourceType())) { + // Fallback: match the source (slice) shape. + newDest = fallback; + } else { + // Destination matched but cannot be adapted. + return true; + } + } + + rewriter.setInsertionPoint(insertOp); + rewriter.replaceOpWithNewOp( + insertOp, newSource, newDest, insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + + return true; +} + +static bool replaceSliceWithTiledValue(RewriterBase &rewriter, Operation *op, + Value tiledValue, Value originalValue) { + OpBuilder::InsertionGuard guard(rewriter); + LDBG("Replacing subset op: " << *op << " ; replace by: " << tiledValue); + + if (isa(op)) { + return replaceSubsetExtraction(rewriter, op, tiledValue); + } + + if (auto pInsert = dyn_cast(op)) { + return replaceParallelInsertSlice(rewriter, pInsert, tiledValue, + originalValue); + } + + if (auto insertOp = dyn_cast(op)) { + return replaceInsertSlice(rewriter, insertOp, tiledValue, originalValue); + } + + llvm_unreachable("Expected subset extraction or insertion op"); +} + +/// Return an IRMapping that maps operands of `producerOp` (must be a +/// linalg::LinalgOp) to the corresponding block arguments inside `loopOp`. +/// +/// - For scf.for: maps initArgs[i] -> bodyArg(i+1) (bodyArg 0 is induction +/// var). +/// - For scf.forall: maps outputs[i] -> bodyArg(numIVs + i). +/// If no mapping found (producer is not linalg or loop type unsupported) +/// an empty IRMapping is returned. +mlir::IRMapping mapProducerOperandsToLoopArgs(Operation *producerOp, + Operation *loopOp) { + mlir::IRMapping mapping; + + // require producer to be a LinalgOp + auto linalgOp = dyn_cast_or_null(producerOp); + if (!linalgOp) { + LLVM_DEBUG(llvm::dbgs() + << "[mapProducerOperandsToLoopArgs] " + "producer is not a LinalgOp. Returning empty map.\n"); + return mapping; + } + + // Build a temporary map from loop-defining Value -> corresponding block arg + llvm::DenseMap loopValueToArg; + + // scf.for : init args -> body arguments (body arg 0 = iv) + if (auto forOp = dyn_cast(loopOp)) { + Block &body = forOp.getRegion().front(); + auto initArgs = forOp.getInitArgs(); + // body args: [iv, loop-carried-arg0, loop-carried-arg1, ...] + for (unsigned i = 0, e = initArgs.size(); i != e; ++i) { + unsigned bodyArgIdx = i + 1; + if (bodyArgIdx < body.getNumArguments()) + loopValueToArg.try_emplace(initArgs[i], body.getArgument(bodyArgIdx)); + else + LLVM_DEBUG(llvm::dbgs() + << "[mapProducerOperandsToLoopArgs] " + "forOp body does not have expected arg index.\n"); + } + } + // scf.forall : outputs -> body args after induction vars + else if (auto forallOp = dyn_cast(loopOp)) { + Block &body = forallOp.getRegion().front(); + unsigned numIVs = forallOp.getInductionVars().size(); + auto outputs = forallOp.getOutputs(); + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + unsigned bodyArgIdx = numIVs + i; + if (bodyArgIdx < body.getNumArguments()) + loopValueToArg.try_emplace(outputs[i], body.getArgument(bodyArgIdx)); + else + LLVM_DEBUG(llvm::dbgs() + << "[mapProducerOperandsToLoopArgs] " + "forallOp body does not have expected arg index.\n"); + } + } else { + // unsupported loop type -> return empty mapping + LLVM_DEBUG( + llvm::dbgs() + << "[mapProducerOperandsToLoopArgs] " + "loopOp is not scf::ForOp or scf::ForallOp. Returning empty map.\n"); + return mapping; + } + + // Now, for each operand of the linalg op, map it if it matches a + // loop-defining value. + for (Value operand : linalgOp->getOperands()) { + auto it = loopValueToArg.find(operand); + if (it != loopValueToArg.end()) { + mapping.map(operand, it->second); + LLVM_DEBUG(llvm::dbgs() + << "[mapProducerOperandsToLoopArgs] " + "mapped operand " + << operand << " -> loop block-arg " << it->second << "\n"); + } + } + + return mapping; +} + +static void applyMappingToGeneratedSlices(IRMapping &mapping, + TilingResult &tilingResult, + RewriterBase &rewriter) { + + for (unsigned i = 0, e = tilingResult.generatedSlices.size(); i < e; ++i) { + auto sliceOp = tilingResult.generatedSlices[i]; + + auto extract = dyn_cast(sliceOp); + if (!extract) + continue; + + Value mappedBase = mapping.lookupOrNull(extract.getSource()); + if (!mappedBase) + continue; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(extract); + auto newSlice = rewriter.create( + extract.getLoc(), extract.getType(), mappedBase, + extract.getMixedOffsets(), extract.getMixedSizes(), + extract.getMixedStrides()); + + tilingResult.generatedSlices[i] = newSlice; + // 如果你希望完全替换原来的 extract(把所有 uses 都指向 + // newSlice),取消注释下一行: + rewriter.replaceOp(extract, newSlice); + } +} + +std::tuple, Operation *> +mlir::dicp::tileAndFuseAllSubsetOps(RewriterBase &rewriter, Diagnostic &diag, + Operation *producerOp, + Operation *containingOp, + bool duplicateProducer) { + LLVM_DEBUG(DBGS() << "Try to fuse all extract uses for producer: " + << *producerOp << "\n"); + + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) { + diag.attachNote(producerOp->getLoc()) + << "producer is not a TilingInterface: " << *producerOp; + return {}; + } + + // Identify valid subset users inside containingOp. + SmallVector subsetOps; + for (Operation *user : producerOp->getUsers()) { + if (!containingOp->isProperAncestor(user) || !isSubsetOp(user)) + continue; + LDBG(" Found candidate slice user: " << *user); + subsetOps.push_back(user); + } + + llvm::sort(subsetOps, [](Operation *a, Operation *b) { + if (a->getBlock() == b->getBlock()) + return a->isBeforeInBlock(b); + return a->getBlock()->getParentOp()->isAncestor( + b->getBlock()->getParentOp()); + }); + + if (subsetOps.empty()) { + diag.attachNote(producerOp->getLoc()) + << "could not find fusion opportunity for: " << *producerOp; + return {}; + } + + // Group by result number. + std::map> resultToSubsetOps; + for (Operation *op : subsetOps) { + unsigned resNum = cast(op->getOperand(0)).getResultNumber(); + resultToSubsetOps[resNum].push_back(op); + } + + SmallVector tiledOps; + Operation *currentContainingOp = containingOp; + + for (auto &entry : resultToSubsetOps) { + unsigned resultNumber = entry.first; + SmallVector &ops = entry.second; + Operation *firstSliceOp = ops.front(); + + auto sliceInterface = cast(firstSliceOp); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(firstSliceOp); + + SmallVector offsets = sliceInterface.getMixedOffsets(); + SmallVector sizes = sliceInterface.getMixedSizes(); + + FailureOr result = tileableProducer.generateResultTileValue( + rewriter, resultNumber, offsets, sizes); + + if (failed(result)) { + diag.attachNote(tileableProducer->getLoc()) + << "failed to tile producer op: " << *tileableProducer; + return {}; + } + mlir::IRMapping mapping = + mapProducerOperandsToLoopArgs(producerOp, containingOp); + if (!mapping.getValueMap().empty()) { + applyMappingToGeneratedSlices(mapping, *result, rewriter); + } + + tiledOps.append(result->tiledOps.begin(), result->tiledOps.end()); + + // Replace all subset ops in this group. + Value tiledValue = result->tiledValues[0]; + Value originalValue = producerOp->getResult(resultNumber); + for (Operation *opToReplace : ops) { + replaceSliceWithTiledValue(rewriter, opToReplace, tiledValue, + originalValue); + } + + if (duplicateProducer) { + continue; + } + + // Update containing op signature if needed. + Operation *newContainingOp = + appendLoopResultAndFuse(rewriter, diag, producerOp, currentContainingOp, + *result, resultNumber, offsets, sizes); + + if (newContainingOp) { + currentContainingOp = newContainingOp; + } + } + + return std::make_tuple(tiledOps, currentContainingOp); +} + +static BlockArgument getTiedBlockArgument(LoopLikeOpInterface loop, + OpOperand *opOperand) { + if (auto forallOp = dyn_cast(loop.getOperation())) + return forallOp.getTiedBlockArgument(opOperand); + + if (auto forOp = dyn_cast(loop.getOperation())) { + auto inits = forOp.getInits(); + auto it = llvm::find(inits, opOperand->get()); + if (it != inits.end()) { + unsigned initIdx = std::distance(inits.begin(), it); + return forOp.getRegionIterArgs()[initIdx]; + } + } + return nullptr; +} + +SmallVector +mlir::dicp::tileAndFuseAllSubsetOpsThroughContainingOpBlockArgument( + RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, + LoopLikeOpInterface containingOp) { + LLVM_DEBUG(DBGS() << "Try to fuse extract uses through block argument for: " + << *producerOp << "\n"); + + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) { + diag.attachNote(producerOp->getLoc()) + << "producer is not a TilingInterface: " << *producerOp; + return {}; + } + + // Find use by containing op. + OpOperand *pUse = nullptr; + for (OpOperand &use : producerOp->getUses()) { + if (use.getOwner() == containingOp.getOperation()) { + pUse = &use; + break; + } + } + + if (!pUse) { + diag.attachNote(producerOp->getLoc()) + << "could not find a use by the containing op: " << *producerOp; + return {}; + } + + BlockArgument bbArg = getTiedBlockArgument(containingOp, pUse); + if (!bbArg) { + diag.attachNote(containingOp.getLoc()) + << "containing op does not have a tied block argument"; + return {}; + } + + SmallVector subsetOps; + for (Operation *user : bbArg.getUsers()) { + if (!containingOp->isProperAncestor(user) || !isSubsetOp(user)) + continue; + LDBG(" Found candidate slice user: " << *user); + subsetOps.push_back(user); + } + + llvm::sort(subsetOps, [](Operation *a, Operation *b) { + if (a->getBlock() == b->getBlock()) + return a->isBeforeInBlock(b); + return a->getBlock()->getParentOp()->isAncestor( + b->getBlock()->getParentOp()); + }); + + if (subsetOps.empty()) { + diag.attachNote(containingOp.getLoc()) + << "could not find fusion opportunity for bbArg: " << bbArg; + return {}; + } + + SmallVector tiledOps; + Operation *firstSliceOp = subsetOps.front(); + auto sliceInterface = cast(firstSliceOp); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(firstSliceOp); + + int64_t resultNumber = cast(pUse->get()).getResultNumber(); + + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations( + rewriter, tileableProducer->getLoc(), tileableProducer, + destinationTensors))) { + diag.attachNote(tileableProducer->getLoc()) + << "failed to get destination tensors for: " << *tileableProducer; + return {}; + } + + // Clone producer to map destination to block arg, then tile the clone. + IRMapping bvm; + bvm.map(destinationTensors[resultNumber], bbArg); + auto tileableProducerClone = + cast(rewriter.clone(*tileableProducer, bvm)); + + auto scopeGuard = + llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); + + FailureOr tileAndFuseResult = + tileableProducerClone.generateResultTileValue( + rewriter, resultNumber, sliceInterface.getMixedOffsets(), + sliceInterface.getMixedSizes()); + + if (failed(tileAndFuseResult)) { + diag.attachNote(tileableProducer->getLoc()) + << "failed to tile producer op: " << *tileableProducer; + return {}; + } + + tiledOps.append(tileAndFuseResult->tiledOps.begin(), + tileAndFuseResult->tiledOps.end()); + Value tiledValueToReplace = tileAndFuseResult->tiledValues[0]; + + for (Operation *sliceOpToReplace : subsetOps) { + replaceSliceWithTiledValue(rewriter, sliceOpToReplace, tiledValueToReplace, + bbArg); + } + + // Update containing op operand to point to destination. + (void)tensor::getOrCreateDestinations(rewriter, tileableProducer->getLoc(), + tileableProducer, destinationTensors); + rewriter.modifyOpInPlace(containingOp, [&]() { + containingOp->setOperand(pUse->getOperandNumber(), + destinationTensors.front()); + }); + + return tiledOps; +} + +Operation *mlir::dicp::cloneAndFuseAllSubsetOps(RewriterBase &rewriter, + Diagnostic &diag, + Operation *producerOp, + Operation *containingOp) { + LDBG("Try to fuse all uses by cloning for: " << *producerOp); + + // If the producer has cross-substage users, cloning is not allowed because + // it would break the requirement of maintaining a single consistent tensor + // state. + if (producerOp->hasAttr(kCrossTillUnitAttr)) { + diag.attachNote(producerOp->getLoc()) + << "cloning fusion is prohibited for ops with cross-substage users; " + << "use tiling-based fusion instead to maintain tensor state."; + return nullptr; + } + + SmallVector usesToFuse; + for (OpResult result : producerOp->getOpResults()) { + for (OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + if (containingOp->isProperAncestor(user)) { + usesToFuse.push_back(&use); + } + } + } + + if (usesToFuse.empty()) { + diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; + return nullptr; + } + + Operation *lastFusedOp = nullptr; + for (OpOperand *use : usesToFuse) { + unsigned resultNumber = cast(use->get()).getResultNumber(); + Operation *user = use->getOwner(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(user); + // Case 1: bufferization.to_buffer -> memref.subview + if (auto toBufferOp = dyn_cast(producerOp)) { + if (auto subViewOp = dyn_cast(user)) { + LDBG("Special Case: Pushing memref.subview across " + "bufferization.to_buffer"); + + auto sliceOp = rewriter.create( + subViewOp.getLoc(), toBufferOp.getTensor(), + subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), + subViewOp.getMixedStrides()); + + auto newToBuffer = rewriter.create( + toBufferOp.getLoc(), subViewOp.getType(), sliceOp.getResult()); + + rewriter.replaceOp(subViewOp, newToBuffer.getResult()); + lastFusedOp = newToBuffer; + continue; + } + } + // Case 2: bufferization.to_tensor -> tensor.extract_slice + if (auto toTensorOp = dyn_cast(producerOp)) { + if (auto extractSliceOp = dyn_cast(user)) { + LDBG("Special Case: Pushing tensor.extract_slice across " + "bufferization.to_tensor"); + + auto subViewOp = rewriter.create( + extractSliceOp.getLoc(), toTensorOp.getBuffer(), + extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), + extractSliceOp.getMixedStrides()); + + auto newToTensor = rewriter.create( + toTensorOp.getLoc(), extractSliceOp.getType(), + subViewOp.getResult()); + + rewriter.replaceOp(extractSliceOp, newToTensor.getResult()); + lastFusedOp = newToTensor; + continue; + } + } + // Default Case: Standard cloning for other ops + Operation *fusedOp = rewriter.clone(*producerOp); + rewriter.modifyOpInPlace( + user, [&] { use->set(fusedOp->getOpResult(resultNumber)); }); + lastFusedOp = fusedOp; + } + + return lastFusedOp; +} + +//===----------------------------------------------------------------------===// +// TransformApplier +//===----------------------------------------------------------------------===// + +void TransformApplier::apply(ModuleOp module, + TransformGenerationCallback generator) { + LLVM_DEBUG(llvm::dbgs() + << "[TransformApplier] Applying unified transformation...\n"); + + // Clone module to isolate transformation attempts + ModuleOp cloned = module.clone(); + MLIRContext *ctx = module.getContext(); + + if (!cloned->hasAttr("transform.with_named_sequence")) { + cloned->setAttr("transform.with_named_sequence", UnitAttr::get(ctx)); + } + + OpBuilder builder(cloned.getBodyRegion()); + Location loc = cloned.getLoc(); + std::string seqName = "__transform_main"; + Type rootType = builder.getType(); + + auto seqOp = builder.create( + loc, seqName, rootType, TypeRange{}, + [&](OpBuilder &b, Location bodyLoc, BlockArgument rootHandle) { + // Invoke the specific generator callback provided by the pass + generator(b, bodyLoc, rootHandle); + b.create(bodyLoc); + }); + + auto entryPoint = cast(seqOp.getOperation()); + transform::TransformOptions options; + + if (failed(transform::applyTransformNamedSequence(cloned, entryPoint, {}, + options))) { + LLVM_DEBUG( + llvm::dbgs() + << "[TransformApplier] Application of transform sequence failed.\n"); + cloned->emitWarning("Failed to apply unified transformation, reverting."); + cloned->erase(); + return; + } + + // Success: Replace original body + entryPoint.erase(); + module.getBodyRegion().getBlocks().clear(); + IRMapping map; + cloned.getBodyRegion().cloneInto(&module.getBodyRegion(), + module.getBodyRegion().begin(), map); + cloned->erase(); + LLVM_DEBUG(llvm::dbgs() + << "[TransformApplier] Transformation applied successfully.\n"); +} \ No newline at end of file diff --git a/compiler/lib/Utils/CMakeLists.txt b/compiler/lib/Utils/CMakeLists.txt index 9315bf5c..c8839039 100644 --- a/compiler/lib/Utils/CMakeLists.txt +++ b/compiler/lib/Utils/CMakeLists.txt @@ -4,4 +4,6 @@ add_triton_library(DICPUtils LINK_LIBS PUBLIC MLIRIR TritonIR + MLIRTransformDialect + MLIRTransformDialectUtils ) \ No newline at end of file diff --git a/dicp_triton.cc b/dicp_triton.cc deleted file mode 100644 index e7a286c7..00000000 --- a/dicp_triton.cc +++ /dev/null @@ -1,5 +0,0 @@ -#include - -namespace py = pybind11; - -void init_triton_dicp_triton(py::module &&m) {} diff --git a/test/ascend/passed_tests/test_cv_unroll_pipleine.py b/test/ascend/passed_tests/test_cv_unroll_pipleine.py new file mode 100644 index 00000000..fb281581 --- /dev/null +++ b/test/ascend/passed_tests/test_cv_unroll_pipleine.py @@ -0,0 +1,756 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.deeplink as dl +import torch_npu +import triton.testing + +DEVICE = "npu" + + +def require_npu(): + try: + torch.empty(1, device=DEVICE) + except Exception: + pytest.skip("npu device not available") + + +# ------------------- Triton kernels (kept same functional implementation) ------------------- +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + qk_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX: tl.constexpr, + fp8_v: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + else: + lo, hi = 0, N_CTX + + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp(qk) + p_cast = p.to(tl.float16) + v = tl.load(V_block_ptr) + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + pv + + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + M, + Out, + sm_scale: tl.constexpr, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + NUM_BLOCKS_PER_CORE: tl.constexpr, + NUM_BLOCKS: tl.constexpr, + NUM_BLOCKS_M: tl.constexpr, +): + pid = tl.program_id(0) + for block_idx in range(pid, NUM_BLOCKS, 24): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_block_ptr) + + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + task_m_idx, + sm_scale, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + N_CTX, + V.dtype.element_ty == tl.float8e5, + ) + + m_i += tl.math.log(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_split_cv( + Q, + K, + V, + M, + Out, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + stride_qz: tl.constexpr, + stride_qh: tl.constexpr, + stride_qm: tl.constexpr, + stride_qk: tl.constexpr, + stride_kz: tl.constexpr, + stride_kh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kk: tl.constexpr, + stride_vz: tl.constexpr, + stride_vh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vk: tl.constexpr, + stride_oz: tl.constexpr, + stride_oh: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + w1_stride_nb: tl.constexpr, + w1_stride_bm: tl.constexpr, + w1_stride_bn: tl.constexpr, + w2_stride_nb: tl.constexpr, + w2_stride_bm: tl.constexpr, + w2_stride_bn: tl.constexpr, + w3_stride_nb: tl.constexpr, + w3_stride_bm: tl.constexpr, + w3_stride_dm: tl.constexpr, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_CORES: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + NUM_BLOCKS_M = N_CTX // BLOCK_M + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + pid = tl.program_id(0) + for block_idx in tl.range(pid, NUM_BLOCKS, NUM_CORES): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + q = tl.load(Q_block_ptr) + K_block_ptr = tl.advance(K_block_ptr, (0, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, 0)) + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + lo, hi = 0, N_CTX + for start_n in range(lo, hi, BLOCK_N * NUM_STAGES): + for i in tl.range(0, NUM_STAGES, num_stages=NUM_STAGES): + workspace_1_ptr = tl.make_block_ptr( + base=workspace_1 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w1_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w1_stride_bm, w1_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + workspace_2_ptr = tl.make_block_ptr( + base=workspace_2 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w2_stride_nb, + shape=(BLOCK_M, BLOCK_N), + strides=(w2_stride_bm, w2_stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + workspace_3_ptr = tl.make_block_ptr( + base=workspace_3 + + (NUM_STAGES * block_idx.to(tl.int64) + i) * w3_stride_nb, + shape=(BLOCK_M, HEAD_DIM), + strides=(w3_stride_bm, w3_stride_dm), + offsets=(0, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + with dl.async_task(scope=dl.async_task.cube): + k = tl.load(K_block_ptr) + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + tl.store(workspace_1_ptr, qk) + + dl.set_cross_flag(dl.SyncFlag.C2V, 0) + dl.wait_cross_flag(dl.SyncFlag.V2C, 1) + + p_cast = tl.load(workspace_2_ptr) + v = tl.load(V_block_ptr) + acc_l0c = tl.dot(p_cast, v) + tl.store(workspace_3_ptr, acc_l0c) + dl.set_cross_flag(dl.SyncFlag.C2V, 2) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + + with dl.async_task(scope=dl.async_task.vector): + dl.wait_cross_flag(dl.SyncFlag.C2V, 0) + + qk = tl.load(workspace_1_ptr) + qk = qk * sm_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp(qk) + p_cast = p.to(Q.type.element_ty) + tl.store(workspace_2_ptr, p_cast) + dl.set_cross_flag(dl.SyncFlag.V2C, 1) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + dl.wait_cross_flag(dl.SyncFlag.C2V, 2) + acc_ptr = acc_ptr * alpha[:, None] + acc_o_ub = tl.load(workspace_3_ptr) + acc_ptr = acc_ptr + acc_o_ub + m_i = m_ij + + m_i += tl.math.log(l_i) + accumulator = acc_ptr / l_i[:, None] + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +# ------------------- Python wrappers and Function classes ------------------- +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + stage = 3 if causal else 1 + num_cores = 24 + NUM_BLOCKS_M = triton.cdiv(q.shape[2], BM) + NUM_BLOCKS = NUM_BLOCKS_M * q.shape[0] * q.shape[1] + NUM_BLOCKS_PER_CORE = triton.cdiv(NUM_BLOCKS, num_cores) + + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + _attn_fwd[(num_cores,)]( + q, + k, + v, + M, + o, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + NUM_BLOCKS_PER_CORE=NUM_BLOCKS_PER_CORE, + NUM_BLOCKS=NUM_BLOCKS, + NUM_BLOCKS_M=NUM_BLOCKS_M, + multibuffer=True, + unit_flag=True, + debug=False, + ) + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + +@triton.jit +def _attn_fwd_split_cv_launcher( + Q, + K, + V, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q_stride0, + q_stride1, + q_stride2, + q_stride3, + k_stride0, + k_stride1, + k_stride2, + k_stride3, + v_stride0, + v_stride1, + v_stride2, + v_stride3, + o_stride0, + o_stride1, + o_stride2, + o_stride3, + w1_nb, + w1_bm, + w1_bn, + w2_nb, + w2_bm, + w2_bn, + w3_nb, + w3_bm, + w3_dm, + Z: tl.constexpr, + H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_CORES: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + _attn_fwd_split_cv( + Q, + K, + V, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q_stride0, + q_stride1, + q_stride2, + q_stride3, + k_stride0, + k_stride1, + k_stride2, + k_stride3, + v_stride0, + v_stride1, + v_stride2, + v_stride3, + o_stride0, + o_stride1, + o_stride2, + o_stride3, + w1_nb, + w1_bm, + w1_bn, + w2_nb, + w2_bm, + w2_bn, + w3_nb, + w3_bm, + w3_dm, + Z=Z, + H=H, + N_CTX=N_CTX, + HEAD_DIM=HEAD_DIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_CORES=NUM_CORES, + NUM_STAGES=NUM_STAGES, + ) + + +class AttentionSplitCV(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale, BM, BN, causal=False): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + extra_kern_args = {} + + o = torch.empty_like(q) + N_CTX = q.shape[2] + Z, H = q.shape[0], q.shape[1] + NUM_BLOCKS_M = N_CTX // BM + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + DIM = q.shape[-1] + NUM_CORES = 24 + NUM_STAGES = 4 + acc = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), + dtype=torch.float32, + device=q.device, + ) + M = torch.empty( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + workspace_1 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=torch.float32 + ) + workspace_2 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, BN), device=q.device, dtype=q.dtype + ) + workspace_3 = torch.empty( + (NUM_STAGES, NUM_BLOCKS, BM, DIM), device=q.device, dtype=torch.float32 + ) + + _attn_fwd_split_cv_launcher[(NUM_CORES,)]( + q, + k, + v, + M, + o, + acc, + sm_scale, + workspace_1, + workspace_2, + workspace_3, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + workspace_1.stride(1), + workspace_1.stride(2), + workspace_1.stride(3), + workspace_2.stride(1), + workspace_2.stride(2), + workspace_2.stride(3), + workspace_3.stride(1), + workspace_3.stride(2), + workspace_3.stride(3), + q.shape[0], + q.shape[1], + N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + NUM_CORES=NUM_CORES, + NUM_STAGES=NUM_STAGES, + disable_auto_inject_block_sync=True, + disable_auto_cv_work_space_manage=True, + **extra_kern_args, + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + return o + + +attention_base = _attention.apply +attention_split_cv = AttentionSplitCV.apply + +# ------------------- Tests (expanded to include all original test_op cases) ------------------- +ALL_CASES = [ + (1, 2, 2048, 128, 64, 128, False), + (4, 32, 1024, 64, 64, 256, False), + # 超长序列 + (1, 1, 1024 * 32, 128, 64, 128, False), + # 中等规模 + (4, 4, 512, 128, 16, 128, False), + (8, 32, 512, 256, 16, 128, False), + # 小序列 / tile 多样性 + (32, 32, 64, 64, 64, 16, False), + (32, 32, 128, 128, 64, 32, False), + (32, 32, 256, 128, 64, 64, False), + # 常见 LLM 配置 + (1, 8, 1024, 64, 64, 128, False), + (8, 12, 512, 64, 128, 128, False), + # 长上下文 + (1, 16, 2048, 128, 64, 128, False), + (1, 32, 4096, 128, 64, 128, False), +] + + +@pytest.mark.xdist_group(name="attention_ref_group") +@pytest.mark.parametrize("Z,H,N_CTX,HEAD_DIM,BM,BN,causal", ALL_CASES) +def test_attention_matches_reference_all(Z, H, N_CTX, HEAD_DIM, BM, BN, causal): + require_npu() + torch.manual_seed(20) + dtype = torch.float16 + + q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( + mean=0.0, std=0.5 + ) + + sm_scale = 0.5 + + ref_out = torch_npu.npu_fusion_attention( + q, + k, + v, + H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + tri_out_base = attention_base(q, k, v, sm_scale, BM, BN, causal).half() + tri_out_cv = attention_split_cv(q, k, v, sm_scale, BM, BN, causal).half() + + atol = 1e-3 + rtol = 0.0 + + assert torch.allclose(ref_out, tri_out_base, atol=atol, rtol=rtol) + assert torch.allclose(ref_out, tri_out_cv, atol=atol, rtol=rtol) + + +# # Performance test for the ultra-long sequence; asserts torch_time < tri_time* 0.30 +# @pytest.mark.xdist_group(name="test_perf_long_sequence") +# def test_perf_long_sequence(): +# require_npu() +# torch.manual_seed(20) +# Z, H, N_CTX, HEAD_DIM, BM, BN, causal = (1, 1, 1024 * 64, 128, 64, 256, False) +# dtype = torch.float16 + +# q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) +# k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) +# v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_( +# mean=0.0, std=0.5 +# ) + +# sm_scale = 0.5 + +# # Warmup & rep counts: keep modest to avoid extremely long CI runs; adjust as needed. +# warmup = 50 +# rep = 50 + +# # measure torch_npu fused attention +# torch_fn = lambda: torch_npu.npu_fusion_attention( +# q, +# k, +# v, +# H, +# padding_mask=None, +# atten_mask=None, +# scale=sm_scale, +# keep_prob=1.0, +# input_layout="BNSD", +# pre_tockens=65535, +# next_tockens=65535, +# sparse_mode=0, +# )[0] + +# tri_fn = lambda: attention_split_cv(q, k, v, sm_scale, BM, BN, causal) + +# # triton.testing.do_bench returns ms +# torch_ms = triton.testing.do_bench(torch_fn, warmup=warmup, rep=rep) +# tri_ms = triton.testing.do_bench(tri_fn, warmup=warmup, rep=rep) + +# # print for visibility when running tests +# print( +# f"torch_npu fusion ms: {torch_ms:.3f} ms; triton split_cv ms: {tri_ms:.3f} ms" +# ) + +# # require triton to be faster than 30% of torch time +# assert ( +# torch_ms > tri_ms * 0.30 +# ), f"triton({torch_ms :.3f}ms) must be < 30% of triton({tri_ms:.3f}ms)" diff --git a/tools/dicp_triton_opt/CMakeLists.txt b/tools/dicp_triton_opt/CMakeLists.txt index bd00d88a..43b43690 100644 --- a/tools/dicp_triton_opt/CMakeLists.txt +++ b/tools/dicp_triton_opt/CMakeLists.txt @@ -8,6 +8,7 @@ target_link_libraries(dicp_opt PRIVATE TritonAnalysis TritonTransforms TritonGPUTransforms + TritonNvidiaGPUTransforms TritonSharedAnalysis ${dialect_libs} ${translation_libs} @@ -22,7 +23,11 @@ target_link_libraries(dicp_opt PRIVATE LinkedToHIVM DICPLinalgExt DiscreteMaskAccessConversion + DICPTransformOps + MLIRLinalgTransformOps + BiShengIRHIVMDialect + LinalgExtAnalysis TritonToLinalg TritonTilingExtIR TritonToLinalgNPUCoversion diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index dccd884e..f80f9218 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -9,32 +9,21 @@ #include "dicp/Dialect/LinalgExt/Transforms/Passes.h" #include "dicp/Dialect/NPU/IR/NPUDialect.h" #include "dicp/Dialect/TritonExt/Transforms/Passes.h" +#include "dicp/TransformOps/DicpTransformOps.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" -#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" @@ -42,47 +31,53 @@ #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" -#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" -#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" #include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" -#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Pass/Pass.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h.inc" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + using namespace mlir; inline void registerDICPDialects(mlir::DialectRegistry ®istry) { @@ -105,16 +100,34 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerNPUUnroolPipelinePass(); + dicp::LinalgExt::registerNPUVectorTileTaggingPass(); + dicp::LinalgExt::registerNPUVectorTileTransformPass(); + dicp::LinalgExt::registerDeLinalgizePass(); + dicp::LinalgExt::registerFuseLoopPass(); + dicp::LinalgExt::registerLoopUnrollStagePass(); + + mlir::dicp::registerTransformDialectExtension(registry); + mlir::linalg::registerTransformDialectExtension(registry); + + affine::registerValueBoundsOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + linalg::registerAllDialectInterfaceImplementations(registry); + + + scf::registerTransformDialectExtension(registry); registry.insert(); + linalg::LinalgDialect, LLVM::LLVMDialect, math::MathDialect, + memref::MemRefDialect, pdl::PDLDialect, scf::SCFDialect, + tensor::TensorDialect, transform::TransformDialect, + vector::VectorDialect, ub::UBDialect, triton::TritonDialect, + affine::AffineDialect, ttx::TritonTilingExtDialect, + mlir::hivm::HIVMDialect>(); } int main(int argc, char **argv) { diff --git a/triton_dicp_triton.cc b/triton_dicp_triton.cc index 979e0071..3d14ee69 100644 --- a/triton_dicp_triton.cc +++ b/triton_dicp_triton.cc @@ -8,30 +8,83 @@ #include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h" #include "dicp/Dialect/LinalgExt/Transforms/Passes.h" #include "dicp/Dialect/TritonExt/Transforms/Passes.h" - -#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" -#include "triton/Dialect/Triton/IR/Dialect.h" +#include "dicp/TransformOps/DicpTransformOps.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" + #include "llvm/IR/Constants.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + #include "passes.h" #include #include @@ -70,6 +123,26 @@ void init_triton_dicp_triton_pass_linked_npu(py::module &&m) { pm.addNestedPass( dicp::LinalgExt::createScalarTo1DTensorPass()); }); + m.def("add_npu_unroll_pipeline", [](mlir::PassManager &pm) { + pm.addNestedPass( + dicp::LinalgExt::createNPUUnroolPipelinePass()); + }); + m.def( + "add_npu_vector_tile_tagging", [](mlir::PassManager &pm, int vectorTile) { + pm.addPass( + mlir::dicp::LinalgExt::createNPUVectorTileTaggingPass(vectorTile)); + }); + ADD_PASS_WRAPPER_0("add_npu_vector_tile_transform", + dicp::LinalgExt::createNPUVectorTileTransformPass); + ADD_PASS_WRAPPER_0("add_de_linalgize", + dicp::LinalgExt::createDeLinalgizePass); + ADD_PASS_WRAPPER_0("add_fuse_loop", dicp::LinalgExt::createFuseLoopPass); + + m.def("add_loop_unroll_stage", [](mlir::PassManager &pm) { + pm.addNestedPass( + dicp::LinalgExt::createLoopUnrollStagePass()); + }); + m.def("add_linalg_to_linked", [](mlir::PassManager &pm, bool globalKernel, bool namedOps) { pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass(globalKernel, @@ -88,6 +161,8 @@ void init_triton_dicp_triton(py::module &&m) { // load dialects m.def("load_dialects", [](MLIRContext &context) { + llvm::errs() << ">>> [DEBUG] load_dialects ENTERED\n"; + llvm::errs() << ">>> [DEBUG] MLIRContext ptr: " << &context << "\n"; DialectRegistry registry; registry.insert(); - dicp::trtion_ext::registerCanonicalizeTritonIRAscendPass(); - dicp::trtion_ext::registerCanonicalizeCmpiPass(); - - dicp::linked::registerLinalgToLinkedPass(); - dicp::linked::registerLinkedToHIVMPass(); - dicp::linked::registerTritonToLinalgNPUCoversionPass(); + mlir::dicp::registerTransformDialectExtension(registry); + mlir::linalg::registerTransformDialectExtension(registry); - dicp::LinalgExt::registerLinalgIfToSelectPass(); - dicp::LinalgExt::registerLinalgGenericToSCFPass(); - dicp::LinalgExt::registerScalarTo1DTensorPass(); - dicp::LinalgExt::registerNormalizeSliceOpsPass(); + affine::registerValueBoundsOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + linalg::registerAllDialectInterfaceImplementations(registry); + scf::registerTransformDialectExtension(registry); + func::registerAllExtensions(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects();