Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions src/MLIR/Semantics/Binder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,47 @@ public static Module BindModule(ModuleSyntax syntax, DialectRegistry? dialectReg

private static Operation BindOperation(OperationSyntax syntax, DialectRegistry? dialectRegistry, List<AssemblyDiagnostic> diagnostics)
{
var genericBody = syntax.GenericBody;
var regions = new List<Region>();
foreach (var region in genericBody.Regions)
List<Region> regions;
List<NamedAttribute> attributes;
TypeReference? typeSignatureReference = null;
IReadOnlyList<ValueReference> operandValues;
IReadOnlyList<BlockReference> successorReferences;

if (syntax.TryGetGenericBody(out var genericBody))
{
regions.Add(BindRegion(region, dialectRegistry, diagnostics));
var regionList = new List<Region>();
foreach (var region in genericBody!.Regions)
{
regionList.Add(BindRegion(region, dialectRegistry, diagnostics));
}

regions = regionList;

var attributeList = new List<NamedAttribute>();
foreach (var attribute in genericBody.Attributes)
{
attributeList.Add(new NamedAttribute(attribute, BindAttributeValue(attribute.RawValue, attribute.NameToken, dialectRegistry, diagnostics)));
}

attributes = attributeList;

if (genericBody.RawTypeSignature != null)
{
var location = genericBody.TypeSignatureColonToken != null
? SourceLocation.FromToken(genericBody.TypeSignatureColonToken.Value)
: default;
typeSignatureReference = BindTypeReference(genericBody.RawTypeSignature, location, dialectRegistry, diagnostics);
}

operandValues = CreateValueReferences(genericBody.OperandList.Items);
successorReferences = CreateBlockReferences(genericBody.SuccessorList.Items);
}

var attributes = new List<NamedAttribute>();
foreach (var attribute in genericBody.Attributes)
else
{
attributes.Add(new NamedAttribute(attribute, BindAttributeValue(attribute.RawValue, attribute.NameToken, dialectRegistry, diagnostics)));
regions = new List<Region>();
attributes = new List<NamedAttribute>();
operandValues = new List<ValueReference>();
successorReferences = new List<BlockReference>();
}

var name = NormalizeOperationName(syntax.Name);
Expand All @@ -50,18 +80,7 @@ private static Operation BindOperation(OperationSyntax syntax, DialectRegistry?
dialectRegistry.TryGetOperation(name, out definition);
}

TypeReference? typeSignatureReference = null;
if (genericBody.RawTypeSignature != null)
{
var location = genericBody.TypeSignatureColonToken != null
? SourceLocation.FromToken(genericBody.TypeSignatureColonToken.Value)
: default;
typeSignatureReference = BindTypeReference(genericBody.RawTypeSignature, location, dialectRegistry, diagnostics);
}

var resultValues = CreateValueReferences(syntax.ResultTokens);
var operandValues = CreateValueReferences(genericBody.OperandList.Items);
var successorReferences = CreateBlockReferences(genericBody.SuccessorList.Items);
Operation operation;
if (definition != null)
{
Expand Down
3 changes: 2 additions & 1 deletion src/MLIR/Semantics/Operation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ public string DialectName
/// <summary>
/// Gets the raw type signature text, if present.
/// </summary>
public RawSyntaxText? TypeSignature => Syntax.RawTypeSignature;
public RawSyntaxText? TypeSignature =>
Syntax.TryGetGenericBody(out var genericBody) ? genericBody!.RawTypeSignature : null;

/// <summary>
/// Gets the source location of the operation name, if known.
Expand Down
58 changes: 0 additions & 58 deletions src/MLIR/Syntax/OperationSyntax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,54 +147,6 @@ public bool TryGetGenericBody(out GenericOperationBodySyntax? genericBody)
return Body.TryGetGenericBody(out genericBody);
}

/// <summary>
/// Gets the operation body as a generic MLIR body.
/// </summary>
public GenericOperationBodySyntax GenericBody => Body.GetGenericBody();

/// <summary>
/// Gets the delimited operand list.
/// </summary>
public DelimitedSyntaxList<SyntaxToken> OperandList => GenericBody.OperandList;

/// <summary>
/// Gets the delimited successor list.
/// </summary>
public DelimitedSyntaxList<SyntaxToken> SuccessorList => GenericBody.SuccessorList;

/// <summary>
/// Gets the regions nested under the operation.
/// </summary>
public IReadOnlyList<RegionSyntax> Regions => GenericBody.Regions;

/// <summary>
/// Gets the delimited attribute dictionary.
/// </summary>
public DelimitedSyntaxList<NamedAttributeSyntax> Attributes => GenericBody.Attributes;

/// <summary>
/// Gets the colon token that introduces the type signature, if present.
/// </summary>
public SyntaxToken? TypeSignatureColonToken => GenericBody.TypeSignatureColonToken;

/// <summary>
/// Gets the trailing type signature syntax, if present.
/// </summary>
public TypeSyntax? TypeSignatureSyntax => GenericBody.TypeSignatureSyntax;

/// <summary>
/// Attempts to get the trailing type signature as raw syntax text.
/// </summary>
public bool TryGetRawTypeSignature(out RawSyntaxText? rawTypeSignature)
{
return GenericBody.TryGetRawTypeSignature(out rawTypeSignature);
}

/// <summary>
/// Gets the trailing type signature as raw syntax text.
/// </summary>
public RawSyntaxText? RawTypeSignature => GenericBody.RawTypeSignature;

/// <summary>
/// Gets the SSA results produced by the operation.
/// </summary>
Expand All @@ -205,16 +157,6 @@ public bool TryGetRawTypeSignature(out RawSyntaxText? rawTypeSignature)
/// </summary>
public string Name => NameToken.Text;

/// <summary>
/// Gets the SSA operands passed to the operation.
/// </summary>
public IReadOnlyList<string> Operands => GetTexts(OperandList.Items);

/// <summary>
/// Gets the successor block labels referenced by the operation.
/// </summary>
public IReadOnlyList<string> Successors => GetTexts(SuccessorList.Items);

/// <summary>
/// Gets a value indicating whether the operation uses a custom assembly body.
/// </summary>
Expand Down
18 changes: 12 additions & 6 deletions src/MLIR/Transforms/AssemblySyntaxBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
namespace MLIR.Transforms;

using System;
using System.Collections.Generic;
using MLIR.Semantics;
using MLIR.Syntax;
Expand Down Expand Up @@ -54,20 +55,25 @@ public OperationSyntax RewriteOperation(Operation operation, OperationBodySyntax

public GenericOperationBodySyntax BuildGenericBody(Operation operation)
{
var genericBody = operation.Syntax.GenericBody;
if (!operation.Syntax.TryGetGenericBody(out var genericBody))
{
throw new InvalidOperationException(
"Cannot build a generic body for an operation that does not provide a generic body projection.");
}

var regions = new List<RegionSyntax>(operation.Regions.Count);
foreach (var region in operation.Regions)
{
regions.Add(BuildRegion(region));
}

return new GenericOperationBodySyntax(
genericBody.OperandList,
genericBody.SuccessorList,
genericBody!.OperandList,
genericBody!.SuccessorList,
regions,
genericBody.Attributes,
genericBody.TypeSignatureColonToken,
genericBody.TypeSignatureSyntax);
genericBody!.Attributes,
genericBody!.TypeSignatureColonToken,
genericBody!.TypeSignatureSyntax);
}

public RegionSyntax BuildRegion(Region region)
Expand Down
8 changes: 6 additions & 2 deletions src/MLIR/Transforms/GenericSyntaxBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ public ModuleSyntax BuildModule(ModuleSyntax module)

private OperationSyntax BuildOperation(OperationSyntax operation)
{
var genericBody = operation.GenericBody;
var regions = new List<RegionSyntax>(genericBody.Regions.Count);
if (!operation.TryGetGenericBody(out var genericBody))
{
return operation;
}

var regions = new List<RegionSyntax>(genericBody!.Regions.Count);
foreach (var region in genericBody.Regions)
{
regions.Add(BuildRegion(region));
Expand Down
14 changes: 8 additions & 6 deletions tests/MLIR.Tests/ConstructionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ public void ExposesPreservedTokenTriviaInTheCst()
Assert.Equal("// leading comment\n", operation.ResultTokens[0].LeadingTrivia);
Assert.Equal(",", operation.ResultCommaTokens[0].Text);
Assert.Equal(" ", operation.ResultTokens[1].LeadingTrivia);
Assert.Equal(" ", operation.SuccessorList.OpenToken!.Value.LeadingTrivia);
Assert.Equal(" ", operation.SuccessorList[0].LeadingTrivia);
Assert.Equal(" ", operation.SuccessorList.CloseToken!.Value.LeadingTrivia);
var body = (GenericOperationBodySyntax)operation.Body;
Assert.Equal(" ", body.SuccessorList.OpenToken!.Value.LeadingTrivia);
Assert.Equal(" ", body.SuccessorList[0].LeadingTrivia);
Assert.Equal(" ", body.SuccessorList.CloseToken!.Value.LeadingTrivia);
Assert.Equal("%0", operation.Results[0]);
Assert.Equal("\"test.op\"", operation.Name);
Assert.Equal("%lhs", operation.Operands[0]);
Assert.Equal("^bb1", operation.Successors[0]);
Assert.Equal("%lhs", body.OperandList.Items[0].Text);
Assert.Equal("^bb1", body.SuccessorList.Items[0].Text);
}

[Fact]
Expand Down Expand Up @@ -194,7 +195,8 @@ [new SyntaxToken("%0")],

Assert.Equal("%0 = arith.constant 0 : i32", text);
Assert.True(module.Operations[0].HasCustomAssemblyBody);
Assert.Equal("0", module.Operations[0].Attributes[0].RawValue.Text);
Assert.True(module.Operations[0].TryGetGenericBody(out var customOpBody));
Assert.Equal("0", customOpBody!.Attributes[0].RawValue.Text);
}

[Fact]
Expand Down
22 changes: 12 additions & 10 deletions tests/MLIR.Tests/ParsingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ public void ParsesSuccessorsAndRegions()
var module = Parser.ParseModule(source);

Assert.Single(module.Operations);
Assert.Single(module.Operations[0].Regions);
Assert.Equal("^bb2", module.Operations[0].Regions[0].Blocks[0].Operations[0].Successors[0]);
var op0Body = (GenericOperationBodySyntax)module.Operations[0].Body;
Assert.Single(op0Body.Regions);
Assert.Equal("^bb2", ((GenericOperationBodySyntax)op0Body.Regions[0].Blocks[0].Operations[0].Body).SuccessorList.Items[0].Text);
}

[Fact]
Expand All @@ -111,7 +112,7 @@ public void PreservesStructuredTypeSignatureText()

var module = Parser.ParseModule(source);

Assert.Equal("(memref<2x?xf32, #map>) -> memref<*xf32>", module.Operations[0].RawTypeSignature!.Text);
Assert.Equal("(memref<2x?xf32, #map>) -> memref<*xf32>", ((GenericOperationBodySyntax)module.Operations[0].Body).RawTypeSignature!.Text);
}

