diff --git a/src/MLIR/Semantics/Binder.cs b/src/MLIR/Semantics/Binder.cs index f09fffb..adf02ab 100644 --- a/src/MLIR/Semantics/Binder.cs +++ b/src/MLIR/Semantics/Binder.cs @@ -30,17 +30,47 @@ public static Module BindModule(ModuleSyntax syntax, DialectRegistry? dialectReg private static Operation BindOperation(OperationSyntax syntax, DialectRegistry? dialectRegistry, List diagnostics) { - var genericBody = syntax.GenericBody; - var regions = new List(); - foreach (var region in genericBody.Regions) + List regions; + List attributes; + TypeReference? typeSignatureReference = null; + IReadOnlyList operandValues; + IReadOnlyList successorReferences; + + if (syntax.TryGetGenericBody(out var genericBody)) { - regions.Add(BindRegion(region, dialectRegistry, diagnostics)); + var regionList = new List(); + foreach (var region in genericBody!.Regions) + { + regionList.Add(BindRegion(region, dialectRegistry, diagnostics)); + } + + regions = regionList; + + var attributeList = new List(); + foreach (var attribute in genericBody.Attributes) + { + attributeList.Add(new NamedAttribute(attribute, BindAttributeValue(attribute.RawValue, attribute.NameToken, dialectRegistry, diagnostics))); + } + + attributes = attributeList; + + if (genericBody.RawTypeSignature != null) + { + var location = genericBody.TypeSignatureColonToken != null + ? SourceLocation.FromToken(genericBody.TypeSignatureColonToken.Value) + : default; + typeSignatureReference = BindTypeReference(genericBody.RawTypeSignature, location, dialectRegistry, diagnostics); + } + + operandValues = CreateValueReferences(genericBody.OperandList.Items); + successorReferences = CreateBlockReferences(genericBody.SuccessorList.Items); } - - var attributes = new List(); - foreach (var attribute in genericBody.Attributes) + else { - attributes.Add(new NamedAttribute(attribute, BindAttributeValue(attribute.RawValue, attribute.NameToken, dialectRegistry, diagnostics))); + regions = new List(); + attributes = new List(); + operandValues = new List(); + successorReferences = new List(); } var name = NormalizeOperationName(syntax.Name); @@ -50,18 +80,7 @@ private static Operation BindOperation(OperationSyntax syntax, DialectRegistry? dialectRegistry.TryGetOperation(name, out definition); } - TypeReference? typeSignatureReference = null; - if (genericBody.RawTypeSignature != null) - { - var location = genericBody.TypeSignatureColonToken != null - ? SourceLocation.FromToken(genericBody.TypeSignatureColonToken.Value) - : default; - typeSignatureReference = BindTypeReference(genericBody.RawTypeSignature, location, dialectRegistry, diagnostics); - } - var resultValues = CreateValueReferences(syntax.ResultTokens); - var operandValues = CreateValueReferences(genericBody.OperandList.Items); - var successorReferences = CreateBlockReferences(genericBody.SuccessorList.Items); Operation operation; if (definition != null) { diff --git a/src/MLIR/Semantics/Operation.cs b/src/MLIR/Semantics/Operation.cs index 1a745a9..e8d57c7 100644 --- a/src/MLIR/Semantics/Operation.cs +++ b/src/MLIR/Semantics/Operation.cs @@ -108,7 +108,8 @@ public string DialectName /// /// Gets the raw type signature text, if present. /// - public RawSyntaxText? TypeSignature => Syntax.RawTypeSignature; + public RawSyntaxText? TypeSignature => + Syntax.TryGetGenericBody(out var genericBody) ? genericBody!.RawTypeSignature : null; /// /// Gets the source location of the operation name, if known. diff --git a/src/MLIR/Syntax/OperationSyntax.cs b/src/MLIR/Syntax/OperationSyntax.cs index dde5bf0..e1dd597 100644 --- a/src/MLIR/Syntax/OperationSyntax.cs +++ b/src/MLIR/Syntax/OperationSyntax.cs @@ -147,54 +147,6 @@ public bool TryGetGenericBody(out GenericOperationBodySyntax? genericBody) return Body.TryGetGenericBody(out genericBody); } - /// - /// Gets the operation body as a generic MLIR body. - /// - public GenericOperationBodySyntax GenericBody => Body.GetGenericBody(); - - /// - /// Gets the delimited operand list. - /// - public DelimitedSyntaxList OperandList => GenericBody.OperandList; - - /// - /// Gets the delimited successor list. - /// - public DelimitedSyntaxList SuccessorList => GenericBody.SuccessorList; - - /// - /// Gets the regions nested under the operation. - /// - public IReadOnlyList Regions => GenericBody.Regions; - - /// - /// Gets the delimited attribute dictionary. - /// - public DelimitedSyntaxList Attributes => GenericBody.Attributes; - - /// - /// Gets the colon token that introduces the type signature, if present. - /// - public SyntaxToken? TypeSignatureColonToken => GenericBody.TypeSignatureColonToken; - - /// - /// Gets the trailing type signature syntax, if present. - /// - public TypeSyntax? TypeSignatureSyntax => GenericBody.TypeSignatureSyntax; - - /// - /// Attempts to get the trailing type signature as raw syntax text. - /// - public bool TryGetRawTypeSignature(out RawSyntaxText? rawTypeSignature) - { - return GenericBody.TryGetRawTypeSignature(out rawTypeSignature); - } - - /// - /// Gets the trailing type signature as raw syntax text. - /// - public RawSyntaxText? RawTypeSignature => GenericBody.RawTypeSignature; - /// /// Gets the SSA results produced by the operation. /// @@ -205,16 +157,6 @@ public bool TryGetRawTypeSignature(out RawSyntaxText? rawTypeSignature) /// public string Name => NameToken.Text; - /// - /// Gets the SSA operands passed to the operation. - /// - public IReadOnlyList Operands => GetTexts(OperandList.Items); - - /// - /// Gets the successor block labels referenced by the operation. - /// - public IReadOnlyList Successors => GetTexts(SuccessorList.Items); - /// /// Gets a value indicating whether the operation uses a custom assembly body. /// diff --git a/src/MLIR/Transforms/AssemblySyntaxBuilder.cs b/src/MLIR/Transforms/AssemblySyntaxBuilder.cs index 3eff115..98f5926 100644 --- a/src/MLIR/Transforms/AssemblySyntaxBuilder.cs +++ b/src/MLIR/Transforms/AssemblySyntaxBuilder.cs @@ -1,5 +1,6 @@ namespace MLIR.Transforms; +using System; using System.Collections.Generic; using MLIR.Semantics; using MLIR.Syntax; @@ -54,7 +55,12 @@ public OperationSyntax RewriteOperation(Operation operation, OperationBodySyntax public GenericOperationBodySyntax BuildGenericBody(Operation operation) { - var genericBody = operation.Syntax.GenericBody; + if (!operation.Syntax.TryGetGenericBody(out var genericBody)) + { + throw new InvalidOperationException( + "Cannot build a generic body for an operation that does not provide a generic body projection."); + } + var regions = new List(operation.Regions.Count); foreach (var region in operation.Regions) { @@ -62,12 +68,12 @@ public GenericOperationBodySyntax BuildGenericBody(Operation operation) } return new GenericOperationBodySyntax( - genericBody.OperandList, - genericBody.SuccessorList, + genericBody!.OperandList, + genericBody!.SuccessorList, regions, - genericBody.Attributes, - genericBody.TypeSignatureColonToken, - genericBody.TypeSignatureSyntax); + genericBody!.Attributes, + genericBody!.TypeSignatureColonToken, + genericBody!.TypeSignatureSyntax); } public RegionSyntax BuildRegion(Region region) diff --git a/src/MLIR/Transforms/GenericSyntaxBuilder.cs b/src/MLIR/Transforms/GenericSyntaxBuilder.cs index 261acd6..21c6dfc 100644 --- a/src/MLIR/Transforms/GenericSyntaxBuilder.cs +++ b/src/MLIR/Transforms/GenericSyntaxBuilder.cs @@ -34,8 +34,12 @@ public ModuleSyntax BuildModule(ModuleSyntax module) private OperationSyntax BuildOperation(OperationSyntax operation) { - var genericBody = operation.GenericBody; - var regions = new List(genericBody.Regions.Count); + if (!operation.TryGetGenericBody(out var genericBody)) + { + return operation; + } + + var regions = new List(genericBody!.Regions.Count); foreach (var region in genericBody.Regions) { regions.Add(BuildRegion(region)); diff --git a/tests/MLIR.Tests/ConstructionTests.cs b/tests/MLIR.Tests/ConstructionTests.cs index 0b3d254..081bbee 100644 --- a/tests/MLIR.Tests/ConstructionTests.cs +++ b/tests/MLIR.Tests/ConstructionTests.cs @@ -61,13 +61,14 @@ public void ExposesPreservedTokenTriviaInTheCst() Assert.Equal("// leading comment\n", operation.ResultTokens[0].LeadingTrivia); Assert.Equal(",", operation.ResultCommaTokens[0].Text); Assert.Equal(" ", operation.ResultTokens[1].LeadingTrivia); - Assert.Equal(" ", operation.SuccessorList.OpenToken!.Value.LeadingTrivia); - Assert.Equal(" ", operation.SuccessorList[0].LeadingTrivia); - Assert.Equal(" ", operation.SuccessorList.CloseToken!.Value.LeadingTrivia); + var body = (GenericOperationBodySyntax)operation.Body; + Assert.Equal(" ", body.SuccessorList.OpenToken!.Value.LeadingTrivia); + Assert.Equal(" ", body.SuccessorList[0].LeadingTrivia); + Assert.Equal(" ", body.SuccessorList.CloseToken!.Value.LeadingTrivia); Assert.Equal("%0", operation.Results[0]); Assert.Equal("\"test.op\"", operation.Name); - Assert.Equal("%lhs", operation.Operands[0]); - Assert.Equal("^bb1", operation.Successors[0]); + Assert.Equal("%lhs", body.OperandList.Items[0].Text); + Assert.Equal("^bb1", body.SuccessorList.Items[0].Text); } [Fact] @@ -194,7 +195,8 @@ [new SyntaxToken("%0")], Assert.Equal("%0 = arith.constant 0 : i32", text); Assert.True(module.Operations[0].HasCustomAssemblyBody); - Assert.Equal("0", module.Operations[0].Attributes[0].RawValue.Text); + Assert.True(module.Operations[0].TryGetGenericBody(out var customOpBody)); + Assert.Equal("0", customOpBody!.Attributes[0].RawValue.Text); } [Fact] diff --git a/tests/MLIR.Tests/ParsingTests.cs b/tests/MLIR.Tests/ParsingTests.cs index 2e2f3a6..bb8b7bc 100644 --- a/tests/MLIR.Tests/ParsingTests.cs +++ b/tests/MLIR.Tests/ParsingTests.cs @@ -100,8 +100,9 @@ public void ParsesSuccessorsAndRegions() var module = Parser.ParseModule(source); Assert.Single(module.Operations); - Assert.Single(module.Operations[0].Regions); - Assert.Equal("^bb2", module.Operations[0].Regions[0].Blocks[0].Operations[0].Successors[0]); + var op0Body = (GenericOperationBodySyntax)module.Operations[0].Body; + Assert.Single(op0Body.Regions); + Assert.Equal("^bb2", ((GenericOperationBodySyntax)op0Body.Regions[0].Blocks[0].Operations[0].Body).SuccessorList.Items[0].Text); } [Fact] @@ -111,7 +112,7 @@ public void PreservesStructuredTypeSignatureText() var module = Parser.ParseModule(source); - Assert.Equal("(memref<2x?xf32, #map>) -> memref<*xf32>", module.Operations[0].RawTypeSignature!.Text); + Assert.Equal("(memref<2x?xf32, #map>) -> memref<*xf32>", ((GenericOperationBodySyntax)module.Operations[0].Body).RawTypeSignature!.Text); } [Fact] @@ -164,7 +165,7 @@ public void RoundTripsLargerInputWithMultipleBlocks() var text = Printer.Print(module); Assert.Equal(2, module.Operations.Count); - Assert.Equal(3, module.Operations[1].Regions[0].Blocks.Count); + Assert.Equal(3, ((GenericOperationBodySyntax)module.Operations[1].Body).Regions[0].Blocks.Count); Assert.Equal(source, text); } @@ -211,9 +212,10 @@ public void ParsesEmptyAttributeDictionary() var module = Parser.ParseModule(source); var operation = module.Operations[0]; + var opBody = (GenericOperationBodySyntax)operation.Body; - Assert.Empty(operation.Regions); - Assert.Empty(operation.Attributes); + Assert.Empty(opBody.Regions); + Assert.Empty(opBody.Attributes); Assert.Equal(source, Printer.Print(module)); } @@ -228,7 +230,7 @@ public void ParsesUnlabeledEntryBlockBeforeExplicitLabeledBlock() "} : () -> ()"; var module = Parser.ParseModule(source); - var blocks = module.Operations[0].Regions[0].Blocks; + var blocks = ((GenericOperationBodySyntax)module.Operations[0].Body).Regions[0].Blocks; Assert.Equal(2, blocks.Count); Assert.Equal("^entry", blocks[0].Label); @@ -267,7 +269,7 @@ public void PreservesRawSyntaxWithNestedDelimiters() "\"test.op\"(%arg0) {layout = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : (tensor<2x2xi32>) -> tensor<2x2xi32>"; var module = Parser.ParseModule(source); - var attribute = module.Operations[0].Attributes[0]; + var attribute = ((GenericOperationBodySyntax)module.Operations[0].Body).Attributes[0]; Assert.Equal("dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>", attribute.RawValue.Text); Assert.Equal(source, Printer.Print(module)); @@ -318,8 +320,8 @@ public void RewritesNestedCustomAssemblySyntaxToGenericSyntaxRecursively() var genericModule = GenericSyntaxBuilder.BuildModule(module); - Assert.True(module.Operations[0].Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody); - Assert.False(genericModule.Operations[0].Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody); + Assert.True(((GenericOperationBodySyntax)module.Operations[0].Body).Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody); + Assert.False(((GenericOperationBodySyntax)genericModule.Operations[0].Body).Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody); Assert.Equal( "\"scf.if\"(%cond) {\n" + " %0 = arith.constant() {value = 0} : i32\n" + diff --git a/tests/MLIR.Tests/SemanticTests.cs b/tests/MLIR.Tests/SemanticTests.cs index 44805e8..bff3349 100644 --- a/tests/MLIR.Tests/SemanticTests.cs +++ b/tests/MLIR.Tests/SemanticTests.cs @@ -739,11 +739,12 @@ public void ParserCanUseRegisteredCustomAssemblyFormats() Assert.Single(module.Operations); Assert.Equal("arith.constant", module.Operations[0].Name); Assert.True(module.Operations[0].HasCustomAssemblyBody); - Assert.Empty(module.Operations[0].Operands); - Assert.Single(module.Operations[0].Attributes); - Assert.Equal("value", module.Operations[0].Attributes[0].Name); - Assert.Equal("0", module.Operations[0].Attributes[0].RawValue.Text); - Assert.Equal("i32", module.Operations[0].RawTypeSignature!.Text); + Assert.True(module.Operations[0].TryGetGenericBody(out var constantBody)); + Assert.Empty(constantBody!.OperandList.Items); + Assert.Single(constantBody.Attributes); + Assert.Equal("value", constantBody.Attributes[0].Name); + Assert.Equal("0", constantBody.Attributes[0].RawValue.Text); + Assert.Equal("i32", constantBody.RawTypeSignature!.Text); } [Fact]