Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions src/MLIR.Generators/DialectSourceEmitter.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
namespace MLIR.Generators;

using System.Collections.Generic;
using System.Globalization;
using System.Text;
using MLIR.ODS.Model;
using MLIR.ODS.Model.AssemblyFormat;
using MLIR.Text;

internal static class DialectSourceEmitter
{
Expand All @@ -29,6 +32,12 @@ public static string GenerateDialectSource(DialectModel dialect)
AppendOperationClass(builder, operation);
builder.AppendLine();

if (operation.AssemblyFormat != null)
{
AppendOperationBodySyntaxClass(builder, operation);
builder.AppendLine();
}

if (operation.HasCustomAssemblyFormat)
{
AppendAssemblyFormatClass(builder, operation);
Expand Down Expand Up @@ -169,6 +178,308 @@ private static void AppendOperationClass(StringBuilder builder, OperationModel o
builder.AppendLine("}");
}

private static void AppendOperationBodySyntaxClass(StringBuilder builder, OperationModel operation)
{
var className = DialectGeneratorNaming.GetOperationClassName(operation);
var assemblyFormat = operation.AssemblyFormat!;
var fields = ComputeBodySyntaxFields(operation, assemblyFormat);

builder.AppendLine("public sealed class " + className + "BodySyntax : OperationBodySyntax");
builder.AppendLine("{");

// Constructor
builder.Append(" public " + className + "BodySyntax(");
for (var i = 0; i < fields.Count; i++)
{
if (i > 0)
{
builder.Append(", ");
}

builder.Append(fields[i].CsType + " " + LowerFirst(fields[i].Name));
}

builder.AppendLine(")");
builder.AppendLine(" {");
foreach (var field in fields)
{
builder.AppendLine(" " + field.Name + " = " + LowerFirst(field.Name) + ";");
}

builder.AppendLine(" }");

if (fields.Count > 0)
{
builder.AppendLine();
foreach (var field in fields)
{
builder.AppendLine(" public " + field.CsType + " " + field.Name + " { get; }");
}
}

builder.AppendLine();
builder.AppendLine(" public override bool TryGetGenericBody(out GenericOperationBodySyntax? genericBody)");
builder.AppendLine(" {");
builder.AppendLine(" genericBody = null;");
builder.AppendLine(" return false;");
builder.AppendLine(" }");
builder.AppendLine();
builder.AppendLine(" public override void WriteTo(Text.SyntaxWriter writer, int indentLevel, System.Action<Text.SyntaxWriter, RegionSyntax, int> writeRegion)");
builder.AppendLine(" {");
foreach (var field in fields)
{
builder.Append(field.WriteToCode);
}

builder.AppendLine(" }");
builder.AppendLine("}");
}

private static IReadOnlyList<BodySyntaxField> ComputeBodySyntaxFields(OperationModel operation, AssemblyFormatModel assemblyFormat)
{
var fields = new List<BodySyntaxField>();
var usedNames = new HashSet<string>(StringComparer.Ordinal);

foreach (var element in assemblyFormat.Elements)
{
AppendBodySyntaxFields(fields, usedNames, element, operation);
}

return fields;
}

private static void AppendBodySyntaxFields(List<BodySyntaxField> fields, HashSet<string> usedNames, Element element, OperationModel operation)
{
switch (element)
{
case LiteralChunk literal:
foreach (var lit in literal.Value)
{
switch (lit)
{
case PunctuationLiteral punc:
{
var name = MakeUnique(GetPunctuationFieldName(punc.TokenKind), usedNames);
fields.Add(new BodySyntaxField(name, "SyntaxToken",
" writer.WriteToken(" + name + ", string.Empty);\n"));
break;
}

case KeywordLiteral kw:
{
var name = MakeUnique(DialectGeneratorNaming.ToPascalCase(kw.Spelling) + "Keyword", usedNames);
fields.Add(new BodySyntaxField(name, "SyntaxToken",
" writer.WriteToken(" + name + ", \" \");\n"));
break;
}

// WhitespaceLiteral, NewlineLiteral, EmptyLiteral → no field; spacing is in stored trivia
}
}

break;

case VariableChunk variable:
{
var pascalName = DialectGeneratorNaming.ToPascalCase(variable.Name);
if (ContainsName(operation.Attributes, variable.Name))
{
var name = MakeUnique(pascalName, usedNames);
fields.Add(new BodySyntaxField(name, "AttributeValueSyntax",
" " + name + ".WriteTo(writer, \" \");\n"));
}
else
{
// Operand, result variable, or unknown → SyntaxToken
var name = MakeUnique(pascalName, usedNames);
fields.Add(new BodySyntaxField(name, "SyntaxToken",
" writer.WriteToken(" + name + ", \" \");\n"));
}

break;
}

case AttrDictDirectiveChunk _:
case AttrDictWithKeywordDirectiveChunk _:
{
var name = MakeUnique("AttrDict", usedNames);
fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList<NamedAttributeSyntax>",
GenerateDelimitedNamedAttributeWriteTo(name)));
break;
}

case PropDictDirectiveChunk _:
{
var name = MakeUnique("PropDict", usedNames);
fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList<NamedAttributeSyntax>",
GenerateDelimitedNamedAttributeWriteTo(name)));
break;
}

case RegionsDirectiveChunk _:
{
var name = MakeUnique("Regions", usedNames);
fields.Add(new BodySyntaxField(name, "IReadOnlyList<RegionSyntax>",
" foreach (var region in " + name + ")\n" +
" {\n" +
" writeRegion(writer, region, indentLevel);\n" +
" }\n"));
break;
}

