-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathExpressionSyntaxRewriter.EnumMethodExpansion.cs
More file actions
155 lines (132 loc) · 6.11 KB
/
ExpressionSyntaxRewriter.EnumMethodExpansion.cs
File metadata and controls
155 lines (132 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace EntityFrameworkCore.Projectables.Generator;
public partial class ExpressionSyntaxRewriter
{
private bool TryExpandEnumMethodCall(InvocationExpressionSyntax node, MemberAccessExpressionSyntax memberAccess, IMethodSymbol methodSymbol, out ExpressionSyntax? expandedExpression)
{
expandedExpression = null;
// Get the receiver expression (the enum instance or variable)
var receiverExpression = memberAccess.Expression;
var receiverTypeInfo = _semanticModel.GetTypeInfo(receiverExpression);
var receiverType = receiverTypeInfo.Type;
// Handle nullable enum types
ITypeSymbol enumType;
var isNullable = false;
if (receiverType is INamedTypeSymbol { IsGenericType: true, Name: "Nullable" } nullableType &&
nullableType.TypeArguments.Length == 1 &&
nullableType.TypeArguments[0].TypeKind == TypeKind.Enum)
{
enumType = nullableType.TypeArguments[0];
isNullable = true;
}
else if (receiverType?.TypeKind == TypeKind.Enum)
{
enumType = receiverType;
}
else
{
// Not an enum type
return false;
}
// Get all enum members
var enumMembers = enumType.GetMembers()
.OfType<IFieldSymbol>()
.Where(f => f.HasConstantValue)
.ToList();
if (enumMembers.Count == 0)
{
return false;
}
// Visit the receiver expression to transform it (e.g., @this.MyProperty)
var visitedReceiver = (ExpressionSyntax)Visit(receiverExpression);
// Get the original method (in case of reduced extension method)
var originalMethod = methodSymbol.ReducedFrom ?? methodSymbol;
// Get the return type of the method to determine the default value
var returnType = methodSymbol.ReturnType;
// Build a chain of ternary expressions for each enum value
// Start with default(T) as the fallback for non-nullable types, or null for nullable/reference types
ExpressionSyntax defaultExpression;
if (returnType.IsReferenceType || returnType.NullableAnnotation == NullableAnnotation.Annotated ||
returnType is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T })
{
defaultExpression = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
}
else
{
// Use default(T) for value types
defaultExpression = SyntaxFactory.DefaultExpression(
SyntaxFactory.ParseTypeName(returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)));
}
var currentExpression = defaultExpression;
// Create the enum value access: EnumType.Value
var enumAccessValues = enumMembers
.AsEnumerable()
.Reverse()
.Select(m =>
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.ParseTypeName(enumType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
SyntaxFactory.IdentifierName(m.Name)
)
);
// Build the ternary chain, calling the method on each enum value
foreach (var enumValueAccess in enumAccessValues)
{
// Create the method call on the enum value: ExtensionClass.Method(EnumType.Value)
var methodCall = CreateMethodCallOnEnumValue(originalMethod, enumValueAccess, node.ArgumentList);
// Create condition: receiver == EnumType.Value
var condition = SyntaxFactory.BinaryExpression(
SyntaxKind.EqualsExpression,
visitedReceiver,
enumValueAccess
);
// Create conditional expression: condition ? methodCall : previousExpression
currentExpression = SyntaxFactory.ConditionalExpression(
condition,
methodCall,
currentExpression
);
}
// If nullable, wrap in null check
if (isNullable)
{
var nullCheck = SyntaxFactory.BinaryExpression(
SyntaxKind.EqualsExpression,
visitedReceiver,
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
);
currentExpression = SyntaxFactory.ConditionalExpression(
nullCheck,
defaultExpression,
currentExpression
);
}
expandedExpression = SyntaxFactory.ParenthesizedExpression(currentExpression);
return true;
}
private ExpressionSyntax CreateMethodCallOnEnumValue(IMethodSymbol methodSymbol, ExpressionSyntax enumValueExpression, ArgumentListSyntax originalArguments)
{
// Get the fully qualified containing type name
var containingTypeName = methodSymbol.ContainingType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat);
// Create the method access expression: ContainingType.MethodName
var methodAccess = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.ParseName(containingTypeName),
SyntaxFactory.IdentifierName(methodSymbol.Name)
);
// Build arguments: the enum value as the first argument (for extension methods), followed by any additional arguments
var arguments = SyntaxFactory.SeparatedList<ArgumentSyntax>();
arguments = arguments.Add(SyntaxFactory.Argument(enumValueExpression));
// Add any additional arguments from the original call
foreach (var arg in originalArguments.Arguments)
{
arguments = arguments.Add((ArgumentSyntax)Visit(arg));
}
return SyntaxFactory.InvocationExpression(
methodAccess,
SyntaxFactory.ArgumentList(arguments)
);
}
}