Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,6 @@ launcher_cxx11abi*
# package
backend/triton-shared-opt-v3*
backend/dicp_opt
third_party/triton-shared-opt
third_party/triton-shared-opt

FA_FWD_PROF
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,24 @@ if (TRITON_BUILD_PYTHON_MODULE)
MLIRTransforms
MLIRSupport
MLIRBytecodeWriter
MLIRTransformDialect
${extension_libs}
MLIRRegisterAllPasses
MLIRLinalgTransforms
MLIRTensorTransforms
MLIRFuncAllExtensions

TritonToLinalgNPUCoversion

LinalgExtTransforms
TritonExtTransforms
LinalgExtAnalysis

LinalgToLinked
LinkedToHIVM
DiscreteMaskAccessConversion
TritonToUnstructure
DICPTransformOps
)
target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers)
endif()
10 changes: 7 additions & 3 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, target: str) -> None:
self.binary_ext = "mcfatbin"
elif self.driver.target == "ascend":
self.binary_ext = "npubin"
self.capability = target.arch
else:
raise RuntimeError(f"Target '{self.target_type}' is not supported.")

Expand Down Expand Up @@ -249,7 +250,7 @@ def add_stages(self, stages, options, language=None):
)
stages["npubin"] = (
lambda src, metadata: linalg_to_bin_enable_npu_compile(
src, metadata, options
src, metadata, options, self.capability
)
)
else:
Expand All @@ -264,17 +265,20 @@ def add_stages(self, stages, options, language=None):
)
stages["npubin"] = (
lambda src, metadata: linalg_to_bin_enable_npu_compile(
src, metadata, options
src, metadata, options, self.capability
)
)
else:
raise RuntimeError("backend not supported")

def load_dialects(self, ctx):
# TODO Warning If additional backends are integrated into the common IR with customized passes, their respective Dialect interfaces must be registered here. A decoupled registration mechanism for each backend is preferred to maintain modularity.
if self.driver.target == "mlu":
from triton._C.libtriton import mlu

mlu.load_dialects(ctx)
else:
from triton._C.libtriton import dicp_triton
dicp_triton.load_dialects(ctx)
return

@functools.lru_cache()
Expand Down
22 changes: 16 additions & 6 deletions backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,20 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm)
dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True)
dicp_triton.passes.linked_npu.add_linked_to_hivm(pm)
# dicp_triton.passes.linked_npu.add_npu_vector_tile_tagging(pm,2)
# dicp_triton.passes.linked_npu.add_npu_vector_tile_transform(pm)
# dicp_triton.passes.linked_npu.add_fuse_loop(pm)
# dicp_triton.passes.linked_npu.add_de_linalgize(pm)
# dicp_triton.passes.linked_npu.add_loop_unroll_stage(pm)
# dicp_triton.passes.linked_npu.add_npu_unroll_pipeline(pm)
pm.run(mod)

# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
content = str(mod)
dicp_utils._dump_stage_ir(
content, metadata["hash"], "kernel.linkedir.mlir.2"
)
if replace_linked_ir is not None:
content= Path(replace_linked_ir).read_text()
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
# 将"*xfxxx"替换成"?xfxxx"
content = content.replace("*xf", "?xf")
content = content.replace("*xi", "?xi")
Expand All @@ -520,9 +530,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
content, metadata["hash"], "kernel.linkedir.mlir", cmd_list
)

if replace_linked_ir is not None:
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
return Path(replace_linked_ir).read_text()

return content


Expand Down Expand Up @@ -683,7 +691,7 @@ def _parse_linalg_metadata(linalg: str, metadata: dict):
return linalg, metadata


def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt, capability):
linalg, metadata = _parse_linalg_metadata(linalg, metadata)
with tempfile.TemporaryDirectory() as tmpdir:
ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir")
Expand All @@ -706,6 +714,8 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
_compile_option_list += [
f"--enable-auto-multi-buffer={multibuffer}",
]
if capability:
_compile_option_list += [f"--target={capability}"]

if _is_ascend_sanitizer_enabled():
_compile_option_list += ["--enable-sanitizer=true"]
Expand Down
166 changes: 166 additions & 0 deletions compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H
#define DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H

#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"

#include <numeric>
#include <queue>
#include <vector>

