Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5c66a47
fix type alias definition
dzufferey Apr 1, 2025
b4e42d4
parsing of func.return
dzufferey Apr 1, 2025
0eb8414
parsing of custom float types
dzufferey Apr 1, 2025
5759093
parsing private functions
dzufferey Apr 1, 2025
4bf858c
allow "call" for "func.call"
dzufferey Apr 2, 2025
14bb730
Remove pretty dialect types and reduce ambiguities
dzufferey Apr 3, 2025
4abcb3b
increase priority of dialect ops
dzufferey Apr 7, 2025
8b99ecc
generalize FP format
dzufferey Apr 11, 2025
d25ea09
fixes to VectorType
dzufferey May 30, 2025
4ea5fab
dense array attribute
dzufferey Jun 6, 2025
0116735
upstream dialect syntax modifications
amanda849 Jul 3, 2025
a23216b
Update mlir.lark
amanda849 Jul 25, 2025
5fcf5af
Update scf.py
amanda849 Jul 25, 2025
9697ee2
Update mlir.lark
amanda849 Jul 25, 2025
5ddf87e
Update mlir.lark
amanda849 Jul 25, 2025
11bcc94
Update scf.py
amanda849 Jul 25, 2025
7a7545f
Merge pull request #1 from dzufferey/amanda849-patch-1
amanda849 Jul 25, 2025
e1801a6
Update astnodes.py
amanda849 Jul 25, 2025
182a825
Merge pull request #2 from dzufferey/amanda849-patch-1
dzufferey Jul 28, 2025
0a44399
fix dense element attribute
dzufferey Jul 30, 2025
a721222
dump handles None
dzufferey Jul 30, 2025
2291772
Allow attribute alias in opaque_dialect_item_contents
amanda849 Jul 31, 2025
f50058f
scf.while with empty assignment
dzufferey Aug 12, 2025
408c85a
fix tensor literal
dzufferey Aug 13, 2025
768dbc9
fix mangling of ids
dzufferey Oct 15, 2025
f044225
fix custom float parsing
dzufferey Oct 15, 2025
ea1c0c6
Update parser_transformer.py
amanda849 Dec 13, 2025
9ddf615
Update mlir.lark
amanda849 Dec 13, 2025
1fa474e
Fix Type Alias parsing
amanda849 Jan 29, 2026
e5544ad
Update mlir.lark
amanda849 Jan 29, 2026
5979695
update setup for lark
dzufferey Apr 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 95 additions & 17 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" Classes containing MLIR AST node types, fields, and conversion back to
MLIR. """

from enum import Enum, auto
from enum import Enum
from typing import Any, List, Union, Optional
from lark import Token
from lark.tree import Tree
Expand Down Expand Up @@ -126,12 +126,82 @@ class FloatTypeEnum(Enum):

@dataclass
class FloatType(Type):
pass


@dataclass
class StandardFloatType(FloatType):
type: FloatTypeEnum

def dump(self, indent: int = 0) -> str:
return self.type.name


@dataclass
class CustomFloatType(FloatType):
width: int
exponent: int
mantissa: int
bias: int = 0
signed: bool = True
zero: bool = True
infinities: bool = True
nans: bool = True

@classmethod
def from_lark(cls, args: list):
bias: int = 0
signed: bool = True
zero: bool = True
infinities: bool = True
nans: bool = True
i = 0
while i < len(args):
a = args[i]
if isinstance(a, Token):
if a.value == "f":
i += 1
width = int(args[i])
elif a.value == "E":
i += 1
exponent = int(args[i])
elif a.value == "M":
i += 1
mantissa = int(args[i])
elif a.value == "B":
i += 1
bias = int(args[i])
elif a.value == "F":
infinities = False
elif a.value == "N":
nans = False
elif a.value == "U":
signed = False
zero = False
elif a.value == "Z":
zero = True
else:
raise ValueError(f"Unknow format specification: {a}")
else:
raise ValueError(f"Bad argument type: {a}")
i += 1
return cls(width, exponent, mantissa, bias, signed, zero, infinities, nans)


def dump(self, indent: int = 0) -> str:
suffix = ""
if not self.infinities:
suffix += "F"
if not self.nans:
suffix += "N"
if not self.signed:
suffix += "U"
if self.zero:
suffix += "Z"
bias = "" if self.bias == 0 else f"B{self.bias}"
return f"f{self.width}E{self.exponent}M{self.mantissa}{bias}{suffix}"


@dataclass
class TensorFloatType(Type):
def dump(self, indent: int = 0) -> str:
Expand Down Expand Up @@ -185,8 +255,8 @@ def dump(self, indent: int = 0) -> str:

@dataclass
class VectorType(Type):
dimensions: int
element_type: Union[IntegerType, FloatType]
dimensions: List[int]
element_type: Union[IntegerType, FloatType, IndexType]

def dump(self, indent: int = 0) -> str:
return 'vector<%s>' % ('x'.join(
Expand Down Expand Up @@ -275,16 +345,6 @@ class OpaqueDialectType(Type):
def dump(self, indent: int = 0) -> str:
return '!%s<"%s">' % (self.dialect, self.contents)

@dataclass
class PrettyDialectType(Type):
dialect: str
type: str
body: List[str]

def dump(self, indent: int = 0) -> str:
return '!%s.%s<%s>' % (self.dialect, self.type, ', '.join(
dump_or_value(item, indent) for item in self.body))


@dataclass
class FunctionType(Type):
Expand Down Expand Up @@ -351,18 +411,28 @@ def dump(self, indent: int = 0) -> str:
return '{%s}' % dump_or_value(self.value, indent)


@dataclass
class DenseArrayAttr(Attribute):
type: IntegerType | FloatType
value: List[bool | int | float]

def dump(self, indent: int = 0) -> str:
return 'array<%s: %s>' % (self.type.dump(indent),
dump_or_value(self.value, indent))


@dataclass
class ElementsAttr(Attribute):
pass


@dataclass
class DenseElementsAttr(ElementsAttr):
attribute: Attribute
attribute: Optional[Attribute]
type: Union[TensorType, VectorType]

def dump(self, indent: int = 0) -> str:
return 'dense<%s> : %s' % (self.attribute.dump(indent),
return 'dense<%s> : %s' % (dump_or_value(self.attribute, indent),
self.type.dump(indent))


Expand Down Expand Up @@ -619,7 +689,7 @@ class GenericModule(ModuleType):
args: List["NamedArgument"]
region: "Region"
attributes: Optional[AttributeDict]
type: List[Type]
type: Type | List[Type]
location: Optional[Location] = None

def dump(self, indent=0) -> str:
Expand All @@ -634,14 +704,18 @@ def dump(self, indent=0) -> str:
result += ')'
if self.attributes:
result += ' ' + dump_or_value(self.attributes, indent)
result += ' : ' + self.type.dump(indent)
if isinstance(self.type, list):
result += ' : ' + ', '.join(t.dump(indent) for t in self.type)
else:
result += ' : ' + self.type.dump(indent)
if self.location:
result += ' ' + self.location.dump(indent)
return result


@dataclass
class Function(Node):
visibility: Optional[str]
name: SymbolRefId
args: Optional[List["NamedArgument"]]
result_types: Optional[List[Type]]
Expand All @@ -651,6 +725,8 @@ class Function(Node):

def dump(self, indent=0) -> str:
result = 'func.func'
if self.visibility:
result += ' %s' % self.visibility
result += ' %s' % self.name.dump(indent)
arg_list = self.args if self.args else []
result += '(%s)' % ', '.join(
Expand Down Expand Up @@ -999,6 +1075,8 @@ def _dump_ast_or_value(value: Any, python=True, indent: int = 0) -> str:
'%s%s%s' %
(_dump_ast_or_value(k, python), sep, _dump_ast_or_value(v, python))
for k, v in value.items())
if value is None:
return ""
return str(value)


Expand Down
8 changes: 4 additions & 4 deletions mlir/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def __init__(self):

name_gen = UniqueNameGenerator(forced_prefix="_pymlir_")

F16 = mast.FloatType(type=mast.FloatTypeEnum.f16)
F32 = mast.FloatType(type=mast.FloatTypeEnum.f32)
F64 = mast.FloatType(type=mast.FloatTypeEnum.f64)
F16 = mast.StandardFloatType(type=mast.FloatTypeEnum.f16)
F32 = mast.StandardFloatType(type=mast.FloatTypeEnum.f32)
F64 = mast.StandardFloatType(type=mast.FloatTypeEnum.f64)
INT32 = mast.IntegerType(32)
INT64 = mast.IntegerType(64)
INDEX = mast.IndexType()
Expand Down Expand Up @@ -143,7 +143,7 @@ def function(self, name: Optional[str] = None) -> mast.Function:
if name is None:
name = self.name_gen("fn")

op = mast.Function(mast.SymbolRefId(value=name), [], [], None,
op = mast.Function(None, mast.SymbolRefId(value=name), [], [], None,
mast.Region([]))

self._insert_op_in_block([], op)
Expand Down
122 changes: 62 additions & 60 deletions mlir/dialects/func.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,62 @@

import inspect
import sys
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass

import mlir.astnodes as mast
from mlir.dialect import Dialect, DialectOp, is_op

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]

@dataclass
class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
func_type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {func_type.function_type}',
'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {func_type.function_type}']


@dataclass
class CallOperation(DialectOp):
func: mast.SymbolRefId
func_type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call {func.symbol_ref_id} () : {func_type.function_type}',
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {func_type.function_type}']

@dataclass
class ConstantOperation(DialectOp):
value: mast.SymbolRefId
val_type: mast.Type
_syntax_ = ['func.constant {value.symbol_ref_id} : {val_typw.type}']

# Note: The 'func.func' operation is defined as 'function' in mlir.lark.

@dataclass
class ReturnOperation(DialectOp):
values: Optional[List[SsaUse]] = None
types: Optional[List[mast.Type]] = None
_syntax_ = ['return',
'return {values.ssa_use_list} : {types.type_list_no_parens}']

def dump(self, indent: int = 0) -> str:
output = 'return'
if self.values:
output += ' ' + ', '.join([v.dump(indent) for v in self.values])
if self.types:
output += ' : ' + ', '.join([t.dump(indent) for t in self.types])

return output



# Inspect current module to get all classes defined above
func = Dialect('func', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])

import inspect
import sys
from typing import List, Optional, Union
from dataclasses import dataclass

import mlir.astnodes as mast
from mlir.dialect import Dialect, DialectOp, is_op

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]

@dataclass
class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
func_type: mast.FunctionType
args: Optional[List[SsaUse]] = None
_syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {func_type.function_type}',
'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {func_type.function_type}']


@dataclass
class CallOperation(DialectOp):
func: mast.SymbolRefId
func_type: mast.FunctionType
args: Optional[List[SsaUse]] = None
_syntax_ = ['call {func.symbol_ref_id} () : {func_type.function_type}',
'func.call {func.symbol_ref_id} () : {func_type.function_type}',
'call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {func_type.function_type}',
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {func_type.function_type}']

@dataclass
class ConstantOperation(DialectOp):
value: mast.SymbolRefId
val_type: mast.Type
_syntax_ = ['func.constant {value.symbol_ref_id} : {val_typw.type}']

# Note: The 'func.func' operation is defined as 'function' in mlir.lark.

@dataclass
class ReturnOperation(DialectOp):
values: Optional[List[SsaUse]] = None
types: Optional[List[mast.Type]] = None
_syntax_ = ['return',
'func.return',
'return {values.ssa_use_list} : {types.type_list_no_parens}',
'func.return {values.ssa_use_list} : {types.type_list_no_parens}']

def dump(self, indent: int = 0) -> str:
output = 'return'
if self.values:
output += ' ' + ', '.join([v.dump(indent) for v in self.values])
if self.types:
output += ' : ' + ', '.join([t.dump(indent) for t in self.types])

return output



# Inspect current module to get all classes defined above
func = Dialect('func', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
Loading