[Fact]
Expand Down Expand Up @@ -164,7 +165,7 @@ public void RoundTripsLargerInputWithMultipleBlocks()
var text = Printer.Print(module);

Assert.Equal(2, module.Operations.Count);
Assert.Equal(3, module.Operations[1].Regions[0].Blocks.Count);
Assert.Equal(3, ((GenericOperationBodySyntax)module.Operations[1].Body).Regions[0].Blocks.Count);
Assert.Equal(source, text);
}

Expand Down Expand Up @@ -211,9 +212,10 @@ public void ParsesEmptyAttributeDictionary()

var module = Parser.ParseModule(source);
var operation = module.Operations[0];
var opBody = (GenericOperationBodySyntax)operation.Body;

Assert.Empty(operation.Regions);
Assert.Empty(operation.Attributes);
Assert.Empty(opBody.Regions);
Assert.Empty(opBody.Attributes);
Assert.Equal(source, Printer.Print(module));
}

Expand All @@ -228,7 +230,7 @@ public void ParsesUnlabeledEntryBlockBeforeExplicitLabeledBlock()
"} : () -> ()";

var module = Parser.ParseModule(source);
var blocks = module.Operations[0].Regions[0].Blocks;
var blocks = ((GenericOperationBodySyntax)module.Operations[0].Body).Regions[0].Blocks;