case TypeDirectiveChunk typeDir:
{
var baseName = typeDir.Operand is VariableOperand varOp
? DialectGeneratorNaming.ToPascalCase(varOp.Name) + "Type"
: "Type";
var name = MakeUnique(baseName, usedNames);
fields.Add(new BodySyntaxField(name, "TypeSyntax",
" " + name + ".WriteTo(writer, \" \");\n"));
break;
}

case SuccessorsDirectiveChunk _:
{
var name = MakeUnique("Successors", usedNames);
fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList<SyntaxToken>",
GenerateDelimitedTokenWriteTo(name)));
break;
}

case OperandsDirectiveChunk _:
{
var name = MakeUnique("Operands", usedNames);
fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList<SyntaxToken>",
GenerateDelimitedTokenWriteTo(name)));
break;
}

// OptionalGroup, OilistDirectiveChunk, CustomDirectiveChunk, FunctionalTypeDirectiveChunk,
// QualifiedDirectiveChunk, RefDirectiveChunk, ResultsDirectiveChunk → not stored in this CST class
}
}

private static string GenerateDelimitedNamedAttributeWriteTo(string fieldName)
{
return
" if (" + fieldName + ".OpenToken != null)\n" +
" {\n" +
" writer.WriteToken(" + fieldName + ".OpenToken.Value, \" \");\n" +
" for (var i = 0; i < " + fieldName + ".Count; i++)\n" +
" {\n" +
" if (i > 0)\n" +
" {\n" +
" writer.WriteToken(" + fieldName + ".SeparatorTokens[i - 1], string.Empty);\n" +
" }\n" +
" " + fieldName + "[i].WriteTo(writer, i > 0 ? \" \" : string.Empty);\n" +
" }\n" +
" writer.WriteToken(" + fieldName + ".CloseToken!.Value, string.Empty);\n" +
" }\n";
}

private static string GenerateDelimitedTokenWriteTo(string fieldName)
{
return
" if (" + fieldName + ".OpenToken != null)\n" +
" {\n" +
" writer.WriteToken(" + fieldName + ".OpenToken.Value, \" \");\n" +
" for (var i = 0; i < " + fieldName + ".Count; i++)\n" +
" {\n" +
" if (i > 0)\n" +
" {\n" +
" writer.WriteToken(" + fieldName + ".SeparatorTokens[i - 1], string.Empty);\n" +
" }\n" +
" writer.WriteToken(" + fieldName + "[i], i > 0 ? \" \" : string.Empty);\n" +
" }\n" +
" writer.WriteToken(" + fieldName + ".CloseToken!.Value, string.Empty);\n" +
" }\n";
}

private static string GetPunctuationFieldName(TokenKind tokenKind)
{
return tokenKind switch
{
TokenKind.Comma => "CommaToken",
TokenKind.LParen => "LParenToken",
TokenKind.RParen => "RParenToken",
TokenKind.LBracket => "LBracketToken",
TokenKind.RBracket => "RBracketToken",
TokenKind.LBrace => "LBraceToken",
TokenKind.RBrace => "RBraceToken",
TokenKind.Arrow => "ArrowToken",
TokenKind.Colon => "ColonToken",
TokenKind.Equal => "EqualToken",
TokenKind.LessThan => "LessThanToken",
TokenKind.GreaterThan => "GreaterThanToken",
TokenKind.Question => "QuestionToken",
TokenKind.Star => "StarToken",
TokenKind.Plus => "PlusToken",
TokenKind.Minus => "MinusToken",
TokenKind.Dot => "DotToken",
TokenKind.At => "AtToken",
TokenKind.Hash => "HashToken",
_ => "Token",
};
}

private static string MakeUnique(string baseName, HashSet<string> used)
{
if (used.Add(baseName))
{
return baseName;
}

for (var i = 2; ; i++)
{
var candidate = baseName + i.ToString(CultureInfo.InvariantCulture);
if (used.Add(candidate))
{
return candidate;
}
}
}

private static string LowerFirst(string name)
{
if (name.Length == 0)
{
return name;
}

return char.ToLowerInvariant(name[0]) + name.Substring(1);
}

private static bool ContainsName(IReadOnlyList<string> names, string name)
{
foreach (var n in names)
{
if (string.Equals(n, name, StringComparison.Ordinal))
{
return true;
}
}

return false;
}

private sealed class BodySyntaxField
{
public BodySyntaxField(string name, string csType, string writeToCode)
{
Name = name;
CsType = csType;
WriteToCode = writeToCode;
}

public string Name { get; }
public string CsType { get; }

/// <summary>
/// C# code (indented for the WriteTo body, ending with a newline) that writes this field.
/// </summary>
public string WriteToCode { get; }
}

private static void AppendAssemblyFormatClass(StringBuilder builder, OperationModel operation)
{
var className = DialectGeneratorNaming.GetOperationClassName(operation);
Expand Down
7 changes: 6 additions & 1 deletion src/MLIR/Syntax/NamedAttributeSyntax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ public bool TryGetRawValue(out RawSyntaxText? rawValue)
/// </summary>
public RawSyntaxText RawValue => ValueSyntax.GetRawText();

internal void WriteTo(SyntaxWriter writer, string defaultLeadingTrivia)
/// <summary>
/// Writes the named attribute to the supplied syntax writer.
/// </summary>
/// <param name="writer">The syntax writer to write to.</param>
/// <param name="defaultLeadingTrivia">The fallback leading trivia for the name token.</param>
public void WriteTo(SyntaxWriter writer, string defaultLeadingTrivia)
{
writer.WriteToken(NameToken, defaultLeadingTrivia);
writer.WriteToken(EqualsToken, " ");
Expand Down
Loading
Loading