Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .scf import scf as scf_dialect
from .linalg import linalg
from .func import func as func_dialect
from .memref import memref as memref_dialect


STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect]
STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect, memref_dialect]
31 changes: 31 additions & 0 deletions mlir/dialects/memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
""" Implementation of the Memref dialect. """

import inspect
import sys
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass

import mlir.astnodes as mast
from mlir.dialect import Dialect, DialectOp, is_op

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]

@dataclass
class LoadOperation(DialectOp):
arg: SsaUse
index: List[SsaUse]
type: mast.MemRefType
_syntax_ = 'memref.load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'

@dataclass
class StoreOperation(DialectOp):
addr: SsaUse
ref: SsaUse
index: List[SsaUse]
type: mast.MemRefType
_syntax_ = 'memref.store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'

# Inspect current module to get all classes defined above
memref = Dialect('memref', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
17 changes: 0 additions & 17 deletions mlir/dialects/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,13 @@ class ExtractElementOperation(DialectOp):
_syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}'


@dataclass
class LoadOperation(DialectOp):
arg: SsaUse
index: List[SsaUse]
type: mast.MemRefType
_syntax_ = 'load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'


@dataclass
class SplatOperation(DialectOp):
arg: SsaUse
type: Union[mast.VectorType, mast.TensorType]
_syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type)


@dataclass
class StoreOperation(DialectOp):
addr: SsaUse
ref: SsaUse
index: List[SsaUse]
type: mast.MemRefType
_syntax_ = 'store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'


@dataclass
class TensorLoadOperation(DialectOp):
arg: SsaUse
Expand Down
Loading