Assert.Equal(2, blocks.Count);
Assert.Equal("^entry", blocks[0].Label);
Expand Down Expand Up @@ -267,7 +269,7 @@ public void PreservesRawSyntaxWithNestedDelimiters()
"\"test.op\"(%arg0) {layout = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : (tensor<2x2xi32>) -> tensor<2x2xi32>";

var module = Parser.ParseModule(source);
var attribute = module.Operations[0].Attributes[0];
var attribute = ((GenericOperationBodySyntax)module.Operations[0].Body).Attributes[0];

Assert.Equal("dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>", attribute.RawValue.Text);
Assert.Equal(source, Printer.Print(module));
Expand Down Expand Up @@ -318,8 +320,8 @@ public void RewritesNestedCustomAssemblySyntaxToGenericSyntaxRecursively()

var genericModule = GenericSyntaxBuilder.BuildModule(module);

Assert.True(module.Operations[0].Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody);
Assert.False(genericModule.Operations[0].Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody);
Assert.True(((GenericOperationBodySyntax)module.Operations[0].Body).Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody);
Assert.False(((GenericOperationBodySyntax)genericModule.Operations[0].Body).Regions[0].Blocks[0].Operations[0].HasCustomAssemblyBody);
Assert.Equal(
"\"scf.if\"(%cond) {\n" +
" %0 = arith.constant() {value = 0} : i32\n" +
Expand Down
11 changes: 6 additions & 5 deletions tests/MLIR.Tests/SemanticTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,12 @@ public void ParserCanUseRegisteredCustomAssemblyFormats()
Assert.Single(module.Operations);
Assert.Equal("arith.constant", module.Operations[0].Name);
Assert.True(module.Operations[0].HasCustomAssemblyBody);
Assert.Empty(module.Operations[0].Operands);
Assert.Single(module.Operations[0].Attributes);
Assert.Equal("value", module.Operations[0].Attributes[0].Name);
Assert.Equal("0", module.Operations[0].Attributes[0].RawValue.Text);
Assert.Equal("i32", module.Operations[0].RawTypeSignature!.Text);
Assert.True(module.Operations[0].TryGetGenericBody(out var constantBody));
Assert.Empty(constantBody!.OperandList.Items);
Assert.Single(constantBody.Attributes);
Assert.Equal("value", constantBody.Attributes[0].Name);
Assert.Equal("0", constantBody.Attributes[0].RawValue.Text);
Assert.Equal("i32", constantBody.RawTypeSignature!.Text);
}

[Fact]
Expand Down
Loading