namespace mlir {
namespace dicp {

/// Classification of a dimension's role in the computation graph.
/// This helps determine if a dimension is safe to tile or parallelize.
enum class DimKind {
Unknown, // No specific property inferred yet.
Parallel, // Dimension implies independent iterations (safe to tile).
Reduction, // Dimension is collapsed/reduced (requires accumulation).
Broadcast, // Dimension is replicated (data invariant along this axis).
Complex // Dimension undergoes complex transformation (e.g., non-affine
// reshape).
};

std::string toString(DimKind k);

/// Disjoint Set Union (DSU) for tracking dimension equivalence and properties.
///
/// This class implements a Disjoint Set data structure (Union-Find)
/// specifically designed for Tensor/MemRef dimensions. It serves two main
/// purposes:
/// 1. **Equivalence Tracking**: Determines which dimensions across different
/// values
/// represent the same logical axis (e.g., the 'N' dimension in a Matmul
/// propagating through element-wise adds).
/// 2. **Property Propagation**: Merges semantic properties (DimKind) when
/// dimensions
/// are unified. For example, if a dimension is used as a Reduction iterator
/// in one operation, that property propagates to all equivalent dimensions
/// in the set.
class DimensionDisjointSet {
public:
explicit DimensionDisjointSet(size_t size = 0) { resize(size); }

/// Allocates `n` new dimension IDs in the set.
/// \return The ID of the first allocated dimension.
int64_t allocate(size_t n = 1);

/// Finds the representative (root) ID for the set containing dimension `i`.
/// Implements path compression for amortized constant time lookups.
int64_t find(int64_t i);

/// Merges the sets containing dimensions `i` and `j`.
/// This also merges the `DimKind` properties of both roots using
/// `mergeKinds`.
void unionSets(int64_t i, int64_t j);

/// Updates the DimKind property for the set containing dimension `i`.
/// The new kind is merged with the existing kind to ensure safety (e.g.,
/// Reduction is sticky).
void setKind(int64_t i, DimKind k);

/// Retrieves the DimKind property of the set containing dimension `i`.
DimKind getKind(int64_t i);

private:
/// Resizes the internal storage to accommodate `n` dimensions.
void resize(size_t n);

/// Defines the logic for combining two dimension kinds.
/// Hierarchy of "stickiness": Complex > Reduction > Broadcast/Parallel.
DimKind mergeKinds(DimKind a, DimKind b);

std::vector<int64_t> parent; // Parent pointers for DSU.
std::vector<DimKind> kind; // Properties associated with each root.
};

/// DimAnalyzer:
/// Analyzes a specific execution stage (StageInfo) to determine tiling
/// strategies.
///
/// The analyzer constructs a constraint graph where nodes are tensor dimensions
/// and edges represent data flow relationships. It uses a Breadth-First Search
/// (BFS) approach to traverse operations and propagate dimension IDs.
///
/// Algorithm Overview:
/// 1. **Initialization**: Seeds the analysis with stage inputs (operands
/// defined outside the stage).
/// 2. **BFS Propagation**: Traverses the def-use chains. For each operation, it
/// uses specific handlers (e.g., processMatmulOp) to bind input dimensions to
/// output dimensions.
/// 3. **Anchor Heuristic**: Identifies the "Anchor" operation (typically the
/// final LinalgOp) to interpret the resulting loops.
/// 4. **Tiling Selection**: Checks the properties of the Anchor's loops in the
/// DSU to recommend outermost parallel loops for tiling.
class DimAnalyzer {
public:
explicit DimAnalyzer(const StageInfo &stage);

/// Analyzes the stage operations and returns indices of loops recommended for
/// tiling. The indices correspond to the loop nest of the "Anchor" operation.
SmallVector<int64_t> analyzeAndGetTilingDims();

private:
const StageInfo &stage_;
// Quick lookup for ops belonging to this stage.
DenseSet<Operation *> stageOps_;
DimensionDisjointSet dsu_;
// Maps SSA Value -> [Dim IDs]
DenseMap<Value, std::vector<int64_t>> valueDims_;

// BFS State passed to handlers to allow them to enqueue new values.
using BFSQueue = std::queue<Value>;
using VisitedSet = DenseSet<Value>;

/// Drives the traversal of the data flow graph.
void processBFS();

/// Dispatches the operation to the appropriate handler.
/// \return true if the operation was handled, false otherwise.
bool processOperation(Operation *op, Value current, BFSQueue &q,
VisitedSet &v);

/// Lazily retrieves or allocates unique IDs for the dimensions of a Value.
std::vector<int64_t> getOrAllocateDims(Value v);

/// Helper to strictly bind all dimensions of v1 to v2 (1-to-1 mapping).
/// Used for Elementwise, Copy, etc.
void bindDimensions(Value v1, Value v2);

// --- Op Handlers ---
// Each handler interprets the semantics of the op to union input/output
// dimensions correctly.

void processElementwise(Operation *op, Value current);
void processMatmulOp(linalg::MatmulOp op);
void processReduceOp(linalg::ReduceOp op);
void processTransposeOp(linalg::TransposeOp op);
void processBroadcastOp(linalg::BroadcastOp op);
void processLinalgOpGeneric(linalg::LinalgOp op);
void processReshapeOp(Operation *op);
void processConcatOp(tensor::ConcatOp op);
void processPadOp(tensor::PadOp op);
void processExtractSliceOp(tensor::ExtractSliceOp op);
void processInsertSliceOp(tensor::InsertSliceOp op);

// Handlers that may need to continue BFS propagation explicitly
void processMemrefCopyOp(memref::CopyOp op, Value current, BFSQueue &q,
VisitedSet &v);
void processMemrefCastOp(Operation *op);
void processBufferizationToTensor(bufferization::ToTensorOp op);
void processMaterializeOp(bufferization::MaterializeInDestinationOp op,
Value current, BFSQueue &q, VisitedSet &v);
};

} // namespace dicp
} // namespace mlir

#endif
Loading
Loading