diff --git a/src/MLIR.Generators/Emitters/AssemblyFormatEmitter.cs b/src/MLIR.Generators/Emitters/AssemblyFormatEmitter.cs index f93f27f..7bbee77 100644 --- a/src/MLIR.Generators/Emitters/AssemblyFormatEmitter.cs +++ b/src/MLIR.Generators/Emitters/AssemblyFormatEmitter.cs @@ -57,7 +57,11 @@ public static void Emit(StringBuilder builder, OperationModel operation, Operati builder.AppendLine(); builder.AppendLine(" public Operation Bind(OperationSyntax syntax, OperationDefinition definition, Binder binder)"); builder.AppendLine(" {"); - builder.AppendLine(" var body = (" + className + "BodySyntax)syntax.Body;"); + builder.AppendLine(" if (syntax.Body is not " + className + "BodySyntax body)"); + builder.AppendLine(" {"); + builder.AppendLine(" binder.Report(new AssemblyDiagnostic(syntax.Location, \"Expected a " + className + "BodySyntax but found \" + syntax.Body.GetType().Name + \".\"));"); + builder.AppendLine(" return new UninterpretedOperation(syntax, definition.Name);"); + builder.AppendLine(" }"); builder.AppendLine(" if (syntax.ResultTokens.Count != " + operation.Results.Count.ToString(CultureInfo.InvariantCulture) + ")"); builder.AppendLine(" {"); builder.AppendLine(" binder.Report(new AssemblyDiagnostic(syntax.Location, \"Expected exactly " + operation.Results.Count.ToString(CultureInfo.InvariantCulture) + " result(s) but found \" + syntax.ResultTokens.Count + \".\"));"); diff --git a/tests/DialectTests/DialectIntegrationTests.cs b/tests/DialectTests/DialectIntegrationTests.cs index ccee4ad..81ca380 100644 --- a/tests/DialectTests/DialectIntegrationTests.cs +++ b/tests/DialectTests/DialectIntegrationTests.cs @@ -3,6 +3,8 @@ namespace DialectTests; using MLIR; using MLIR.Dialects; using MLIR.Miniarith; +using MLIR.Semantics; +using MLIR.Syntax; using Xunit; public sealed class DialectIntegrationTests @@ -72,4 +74,125 @@ public void GeneratedDialectCanReadAndWriteBoundModules() Assert.Equal(source, module.ToText()); } + + [Fact] + public void GeneratedAssemblyFormatBindProducesTypedAddIOpFromCustomBodySyntax() + { + // Arrange: construct a custom body matching "$lhs `,` $rhs attr-dict `:` type($result)" + var body = new MiniArith_AddIOpBodySyntax( + new SyntaxToken("%lhs"), + new SyntaxToken(","), + new SyntaxToken("%rhs"), + new DelimitedSyntaxList(null, [], [], null), + new SyntaxToken(":"), + new RawTypeSyntax(new RawSyntaxText("i32"))); + + var syntax = new OperationSyntax( + resultTokens: [new SyntaxToken("%result")], + resultCommaTokens: [], + equalsToken: new SyntaxToken("="), + nameToken: new SyntaxToken("miniarith.addi"), + body: body); + + var registry = new DialectRegistry(); + registry.RegisterDialect(MiniarithDialectRegistration.Create()); + + // Act: bind invokes MiniArith_AddIOpAssemblyFormat.Bind because body is not GenericOperationBodySyntax + var module = Binder.BindModule(new ModuleSyntax([syntax]), registry); + + // Assert + Assert.Empty(module.AssemblyDiagnostics); + var operation = Assert.IsType(Assert.Single(module.Operations)); + Assert.Equal("miniarith.addi", operation.Name); + Assert.Equal("%lhs", operation.Lhs.Name); + Assert.Equal("%rhs", operation.Rhs.Name); + Assert.Equal("%result", operation.ResultValue.Name); + Assert.NotNull(operation.TypeSignatureReference); + } + + [Fact] + public void GeneratedAssemblyFormatBindProducesTypedConstantOpFromCustomBodySyntax() + { + // Arrange: construct a custom body matching "$value attr-dict" + var body = new MiniArith_ConstantOpBodySyntax( + new RawAttributeValueSyntax(new RawSyntaxText("42")), + new DelimitedSyntaxList(null, [], [], null)); + + var syntax = new OperationSyntax( + resultTokens: [new SyntaxToken("%result")], + resultCommaTokens: [], + equalsToken: new SyntaxToken("="), + nameToken: new SyntaxToken("miniarith.constant"), + body: body); + + var registry = new DialectRegistry(); + registry.RegisterDialect(MiniarithDialectRegistration.Create()); + + // Act: bind invokes MiniArith_ConstantOpAssemblyFormat.Bind because body is not GenericOperationBodySyntax + var module = Binder.BindModule(new ModuleSyntax([syntax]), registry); + + // Assert + Assert.Empty(module.AssemblyDiagnostics); + var operation = Assert.IsType(Assert.Single(module.Operations)); + Assert.Equal("miniarith.constant", operation.Name); + Assert.Equal("%result", operation.ResultValue.Name); + Assert.Equal("value", operation.Value.Name); + Assert.Null(operation.TypeSignatureReference); + } + + [Fact] + public void GeneratedAssemblyFormatBindReportsDiagnosticForWrongBodyType() + { + // Arrange: pass a MiniArith_ConstantOpBodySyntax to the addi operation (wrong type) + var body = new MiniArith_ConstantOpBodySyntax( + new RawAttributeValueSyntax(new RawSyntaxText("42")), + new DelimitedSyntaxList(null, [], [], null)); + + var syntax = new OperationSyntax( + resultTokens: [new SyntaxToken("%result")], + resultCommaTokens: [], + equalsToken: new SyntaxToken("="), + nameToken: new SyntaxToken("miniarith.addi"), + body: body); + + var registry = new DialectRegistry(); + registry.RegisterDialect(MiniarithDialectRegistration.Create()); + + // Act + var module = Binder.BindModule(new ModuleSyntax([syntax]), registry); + + // Assert: one diagnostic reported, operation is uninterpreted + Assert.Single(module.AssemblyDiagnostics); + Assert.IsType(Assert.Single(module.Operations)); + } + + [Fact] + public void GeneratedAssemblyFormatBindReportsDiagnosticForWrongResultCount() + { + // Arrange: zero result tokens when the operation expects exactly one + var body = new MiniArith_AddIOpBodySyntax( + new SyntaxToken("%lhs"), + new SyntaxToken(","), + new SyntaxToken("%rhs"), + new DelimitedSyntaxList(null, [], [], null), + new SyntaxToken(":"), + new RawTypeSyntax(new RawSyntaxText("i32"))); + + var syntax = new OperationSyntax( + resultTokens: [], + resultCommaTokens: [], + equalsToken: null, + nameToken: new SyntaxToken("miniarith.addi"), + body: body); + + var registry = new DialectRegistry(); + registry.RegisterDialect(MiniarithDialectRegistration.Create()); + + // Act + var module = Binder.BindModule(new ModuleSyntax([syntax]), registry); + + // Assert: one diagnostic reported, operation is uninterpreted + Assert.Single(module.AssemblyDiagnostics); + Assert.IsType(Assert.Single(module.Operations)); + } } diff --git a/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs b/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs index c396846..845e478 100644 --- a/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs +++ b/tests/MLIR.Generators.Tests/DialectGeneratorTests.cs @@ -101,6 +101,36 @@ public void GeneratesOperationBodySyntaxClassForDeclarativeAssemblyFormat() Assert.Contains("public override void WriteTo(Text.SyntaxWriter writer, int indentLevel)", registrationSource); } + [Fact] + public void GeneratedBindMethodUsesPatternMatchInsteadOfHardCastForBodyType() + { + const string source = + "class MiniArith_Op traits = []> :\n" + + " Op;\n" + + "\n" + + "def MiniArith_Dialect : Dialect {\n" + + " let name = \"miniarith\";\n" + + " let cppNamespace = \"::mlir::miniarith\";\n" + + "};\n" + + "\n" + + "def MiniArith_AddIOp : MiniArith_Op<\"addi\", [Pure, Commutative]> {\n" + + " let summary = \"integer addition\";\n" + + " let arguments = (ins I32:$lhs, I32:$rhs);\n" + + " let results = (outs I32:$result);\n" + + " let assemblyFormat = \"$lhs `,` $rhs attr-dict `:` type($result)\";\n" + + "};"; + + var generatedSources = GeneratorTestHelpers.RunGenerator( + new DialectGenerator(), + ("miniarith.td", source)); + var registrationSource = Assert.Single(generatedSources.Where(static result => result.HintName == "MiniarithDialectRegistration.g.cs")).SourceText.ToString(); + + // The Bind method must use a safe pattern-match rather than a hard cast so that + // a wrong body type yields a diagnostic instead of an InvalidCastException. + Assert.Contains("if (syntax.Body is not MiniArith_AddIOpBodySyntax body)", registrationSource); + Assert.DoesNotContain("(MiniArith_AddIOpBodySyntax)syntax.Body", registrationSource); + } + [Fact] public void BodySyntaxClassIsNotGeneratedForOperationsWithoutDeclarativeAssemblyFormat() {