diff --git a/mlir/dialects/__init__.py b/mlir/dialects/__init__.py index eefd918..9056c81 100644 --- a/mlir/dialects/__init__.py +++ b/mlir/dialects/__init__.py @@ -3,6 +3,7 @@ from .scf import scf as scf_dialect from .linalg import linalg from .func import func as func_dialect +from .memref import memref as memref_dialect -STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect] +STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect, memref_dialect] diff --git a/mlir/dialects/memref.py b/mlir/dialects/memref.py new file mode 100644 index 0000000..1131133 --- /dev/null +++ b/mlir/dialects/memref.py @@ -0,0 +1,31 @@ +""" Implementation of the Memref dialect. """ + +import inspect +import sys +from typing import List, Tuple, Optional, Union +from dataclasses import dataclass + +import mlir.astnodes as mast +from mlir.dialect import Dialect, DialectOp, is_op + +Literal = Union[mast.StringLiteral, float, int, bool] +SsaUse = Union[mast.SsaId, Literal] + +@dataclass +class LoadOperation(DialectOp): + arg: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +@dataclass +class StoreOperation(DialectOp): + addr: SsaUse + ref: SsaUse + index: List[SsaUse] + type: mast.MemRefType + _syntax_ = 'memref.store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' + +# Inspect current module to get all classes defined above +memref = Dialect('memref', ops=[m[1] for m in inspect.getmembers( + sys.modules[__name__], lambda obj: is_op(obj, __name__))]) diff --git a/mlir/dialects/standard.py b/mlir/dialects/standard.py index 6eb6168..8050310 100644 --- a/mlir/dialects/standard.py +++ b/mlir/dialects/standard.py @@ -96,14 +96,6 @@ class ExtractElementOperation(DialectOp): _syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}' -@dataclass -class LoadOperation(DialectOp): - arg: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - @dataclass class SplatOperation(DialectOp): arg: SsaUse @@ -111,15 +103,6 @@ class SplatOperation(DialectOp): _syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type) -@dataclass -class StoreOperation(DialectOp): - addr: SsaUse - ref: SsaUse - index: List[SsaUse] - type: mast.MemRefType - _syntax_ = 'store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}' - - @dataclass class TensorLoadOperation(DialectOp): arg: SsaUse