From df1af35f9a981f878c5ad675edabefe9a83f2639 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 04:01:28 +0000 Subject: [PATCH 1/2] Initial plan From 7e28c50c5d1b99b30ba996ee92c32693ebb1c83c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 29 Mar 2026 04:47:36 +0000 Subject: [PATCH 2/2] Teach DialectSourceEmitter to synthesize OperationBodySyntax subclasses for operations with assembly formats Agent-Logs-Url: https://github.com/jonathanvdc/MLIR.NET/sessions/8c4f0d04-0e04-4697-9f1c-86e836772c76 Co-authored-by: jonathanvdc <9839946+jonathanvdc@users.noreply.github.com> --- src/MLIR.Generators/DialectSourceEmitter.cs | 311 ++++++++++++++++++ src/MLIR/Syntax/NamedAttributeSyntax.cs | 7 +- .../DialectGeneratorTests.cs | 83 +++++ 3 files changed, 400 insertions(+), 1 deletion(-) diff --git a/src/MLIR.Generators/DialectSourceEmitter.cs b/src/MLIR.Generators/DialectSourceEmitter.cs index 0619164..a5be8dc 100644 --- a/src/MLIR.Generators/DialectSourceEmitter.cs +++ b/src/MLIR.Generators/DialectSourceEmitter.cs @@ -1,8 +1,11 @@ namespace MLIR.Generators; +using System.Collections.Generic; using System.Globalization; using System.Text; using MLIR.ODS.Model; +using MLIR.ODS.Model.AssemblyFormat; +using MLIR.Text; internal static class DialectSourceEmitter { @@ -29,6 +32,12 @@ public static string GenerateDialectSource(DialectModel dialect) AppendOperationClass(builder, operation); builder.AppendLine(); + if (operation.AssemblyFormat != null) + { + AppendOperationBodySyntaxClass(builder, operation); + builder.AppendLine(); + } + if (operation.HasCustomAssemblyFormat) { AppendAssemblyFormatClass(builder, operation); @@ -169,6 +178,308 @@ private static void AppendOperationClass(StringBuilder builder, OperationModel o builder.AppendLine("}"); } + private static void AppendOperationBodySyntaxClass(StringBuilder builder, OperationModel operation) + { + var className = DialectGeneratorNaming.GetOperationClassName(operation); + var assemblyFormat = operation.AssemblyFormat!; + var fields = ComputeBodySyntaxFields(operation, assemblyFormat); + + builder.AppendLine("public sealed class " + className + "BodySyntax : OperationBodySyntax"); + builder.AppendLine("{"); + + // Constructor + builder.Append(" public " + className + "BodySyntax("); + for (var i = 0; i < fields.Count; i++) + { + if (i > 0) + { + builder.Append(", "); + } + + builder.Append(fields[i].CsType + " " + LowerFirst(fields[i].Name)); + } + + builder.AppendLine(")"); + builder.AppendLine(" {"); + foreach (var field in fields) + { + builder.AppendLine(" " + field.Name + " = " + LowerFirst(field.Name) + ";"); + } + + builder.AppendLine(" }"); + + if (fields.Count > 0) + { + builder.AppendLine(); + foreach (var field in fields) + { + builder.AppendLine(" public " + field.CsType + " " + field.Name + " { get; }"); + } + } + + builder.AppendLine(); + builder.AppendLine(" public override bool TryGetGenericBody(out GenericOperationBodySyntax? genericBody)"); + builder.AppendLine(" {"); + builder.AppendLine(" genericBody = null;"); + builder.AppendLine(" return false;"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" public override void WriteTo(Text.SyntaxWriter writer, int indentLevel, System.Action writeRegion)"); + builder.AppendLine(" {"); + foreach (var field in fields) + { + builder.Append(field.WriteToCode); + } + + builder.AppendLine(" }"); + builder.AppendLine("}"); + } + + private static IReadOnlyList ComputeBodySyntaxFields(OperationModel operation, AssemblyFormatModel assemblyFormat) + { + var fields = new List(); + var usedNames = new HashSet(StringComparer.Ordinal); + + foreach (var element in assemblyFormat.Elements) + { + AppendBodySyntaxFields(fields, usedNames, element, operation); + } + + return fields; + } + + private static void AppendBodySyntaxFields(List fields, HashSet usedNames, Element element, OperationModel operation) + { + switch (element) + { + case LiteralChunk literal: + foreach (var lit in literal.Value) + { + switch (lit) + { + case PunctuationLiteral punc: + { + var name = MakeUnique(GetPunctuationFieldName(punc.TokenKind), usedNames); + fields.Add(new BodySyntaxField(name, "SyntaxToken", + " writer.WriteToken(" + name + ", string.Empty);\n")); + break; + } + + case KeywordLiteral kw: + { + var name = MakeUnique(DialectGeneratorNaming.ToPascalCase(kw.Spelling) + "Keyword", usedNames); + fields.Add(new BodySyntaxField(name, "SyntaxToken", + " writer.WriteToken(" + name + ", \" \");\n")); + break; + } + + // WhitespaceLiteral, NewlineLiteral, EmptyLiteral → no field; spacing is in stored trivia + } + } + + break; + + case VariableChunk variable: + { + var pascalName = DialectGeneratorNaming.ToPascalCase(variable.Name); + if (ContainsName(operation.Attributes, variable.Name)) + { + var name = MakeUnique(pascalName, usedNames); + fields.Add(new BodySyntaxField(name, "AttributeValueSyntax", + " " + name + ".WriteTo(writer, \" \");\n")); + } + else + { + // Operand, result variable, or unknown → SyntaxToken + var name = MakeUnique(pascalName, usedNames); + fields.Add(new BodySyntaxField(name, "SyntaxToken", + " writer.WriteToken(" + name + ", \" \");\n")); + } + + break; + } + + case AttrDictDirectiveChunk _: + case AttrDictWithKeywordDirectiveChunk _: + { + var name = MakeUnique("AttrDict", usedNames); + fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList", + GenerateDelimitedNamedAttributeWriteTo(name))); + break; + } + + case PropDictDirectiveChunk _: + { + var name = MakeUnique("PropDict", usedNames); + fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList", + GenerateDelimitedNamedAttributeWriteTo(name))); + break; + } + + case RegionsDirectiveChunk _: + { + var name = MakeUnique("Regions", usedNames); + fields.Add(new BodySyntaxField(name, "IReadOnlyList", + " foreach (var region in " + name + ")\n" + + " {\n" + + " writeRegion(writer, region, indentLevel);\n" + + " }\n")); + break; + } + + case TypeDirectiveChunk typeDir: + { + var baseName = typeDir.Operand is VariableOperand varOp + ? DialectGeneratorNaming.ToPascalCase(varOp.Name) + "Type" + : "Type"; + var name = MakeUnique(baseName, usedNames); + fields.Add(new BodySyntaxField(name, "TypeSyntax", + " " + name + ".WriteTo(writer, \" \");\n")); + break; + } + + case SuccessorsDirectiveChunk _: + { + var name = MakeUnique("Successors", usedNames); + fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList", + GenerateDelimitedTokenWriteTo(name))); + break; + } + + case OperandsDirectiveChunk _: + { + var name = MakeUnique("Operands", usedNames); + fields.Add(new BodySyntaxField(name, "DelimitedSyntaxList", + GenerateDelimitedTokenWriteTo(name))); + break; + } + + // OptionalGroup, OilistDirectiveChunk, CustomDirectiveChunk, FunctionalTypeDirectiveChunk, + // QualifiedDirectiveChunk, RefDirectiveChunk, ResultsDirectiveChunk → not stored in this CST class + } + } + + private static string GenerateDelimitedNamedAttributeWriteTo(string fieldName) + { + return + " if (" + fieldName + ".OpenToken != null)\n" + + " {\n" + + " writer.WriteToken(" + fieldName + ".OpenToken.Value, \" \");\n" + + " for (var i = 0; i < " + fieldName + ".Count; i++)\n" + + " {\n" + + " if (i > 0)\n" + + " {\n" + + " writer.WriteToken(" + fieldName + ".SeparatorTokens[i - 1], string.Empty);\n" + + " }\n" + + " " + fieldName + "[i].WriteTo(writer, i > 0 ? \" \" : string.Empty);\n" + + " }\n" + + " writer.WriteToken(" + fieldName + ".CloseToken!.Value, string.Empty);\n" + + " }\n"; + } + + private static string GenerateDelimitedTokenWriteTo(string fieldName) + { + return + " if (" + fieldName + ".OpenToken != null)\n" + + " {\n" + + " writer.WriteToken(" + fieldName + ".OpenToken.Value, \" \");\n" + + " for (var i = 0; i < " + fieldName + ".Count; i++)\n" + + " {\n" + + " if (i > 0)\n" + + " {\n" + + " writer.WriteToken(" + fieldName + ".SeparatorTokens[i - 1], string.Empty);\n" + + " }\n" + + " writer.WriteToken(" + fieldName + "[i], i > 0 ? \" \" : string.Empty);\n" + + " }\n" + + " writer.WriteToken(" + fieldName + ".CloseToken!.Value, string.Empty);\n" + + " }\n"; + } + + private static string GetPunctuationFieldName(TokenKind tokenKind) + { + return tokenKind switch + { + TokenKind.Comma => "CommaToken", + TokenKind.LParen => "LParenToken", + TokenKind.RParen => "RParenToken", + TokenKind.LBracket => "LBracketToken", + TokenKind.RBracket => "RBracketToken", + TokenKind.LBrace => "LBraceToken", + TokenKind.RBrace => "RBraceToken", + TokenKind.Arrow => "ArrowToken", + TokenKind.Colon => "ColonToken", + TokenKind.Equal => "EqualToken", + TokenKind.LessThan => "LessThanToken", + TokenKind.GreaterThan => "GreaterThanToken", + TokenKind.Question => "QuestionToken", + TokenKind.Star => "StarToken", + TokenKind.Plus => "PlusToken", + TokenKind.Minus => "MinusToken", + TokenKind.Dot => "DotToken", + TokenKind.At => "AtToken", + TokenKind.Hash => "HashToken", + _ => "Token", + }; + } + + private static string MakeUnique(string baseName, HashSet used) + { + if (used.Add(baseName)) + { + return baseName; + } + + for (var i = 2; ; i++) + { + var candidate = baseName + i.ToString(CultureInfo.InvariantCulture); + if (used.Add(candidate)) + { + return candidate; + } + } + } + + private static string LowerFirst(string name) + { + if (name.Length == 0) + { + return name; + } + + return char.ToLowerInvariant(name[0]) + name.Substring(1); + } + + private static bool ContainsName(IReadOnlyList names, string name) + { + foreach (var n in names) + { + if (string.Equals(n, name, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private sealed class BodySyntaxField + { + public BodySyntaxField(string name, string csType, string writeToCode) + { + Name = name; + CsType = csType; + WriteToCode = writeToCode; + } + + public string Name { get; } + public string CsType { get; } + + /// + /// C# code (indented for the WriteTo body, ending with a newline) that writes this field. + /// + public string WriteToCode { get; } + } + private static void AppendAssemblyFormatClass(StringBuilder builder, OperationModel operation) { var className = DialectGeneratorNaming.GetOperationClassName(operation); diff --git a/src/MLIR/Syntax/NamedAttributeSyntax.cs b/src/MLIR/Syntax/NamedAttributeSyntax.cs index 63cc7a4..d86b0a7 100644 --- a/src/MLIR/Syntax/NamedAttributeSyntax.cs +++ b/src/MLIR/Syntax/NamedAttributeSyntax.cs @@ -56,7 +56,12 @@ public bool TryGetRawValue(out RawSyntaxText? rawValue) /// public RawSyntaxText RawValue => ValueSyntax.GetRawText(); - internal void WriteTo(SyntaxWriter writer, string defaultLeadingTrivia) + /// + /// Writes the named attribute to the supplied syntax writer. + /// + /// The syntax writer to write to. + /// The fallback leading trivia for the name token. + public void WriteTo(SyntaxWriter writer, string defaultLeadingTrivia) { writer.WriteToken(NameToken, defaultLeadingTrivia); writer.WriteToken(EqualsToken, " "); diff --git a/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs b/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs index a408a11..d0de9d0 100644 --- a/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs +++ b/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs @@ -49,6 +49,89 @@ public void GeneratesDialectRegistrationTypedNodesAndCustomAssemblyStubs() Assert.Contains(".WithAssemblyFormat(new MiniArith_AddIOpAssemblyFormat())", registrationSource); } + [Fact] + public void GeneratesOperationBodySyntaxClassForDeclarativeAssemblyFormat() + { + const string source = + "class MiniArith_Op traits = []> :\n" + + " Op;\n" + + "\n" + + "def MiniArith_Dialect : Dialect {\n" + + " let name = \"miniarith\";\n" + + " let cppNamespace = \"::mlir::miniarith\";\n" + + "};\n" + + "\n" + + "def MiniArith_ConstantOp : MiniArith_Op<\"constant\", [Pure]> {\n" + + " let summary = \"integer constant\";\n" + + " let arguments = (ins I32Attr:$value);\n" + + " let results = (outs I32:$result);\n" + + " let assemblyFormat = \"$value attr-dict\";\n" + + "};\n" + + "\n" + + "def MiniArith_AddIOp : MiniArith_Op<\"addi\", [Pure, Commutative]> {\n" + + " let summary = \"integer addition\";\n" + + " let arguments = (ins I32:$lhs, I32:$rhs);\n" + + " let results = (outs I32:$result);\n" + + " let assemblyFormat = \"$lhs `,` $rhs attr-dict `:` type($result)\";\n" + + "};"; + + var generatedSources = GeneratorTestHelpers.RunGenerator( + new DialectGenerator(), + ("miniarith.td", source)); + var registrationSource = Assert.Single(generatedSources.Where(static result => result.HintName == "MiniarithDialectRegistration.g.cs")).SourceText.ToString(); + + // BodySyntax classes are generated for operations with declarative assembly formats. + Assert.Contains("public sealed class MiniArith_ConstantOpBodySyntax : OperationBodySyntax", registrationSource); + Assert.Contains("public sealed class MiniArith_AddIOpBodySyntax : OperationBodySyntax", registrationSource); + + // MiniArith_ConstantOp: $value (attribute) and attr-dict. + Assert.Contains("public AttributeValueSyntax Value { get; }", registrationSource); + + // MiniArith_AddIOp: $lhs, `,`, $rhs, attr-dict, `:`, type($result). + Assert.Contains("public SyntaxToken Lhs { get; }", registrationSource); + Assert.Contains("public SyntaxToken CommaToken { get; }", registrationSource); + Assert.Contains("public SyntaxToken Rhs { get; }", registrationSource); + Assert.Contains("public SyntaxToken ColonToken { get; }", registrationSource); + Assert.Contains("public TypeSyntax ResultType { get; }", registrationSource); + + // Both classes share AttrDict. + Assert.Contains("public DelimitedSyntaxList AttrDict { get; }", registrationSource); + + // TryGetGenericBody is a stub that always returns false. + Assert.Contains("public override bool TryGetGenericBody(out GenericOperationBodySyntax? genericBody)", registrationSource); + Assert.Contains("genericBody = null;", registrationSource); + Assert.Contains("return false;", registrationSource); + + // WriteTo is implemented. + Assert.Contains("public override void WriteTo(Text.SyntaxWriter writer, int indentLevel, System.Action writeRegion)", registrationSource); + } + + [Fact] + public void BodySyntaxClassIsNotGeneratedForOperationsWithoutDeclarativeAssemblyFormat() + { + const string source = + "class MiniArith_Op traits = []> :\n" + + " Op;\n" + + "\n" + + "def MiniArith_Dialect : Dialect {\n" + + " let name = \"miniarith\";\n" + + " let cppNamespace = \"::mlir::miniarith\";\n" + + "};\n" + + "\n" + + "def MiniArith_ConstantOp : MiniArith_Op<\"constant\", [Pure]> {\n" + + " let arguments = (ins I32Attr:$value);\n" + + " let results = (outs I32:$result);\n" + + "};"; + + var generatedSources = GeneratorTestHelpers.RunGenerator( + new DialectGenerator(), + ("miniarith.td", source)); + var registrationSource = Assert.Single(generatedSources.Where(static result => result.HintName == "MiniarithDialectRegistration.g.cs")).SourceText.ToString(); + + // No BodySyntax class when there is no declarative assembly format. + Assert.DoesNotContain("BodySyntax", registrationSource); + } + [Fact] public void GeneratesXmlDocCommentsFromSummaryAndDescription() {