From 7d1aa00ebaf72f7a5be229270314f12a9d6e382c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:24:03 +0000 Subject: [PATCH 01/12] Initial plan From ed52f7704ae5d5748cdb257754a2c510c14905df Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:55:04 +0000 Subject: [PATCH 02/12] Implement AOT-compatible static projection registry Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com> --- .../ExpressionResolverBenchmark.cs | 40 +++ .../Helpers/TestEntity.cs | 3 + .../PlainOverhead.cs | 1 + .../ProjectableExtensionMethods.cs | 1 + .../ProjectableMethods.cs | 1 + .../ProjectableProperties.cs | 1 + .../ResolverOverhead.cs | 72 +++++ .../IsExternalInit.cs | 6 + .../ProjectableRegistryEntry.cs | 21 ++ .../ProjectionExpressionGenerator.cs | 254 ++++++++++++++++++ .../Services/ProjectableExpressionReplacer.cs | 24 +- .../Services/ProjectionExpressionResolver.cs | 37 ++- .../ProjectionExpressionGeneratorTestsBase.cs | 51 +++- 13 files changed, 491 insertions(+), 21 deletions(-) create mode 100644 benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs create mode 100644 benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs create mode 100644 src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs create mode 100644 src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs new file mode 100644 index 0000000..e2803de --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs @@ -0,0 +1,40 @@ +using System.Linq.Expressions; +using System.Reflection; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using EntityFrameworkCore.Projectables.Services; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Micro-benchmarks in + /// isolation (no EF Core overhead) to directly compare the registry lookup path against + /// the previous per-call reflection chain. + /// + [MemoryDiagnoser] + public class ExpressionResolverBenchmark + { + private static readonly MemberInfo _propertyMember = + typeof(TestEntity).GetProperty(nameof(TestEntity.IdPlus1))!; + + private static readonly MemberInfo _methodMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlus1Method))!; + + private static readonly MemberInfo _methodWithParamMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlusDelta), new[] { typeof(int) })!; + + private readonly ProjectionExpressionResolver _resolver = new(); + + [Benchmark(Baseline = true)] + public LambdaExpression? ResolveProperty() + => _resolver.FindGeneratedExpression(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod() + => _resolver.FindGeneratedExpression(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam() + => _resolver.FindGeneratedExpression(_methodWithParamMember); + } +} diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs index 4bc741c..68a04d8 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs @@ -15,5 +15,8 @@ public class TestEntity [Projectable] public int IdPlus1Method() => Id + 1; + + [Projectable] + public int IdPlusDelta(int delta) => Id + delta; } } diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs index d064b7d..b9cda85 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class PlainOverhead { [Benchmark(Baseline = true)] diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs index fdb0f9f..467d470 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableExtensionMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs index 785b52c..0618627 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs index 6f2fa57..3a2e2be 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableProperties { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs new file mode 100644 index 0000000..3876641 --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs @@ -0,0 +1,72 @@ +using System.Linq; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using Microsoft.EntityFrameworkCore; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Measures the per-DbContext cold-start cost of resolver lookup by creating a new + /// on every iteration. The previous benchmarks reuse a single + /// DbContext for 10 000 iterations, so the resolver cache is warm after the first query — + /// these benchmarks expose the cost of the very first query per context. + /// + [MemoryDiagnoser] + public class ResolverOverhead + { + const int Iterations = 1000; + + /// Baseline: no projectables, new DbContext per query. + [Benchmark(Baseline = true)] + public void WithoutProjectables_FreshDbContext() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(false); + dbContext.Entities.Select(x => x.Id + 1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable property. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Property() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Method() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1Method()).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method that takes a parameter, + /// exercising parameter-type disambiguation in the registry key. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_MethodWithParam() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlusDelta(5)).ToQueryString(); + } + } + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs new file mode 100644 index 0000000..bd4930e --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs @@ -0,0 +1,6 @@ +// Polyfill for C# 9 record types when targeting netstandard2.0 or netstandard2.1 +// The compiler requires this type to exist in order to use init-only setters (used by records). +namespace System.Runtime.CompilerServices +{ + internal sealed class IsExternalInit { } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs new file mode 100644 index 0000000..4e27f8a --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -0,0 +1,21 @@ +using System.Collections.Immutable; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Incremental-pipeline-safe representation of a single projectable member. + /// Contains only primitive types and ImmutableArray<string> so that value equality + /// works correctly across incremental generation steps. + /// + internal sealed record ProjectableRegistryEntry( + string DeclaringTypeFullName, + string MemberKind, + string MemberLookupName, + string GeneratedClassFullName, + bool IsGenericClass, + int ClassTypeParamCount, + bool IsGenericMethod, + int MethodTypeParamCount, + ImmutableArray ParameterTypeNames + ); +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 440caa5..8f0c9bd 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -3,6 +3,9 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; using System.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -47,6 +50,18 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate the source using the compilation and enums context.RegisterImplementationSourceOutput(compilationAndMemberPairs, static (spc, source) => Execute(source.Item1, source.Item2, spc)); + + // Build the projection registry: collect all entries and emit a single registry file + IncrementalValuesProvider registryEntries = + compilationAndMemberPairs.Select( + static (pair, _) => ExtractRegistryEntry(pair.Item1, pair.Item2)); + + IncrementalValueProvider> allEntries = + registryEntries.Collect(); + + context.RegisterImplementationSourceOutput( + allEntries, + static (spc, entries) => EmitRegistry(entries, spc)); } static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) @@ -226,5 +241,244 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip return lambdaTypeArguments; } } + +#nullable restore + + /// + /// Extracts a from a member declaration. + /// Returns null when the member does not have [Projectable], is an extension member, + /// or cannot be represented in the registry (e.g. a generic class member or generic method). + /// + static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation) + { + var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); + var memberSymbol = semanticModel.GetDeclaredSymbol(member); + + if (memberSymbol is null) + return null; + + // Verify [Projectable] attribute + var projectableAttributeTypeSymbol = compilation.GetTypeByMetadataName("EntityFrameworkCore.Projectables.ProjectableAttribute"); + var projectableAttribute = memberSymbol.GetAttributes() + .FirstOrDefault(x => x.AttributeClass?.Name == "ProjectableAttribute"); + + if (projectableAttribute is null || + !SymbolEqualityComparer.Default.Equals(projectableAttribute.AttributeClass, projectableAttributeTypeSymbol)) + return null; + + // Skip C# 14 extension type members — they require special handling (fall back to reflection) + if (memberSymbol.ContainingType is { IsExtension: true }) + return null; + + var containingType = memberSymbol.ContainingType; + bool isGenericClass = containingType.TypeParameters.Length > 0; + + // Determine member kind and lookup name + string memberKind; + string memberLookupName; + ImmutableArray parameterTypeNames = ImmutableArray.Empty; + int methodTypeParamCount = 0; + bool isGenericMethod = false; + + if (memberSymbol is IMethodSymbol methodSymbol) + { + isGenericMethod = methodSymbol.TypeParameters.Length > 0; + methodTypeParamCount = methodSymbol.TypeParameters.Length; + + if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor) + { + memberKind = "Constructor"; + memberLookupName = "_ctor"; + } + else + { + memberKind = "Method"; + memberLookupName = memberSymbol.Name; + } + + parameterTypeNames = methodSymbol.Parameters + .Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + .ToImmutableArray(); + } + else + { + memberKind = "Property"; + memberLookupName = memberSymbol.Name; + } + + // Build the generated class name using the same logic as Execute + string? classNamespace = containingType.ContainingNamespace.IsGlobalNamespace + ? null + : containingType.ContainingNamespace.ToDisplayString(); + + var nestedTypePath = GetRegistryNestedTypePath(containingType); + + var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName( + classNamespace, + nestedTypePath, + memberLookupName, + parameterTypeNames.IsEmpty ? null : (IEnumerable)parameterTypeNames); + + var generatedClassFullName = "EntityFrameworkCore.Projectables.Generated." + generatedClassName; + + var declaringTypeFullName = containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + return new ProjectableRegistryEntry( + DeclaringTypeFullName: declaringTypeFullName, + MemberKind: memberKind, + MemberLookupName: memberLookupName, + GeneratedClassFullName: generatedClassFullName, + IsGenericClass: isGenericClass, + ClassTypeParamCount: containingType.TypeParameters.Length, + IsGenericMethod: isGenericMethod, + MethodTypeParamCount: methodTypeParamCount, + ParameterTypeNames: parameterTypeNames); + } + + static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) + { + if (typeSymbol.ContainingType is not null) + { + foreach (var name in GetRegistryNestedTypePath(typeSymbol.ContainingType)) + yield return name; + } + yield return typeSymbol.Name; + } + + /// + /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members + /// into a single static dictionary keyed by . + /// + static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) + { + var validEntries = entries + .Where(e => e is not null) + .Select(e => e!) + .ToList(); + + if (validEntries.Count == 0) + return; + + var sb = new StringBuilder(); + sb.AppendLine("// "); + sb.AppendLine("#nullable disable"); + sb.AppendLine("using System;"); + sb.AppendLine("using System.Collections.Generic;"); + sb.AppendLine("using System.Linq.Expressions;"); + sb.AppendLine("using System.Reflection;"); + sb.AppendLine(); + sb.AppendLine("namespace EntityFrameworkCore.Projectables.Generated"); + sb.AppendLine("{"); + sb.AppendLine(" [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); + sb.AppendLine(" internal static class ProjectionRegistry"); + sb.AppendLine(" {"); + sb.AppendLine(" // Keyed by RuntimeMethodHandle.Value (a stable nint pointer for the method/getter/ctor)."); + sb.AppendLine(" // Populated once at type initialization; shared across the entire AppDomain lifetime."); + sb.AppendLine(" private static readonly Dictionary _map = Build();"); + sb.AppendLine(); + sb.AppendLine(" /// "); + sb.AppendLine(" /// Returns the pre-built LambdaExpression for the given [Projectable] member,"); + sb.AppendLine(" /// or null if the member is not registered (e.g. open-generic members)."); + sb.AppendLine(" /// "); + sb.AppendLine(" public static LambdaExpression TryGet(MemberInfo member)"); + sb.AppendLine(" {"); + sb.AppendLine(" var handle = GetHandle(member);"); + sb.AppendLine(" return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" private static nint? GetHandle(MemberInfo member) => member switch"); + sb.AppendLine(" {"); + sb.AppendLine(" MethodInfo m => m.MethodHandle.Value,"); + sb.AppendLine(" PropertyInfo p => p.GetMethod?.MethodHandle.Value,"); + sb.AppendLine(" ConstructorInfo c => c.MethodHandle.Value,"); + sb.AppendLine(" _ => null"); + sb.AppendLine(" };"); + sb.AppendLine(); + sb.AppendLine(" private static Dictionary Build()"); + sb.AppendLine(" {"); + sb.AppendLine(" var map = new Dictionary();"); + + foreach (var entry in validEntries) + { + EmitRegistryEntry(sb, entry); + } + + sb.AppendLine(" return map;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine("}"); + + context.AddSource("ProjectionRegistry.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + } + + static void EmitRegistryEntry(StringBuilder sb, ProjectableRegistryEntry entry) + { + if (entry.IsGenericClass) + { + sb.AppendLine($" // TODO: generic class — {entry.GeneratedClassFullName} (falls back to reflection)"); + return; + } + + if (entry.IsGenericMethod) + { + sb.AppendLine($" // TODO: generic method — {entry.GeneratedClassFullName} (falls back to reflection)"); + return; + } + + const string flags = "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static"; + + sb.AppendLine(" {"); + sb.AppendLine($" var t = typeof({entry.DeclaringTypeFullName});"); + + switch (entry.MemberKind) + { + case "Property": + sb.AppendLine($" var m = t.GetProperty(\"{entry.MemberLookupName}\", {flags})?.GetMethod;"); + break; + + case "Method": + { + var typeArray = BuildTypeArray(entry.ParameterTypeNames); + sb.AppendLine($" var m = t.GetMethod(\"{entry.MemberLookupName}\", {flags}, null, {typeArray}, null);"); + break; + } + + case "Constructor": + { + var typeArray = BuildTypeArray(entry.ParameterTypeNames); + sb.AppendLine($" var m = t.GetConstructor({flags}, null, {typeArray}, null);"); + break; + } + + default: + sb.AppendLine(" }"); + return; + } + + sb.AppendLine(" if (m is not null)"); + sb.AppendLine(" {"); + sb.AppendLine($" var exprType = t.Assembly.GetType(\"{entry.GeneratedClassFullName}\");"); + sb.AppendLine(" var exprMethod = exprType?.GetMethod(\"Expression\", global::System.Reflection.BindingFlags.Static | global::System.Reflection.BindingFlags.NonPublic);"); + sb.AppendLine(" if (exprMethod is not null)"); + sb.AppendLine(" map[m.MethodHandle.Value] = (global::System.Linq.Expressions.LambdaExpression)exprMethod.Invoke(null, null)!;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + } + + static string BuildTypeArray(ImmutableArray parameterTypeNames) + { + if (parameterTypeNames.IsEmpty) + return "global::System.Type.EmptyTypes"; + + var sb = new StringBuilder("new global::System.Type[] { "); + for (int i = 0; i < parameterTypeNames.Length; i++) + { + if (i > 0) sb.Append(", "); + sb.Append($"typeof({parameterTypeNames[i]})"); + } + sb.Append(" }"); + return sb.ToString(); + } + } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 16e547a..e342a6e 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -21,27 +21,19 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor private readonly bool _trackingByDefault; private IEntityType? _entityType; - private readonly MethodInfo _select; - private readonly MethodInfo _where; + // Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain) + private static readonly MethodInfo _select = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition(); + + private static readonly MethodInfo _where = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition(); public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false) { _trackingByDefault = trackByDefault; _resolver = projectionExpressionResolver; - _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Select)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); - _where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Where)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 7e26d69..c0186f9 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -8,6 +9,26 @@ namespace EntityFrameworkCore.Projectables.Services { public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver { + // Cache per-assembly registry delegate: Assembly → Func + // After first lookup, subsequent calls do a fast lock-free read. + private static readonly ConcurrentDictionary?> _assemblyRegistries = new(); + + /// + /// Looks up the generated ProjectionRegistry class in an assembly (once, then caches it). + /// Returns a delegate that calls TryGet(MemberInfo) on the registry, or null if the registry + /// is not present in that assembly (e.g. if the source generator was not run against it). + /// + private static Func? GetAssemblyRegistry(Assembly assembly) => + _assemblyRegistries.GetOrAdd(assembly, static asm => + { + var registryType = asm.GetType("EntityFrameworkCore.Projectables.Generated.ProjectionRegistry"); + if (registryType is null) return null; + var tryGetMethod = registryType.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); + if (tryGetMethod is null) return null; + return (Func)Delegate.CreateDelegate( + typeof(Func), tryGetMethod); + }); + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) { var projectableAttribute = projectableMemberInfo.GetCustomAttribute() @@ -62,7 +83,21 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); - + + // Fast path: check the per-assembly static registry (generated by source generator). + // The first call per assembly does a reflection lookup to find the registry class and + // caches it as a delegate; subsequent calls use the cached delegate for an O(1) dictionary lookup. + var registry = GetAssemblyRegistry(declaringType.Assembly); + if (registry is not null) + { + var registeredExpr = registry(projectableMemberInfo); + if (registeredExpr is not null) + return registeredExpr; + } + + // Slow path: reflection fallback for open-generic class members and generic methods + // that are not yet in the registry. + // Keep track of the original declaring type's generic arguments for later use var originalDeclaringType = declaringType; diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index 9a791e2..c442005 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -1,5 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -17,6 +18,47 @@ protected ProjectionExpressionGeneratorTestsBase(ITestOutputHelper testOutputHel _testOutputHelper = testOutputHelper; } + /// + /// Wraps and exposes + /// as a filtered view that excludes the generated ProjectionRegistry.g.cs file. + /// This keeps all existing tests working without modification after the registry was added. + /// + protected sealed class TestGeneratorRunResult + { + private readonly GeneratorDriverRunResult _inner; + + public TestGeneratorRunResult(GeneratorDriverRunResult inner) + { + _inner = inner; + } + + /// + /// Diagnostics from the generator run. + /// + public ImmutableArray Diagnostics => _inner.Diagnostics; + + /// + /// Generated trees excluding ProjectionRegistry.g.cs. + /// Existing tests use this and should continue to work without modification. + /// + public ImmutableArray GeneratedTrees => + _inner.GeneratedTrees + .Where(t => !t.FilePath.EndsWith("ProjectionRegistry.g.cs")) + .ToImmutableArray(); + + /// + /// All generated trees including ProjectionRegistry.g.cs. + /// Use this in new tests that need to verify the registry. + /// + public ImmutableArray AllGeneratedTrees => _inner.GeneratedTrees; + + /// + /// The generated ProjectionRegistry.g.cs tree, or null if it was not generated. + /// + public SyntaxTree? RegistryTree => + _inner.GeneratedTrees.FirstOrDefault(t => t.FilePath.EndsWith("ProjectionRegistry.g.cs")); + } + protected Compilation CreateCompilation([StringSyntax("csharp")] string source) { var references = Basic.Reference.Assemblies. @@ -58,7 +100,7 @@ protected Compilation CreateCompilation([StringSyntax("csharp")] string source) return compilation; } - protected GeneratorDriverRunResult RunGenerator(Compilation compilation) + protected TestGeneratorRunResult RunGenerator(Compilation compilation) { _testOutputHelper.WriteLine("Running generator and updating compilation..."); @@ -67,7 +109,8 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) .Create(subject) .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); if (result.Diagnostics.IsEmpty) { @@ -83,7 +126,7 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) } } - foreach (var newSyntaxTree in result.GeneratedTrees) + foreach (var newSyntaxTree in result.AllGeneratedTrees) { _testOutputHelper.WriteLine($"Produced syntax tree with path produced: {newSyntaxTree.FilePath}"); _testOutputHelper.WriteLine(newSyntaxTree.GetText().ToString()); @@ -91,7 +134,7 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) // Verify that the generated code compiles without errors var hasGeneratorErrors = result.Diagnostics.Any(d => d.Severity == DiagnosticSeverity.Error); - if (!hasGeneratorErrors && result.GeneratedTrees.Length > 0) + if (!hasGeneratorErrors && result.AllGeneratedTrees.Length > 0) { _testOutputHelper.WriteLine("Checking that generated code compiles..."); From c00bee603ea100bf55975b8338b3898a2170d976 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 21:33:55 +0000 Subject: [PATCH 03/12] Use SyntaxFactory for registry emission; expose reflection path; add benchmark comparison Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com> --- .../ExpressionResolverBenchmark.cs | 26 +- .../ProjectionExpressionGenerator.cs | 380 +++++++++++++----- .../Services/ProjectionExpressionResolver.cs | 277 +++++++------ 3 files changed, 442 insertions(+), 241 deletions(-) diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs index e2803de..3ab54d7 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs @@ -8,8 +8,8 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { /// /// Micro-benchmarks in - /// isolation (no EF Core overhead) to directly compare the registry lookup path against - /// the previous per-call reflection chain. + /// isolation (no EF Core overhead) to directly compare the static registry path against + /// the reflection-based path (). /// [MemoryDiagnoser] public class ExpressionResolverBenchmark @@ -25,16 +25,32 @@ public class ExpressionResolverBenchmark private readonly ProjectionExpressionResolver _resolver = new(); + // ── Registry (source-generated) path ───────────────────────────────── + [Benchmark(Baseline = true)] - public LambdaExpression? ResolveProperty() + public LambdaExpression? ResolveProperty_Registry() => _resolver.FindGeneratedExpression(_propertyMember); [Benchmark] - public LambdaExpression? ResolveMethod() + public LambdaExpression? ResolveMethod_Registry() => _resolver.FindGeneratedExpression(_methodMember); [Benchmark] - public LambdaExpression? ResolveMethodWithParam() + public LambdaExpression? ResolveMethodWithParam_Registry() => _resolver.FindGeneratedExpression(_methodWithParamMember); + + // ── Reflection path ─────────────────────────────────────────────────── + + [Benchmark] + public LambdaExpression? ResolveProperty_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodWithParamMember); } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 8f0c9bd..e4580af 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -348,6 +348,7 @@ static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol /// /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members /// into a single static dictionary keyed by . + /// Uses SyntaxFactory for the class/method/field structure, consistent with . /// static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) { @@ -359,125 +360,298 @@ static void EmitRegistry(ImmutableArray entries, Sour if (validEntries.Count == 0) return; - var sb = new StringBuilder(); - sb.AppendLine("// "); - sb.AppendLine("#nullable disable"); - sb.AppendLine("using System;"); - sb.AppendLine("using System.Collections.Generic;"); - sb.AppendLine("using System.Linq.Expressions;"); - sb.AppendLine("using System.Reflection;"); - sb.AppendLine(); - sb.AppendLine("namespace EntityFrameworkCore.Projectables.Generated"); - sb.AppendLine("{"); - sb.AppendLine(" [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); - sb.AppendLine(" internal static class ProjectionRegistry"); - sb.AppendLine(" {"); - sb.AppendLine(" // Keyed by RuntimeMethodHandle.Value (a stable nint pointer for the method/getter/ctor)."); - sb.AppendLine(" // Populated once at type initialization; shared across the entire AppDomain lifetime."); - sb.AppendLine(" private static readonly Dictionary _map = Build();"); - sb.AppendLine(); - sb.AppendLine(" /// "); - sb.AppendLine(" /// Returns the pre-built LambdaExpression for the given [Projectable] member,"); - sb.AppendLine(" /// or null if the member is not registered (e.g. open-generic members)."); - sb.AppendLine(" /// "); - sb.AppendLine(" public static LambdaExpression TryGet(MemberInfo member)"); - sb.AppendLine(" {"); - sb.AppendLine(" var handle = GetHandle(member);"); - sb.AppendLine(" return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null;"); - sb.AppendLine(" }"); - sb.AppendLine(); - sb.AppendLine(" private static nint? GetHandle(MemberInfo member) => member switch"); - sb.AppendLine(" {"); - sb.AppendLine(" MethodInfo m => m.MethodHandle.Value,"); - sb.AppendLine(" PropertyInfo p => p.GetMethod?.MethodHandle.Value,"); - sb.AppendLine(" ConstructorInfo c => c.MethodHandle.Value,"); - sb.AppendLine(" _ => null"); - sb.AppendLine(" };"); - sb.AppendLine(); - sb.AppendLine(" private static Dictionary Build()"); - sb.AppendLine(" {"); - sb.AppendLine(" var map = new Dictionary();"); + // Build the Build() method body: one block per valid (non-generic) entry + var buildStatements = new List + { + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("map") + .WithInitializer(EqualsValueClause( + ObjectCreationExpression( + ParseTypeName("Dictionary")) + .WithArgumentList(ArgumentList()))))), + }; foreach (var entry in validEntries) { - EmitRegistryEntry(sb, entry); + var block = BuildRegistryEntryBlock(entry); + if (block is not null) + buildStatements.Add(block); } - sb.AppendLine(" return map;"); - sb.AppendLine(" }"); - sb.AppendLine(" }"); - sb.AppendLine("}"); + buildStatements.Add(ReturnStatement(IdentifierName("map"))); - context.AddSource("ProjectionRegistry.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + var classSyntax = ClassDeclaration("ProjectionRegistry") + .WithModifiers(TokenList( + Token(SyntaxKind.InternalKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddAttributeLists(AttributeList().AddAttributes(_editorBrowsableAttribute)) + .AddMembers( + // private static readonly Dictionary _map = Build(); + FieldDeclaration( + VariableDeclaration(ParseTypeName("Dictionary")) + .AddVariables( + VariableDeclarator("_map") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("Build")) + .WithArgumentList(ArgumentList()))))) + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.ReadOnlyKeyword))), + + // public static LambdaExpression TryGet(MemberInfo member) + MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet") + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("member")) + .WithType(ParseTypeName("MemberInfo"))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("handle") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("GetHandle")) + .AddArgumentListArguments( + Argument(IdentifierName("member"))))))), + ReturnStatement( + ParseExpression( + "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))), + + // private static nint? GetHandle(MemberInfo member) => member switch { ... }; + MethodDeclaration(ParseTypeName("nint?"), "GetHandle") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("member")) + .WithType(ParseTypeName("MemberInfo"))) + .WithExpressionBody(ArrowExpressionClause( + SwitchExpression(IdentifierName("member")) + .WithArms(SeparatedList( + new SyntaxNodeOrToken[] + { + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("MethodInfo"), + SingleVariableDesignation(Identifier("m"))), + ParseExpression("m.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("PropertyInfo"), + SingleVariableDesignation(Identifier("p"))), + ParseExpression("p.GetMethod?.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("ConstructorInfo"), + SingleVariableDesignation(Identifier("c"))), + ParseExpression("c.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DiscardPattern(), + LiteralExpression(SyntaxKind.NullLiteralExpression)) + })))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)), + + // private static Dictionary Build() + MethodDeclaration(ParseTypeName("Dictionary"), "Build") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .WithBody(Block(buildStatements))); + + var compilationUnit = CompilationUnit() + .AddUsings( + UsingDirective(ParseName("System")), + UsingDirective(ParseName("System.Collections.Generic")), + UsingDirective(ParseName("System.Linq.Expressions")), + UsingDirective(ParseName("System.Reflection"))) + .AddMembers( + NamespaceDeclaration(ParseName("EntityFrameworkCore.Projectables.Generated")) + .AddMembers(classSyntax)) + .WithLeadingTrivia(TriviaList( + Comment("// "), + Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))); + + context.AddSource("ProjectionRegistry.g.cs", + SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); } - static void EmitRegistryEntry(StringBuilder sb, ProjectableRegistryEntry entry) + /// + /// Builds the registration block for a single projectable entry inside Build(). + /// Returns for generic class/method entries (they fall back to reflection). + /// + static BlockSyntax? BuildRegistryEntryBlock(ProjectableRegistryEntry entry) { - if (entry.IsGenericClass) - { - sb.AppendLine($" // TODO: generic class — {entry.GeneratedClassFullName} (falls back to reflection)"); - return; - } - - if (entry.IsGenericMethod) - { - sb.AppendLine($" // TODO: generic method — {entry.GeneratedClassFullName} (falls back to reflection)"); - return; - } - - const string flags = "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Static"; - - sb.AppendLine(" {"); - sb.AppendLine($" var t = typeof({entry.DeclaringTypeFullName});"); + if (entry.IsGenericClass || entry.IsGenericMethod) + return null; - switch (entry.MemberKind) + var bindingFlagsExpr = ParseExpression( + "global::System.Reflection.BindingFlags.Public | " + + "global::System.Reflection.BindingFlags.NonPublic | " + + "global::System.Reflection.BindingFlags.Instance | " + + "global::System.Reflection.BindingFlags.Static"); + + // var t = typeof(DeclaringType); + var tDecl = LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("t") + .WithInitializer(EqualsValueClause( + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)))))); + + // var m = t.GetProperty(...) / t.GetMethod(...) / t.GetConstructor(...); + StatementSyntax? mDecl = entry.MemberKind switch { - case "Property": - sb.AppendLine($" var m = t.GetProperty(\"{entry.MemberLookupName}\", {flags})?.GetMethod;"); - break; - - case "Method": - { - var typeArray = BuildTypeArray(entry.ParameterTypeNames); - sb.AppendLine($" var m = t.GetMethod(\"{entry.MemberLookupName}\", {flags}, null, {typeArray}, null);"); - break; - } - - case "Constructor": - { - var typeArray = BuildTypeArray(entry.ParameterTypeNames); - sb.AppendLine($" var m = t.GetConstructor({flags}, null, {typeArray}, null);"); - break; - } - - default: - sb.AppendLine(" }"); - return; - } + "Property" => LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("m") + .WithInitializer(EqualsValueClause( + ConditionalAccessExpression( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("t"), + IdentifierName("GetProperty"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(bindingFlagsExpr)), + MemberBindingExpression(IdentifierName("GetMethod"))))))), + + "Method" => LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("m") + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("t"), + IdentifierName("GetMethod"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(bindingFlagsExpr), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))))))), + + "Constructor" => LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("m") + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("t"), + IdentifierName("GetConstructor"))) + .AddArgumentListArguments( + Argument(bindingFlagsExpr), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))))))), + + _ => null + }; + + if (mDecl is null) + return null; - sb.AppendLine(" if (m is not null)"); - sb.AppendLine(" {"); - sb.AppendLine($" var exprType = t.Assembly.GetType(\"{entry.GeneratedClassFullName}\");"); - sb.AppendLine(" var exprMethod = exprType?.GetMethod(\"Expression\", global::System.Reflection.BindingFlags.Static | global::System.Reflection.BindingFlags.NonPublic);"); - sb.AppendLine(" if (exprMethod is not null)"); - sb.AppendLine(" map[m.MethodHandle.Value] = (global::System.Linq.Expressions.LambdaExpression)exprMethod.Invoke(null, null)!;"); - sb.AppendLine(" }"); - sb.AppendLine(" }"); + // var exprType = t.Assembly.GetType("GeneratedClassFullName"); + var exprTypeDecl = LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprType") + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("t"), + IdentifierName("Assembly")), + IdentifierName("GetType"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.GeneratedClassFullName)))))))); + + // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + var exprMethodDecl = LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprMethod") + .WithInitializer(EqualsValueClause( + ConditionalAccessExpression( + IdentifierName("exprType"), + InvocationExpression( + MemberBindingExpression(IdentifierName("GetMethod"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal("Expression"))), + Argument(ParseExpression( + "global::System.Reflection.BindingFlags.Static | " + + "global::System.Reflection.BindingFlags.NonPublic")))))))); + + // if (exprMethod != null) + // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + var ifExprMethod = IfStatement( + BinaryExpression(SyntaxKind.NotEqualsExpression, + IdentifierName("exprMethod"), + LiteralExpression(SyntaxKind.NullLiteralExpression)), + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + ElementAccessExpression(IdentifierName("map")) + .AddArgumentListArguments( + Argument( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("m"), + IdentifierName("MethodHandle")), + IdentifierName("Value")))), + CastExpression( + ParseTypeName("global::System.Linq.Expressions.LambdaExpression"), + PostfixUnaryExpression(SyntaxKind.SuppressNullableWarningExpression, + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("exprMethod"), + IdentifierName("Invoke"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)))))))); + + // if (m != null) { exprType decl; exprMethod decl; if (exprMethod != null) ... } + var ifM = IfStatement( + BinaryExpression(SyntaxKind.NotEqualsExpression, + IdentifierName("m"), + LiteralExpression(SyntaxKind.NullLiteralExpression)), + Block(exprTypeDecl, exprMethodDecl, ifExprMethod)); + + return Block(tDecl, mDecl, ifM); } - static string BuildTypeArray(ImmutableArray parameterTypeNames) + /// + /// Builds the typeof(...)-array expression used for reflection method/constructor lookup. + /// Returns global::System.Type.EmptyTypes when there are no parameters. + /// + static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames) { if (parameterTypeNames.IsEmpty) - return "global::System.Type.EmptyTypes"; - - var sb = new StringBuilder("new global::System.Type[] { "); - for (int i = 0; i < parameterTypeNames.Length; i++) - { - if (i > 0) sb.Append(", "); - sb.Append($"typeof({parameterTypeNames[i]})"); - } - sb.Append(" }"); - return sb.ToString(); + return ParseExpression("global::System.Type.EmptyTypes"); + + var typeofExprs = parameterTypeNames + .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name))) + .ToArray(); + + return ArrayCreationExpression( + ArrayType(ParseTypeName("global::System.Type")) + .AddRankSpecifiers(ArrayRankSpecifier())) + .WithInitializer( + InitializerExpression(SyntaxKind.ArrayInitializerExpression, + SeparatedList(typeofExprs))); } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index c0186f9..89e0ac9 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -97,165 +97,176 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo // Slow path: reflection fallback for open-generic class members and generic methods // that are not yet in the registry. + return FindGeneratedExpressionViaReflection(projectableMemberInfo); + } + } - // Keep track of the original declaring type's generic arguments for later use - var originalDeclaringType = declaringType; - - // For generic types, use the generic type definition to match the generated name - // which is based on the open generic type - if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) - { - declaringType = declaringType.GetGenericTypeDefinition(); - } - - // Get parameter types for method overload disambiguation - // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat - // which uses C# keywords for primitive types (int, string, etc.) - string[]? parameterTypeNames = null; - string memberLookupName = projectableMemberInfo.Name; - if (projectableMemberInfo is MethodInfo method) - { - // For generic methods, use the generic definition to get parameter types - // This ensures type parameters like TEntity are used instead of concrete types - var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; - - parameterTypeNames = methodToInspect.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); - } - else if (projectableMemberInfo is ConstructorInfo ctor) + /// + /// Resolves the for a [Projectable] member using the + /// reflection-based slow path only, bypassing the static registry. + /// Useful for benchmarking and for members not yet in the registry (e.g. open-generic types). + /// + public static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo) + { + var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + + // Keep track of the original declaring type's generic arguments for later use + var originalDeclaringType = declaringType; + + // For generic types, use the generic type definition to match the generated name + // which is based on the open generic type + if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) + { + declaringType = declaringType.GetGenericTypeDefinition(); + } + + // Get parameter types for method overload disambiguation + // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat + // which uses C# keywords for primitive types (int, string, etc.) + string[]? parameterTypeNames = null; + string memberLookupName = projectableMemberInfo.Name; + if (projectableMemberInfo is MethodInfo method) + { + // For generic methods, use the generic definition to get parameter types + // This ensures type parameters like TEntity are used instead of concrete types + var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; + + parameterTypeNames = methodToInspect.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + else if (projectableMemberInfo is ConstructorInfo ctor) + { + // Constructors are stored under the synthetic name "_ctor" + memberLookupName = "_ctor"; + parameterTypeNames = ctor.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + + var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); + + var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + + if (expressionFactoryType is not null) + { + if (expressionFactoryType.IsGenericTypeDefinition) { - // Constructors are stored under the synthetic name "_ctor" - memberLookupName = "_ctor"; - parameterTypeNames = ctor.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); + expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); } - - var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); - var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + + var methodGenericArguments = projectableMemberInfo switch { + MethodInfo methodInfo => methodInfo.GetGenericArguments(), + _ => null + }; - if (expressionFactoryType is not null) + if (expressionFactoryMethod is not null) { - if (expressionFactoryType.IsGenericTypeDefinition) + if (methodGenericArguments is { Length: > 0 }) { - expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); + expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); } - var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - - var methodGenericArguments = projectableMemberInfo switch { - MethodInfo methodInfo => methodInfo.GetGenericArguments(), - _ => null - }; + return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); + } + } - if (expressionFactoryMethod is not null) - { - if (methodGenericArguments is { Length: > 0 }) - { - expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); - } + return null; + } - return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); - } - } + private static string GetFullTypeName(Type type) + { + // Handle generic type parameters (e.g., T, TEntity) + if (type.IsGenericParameter) + { + return type.Name; + } - return null; + // Handle nullable value types (e.g., int? -> int?) + var underlyingType = Nullable.GetUnderlyingType(type); + if (underlyingType != null) + { + return $"{GetFullTypeName(underlyingType)}?"; } - - static string GetFullTypeName(Type type) + + // Handle array types + if (type.IsArray) { - // Handle generic type parameters (e.g., T, TEntity) - if (type.IsGenericParameter) + var elementType = type.GetElementType(); + if (elementType == null) { + // Fallback for edge cases where GetElementType() might return null return type.Name; } - - // Handle nullable value types (e.g., int? -> int?) - var underlyingType = Nullable.GetUnderlyingType(type); - if (underlyingType != null) + + var rank = type.GetArrayRank(); + var elementTypeName = GetFullTypeName(elementType); + + if (rank == 1) { - return $"{GetFullTypeName(underlyingType)}?"; + return $"{elementTypeName}[]"; } - - // Handle array types - if (type.IsArray) + else { - var elementType = type.GetElementType(); - if (elementType == null) - { - // Fallback for edge cases where GetElementType() might return null - return type.Name; - } - - var rank = type.GetArrayRank(); - var elementTypeName = GetFullTypeName(elementType); - - if (rank == 1) - { - return $"{elementTypeName}[]"; - } - else - { - var commas = new string(',', rank - 1); - return $"{elementTypeName}[{commas}]"; - } - } - - // Map primitive types to their C# keyword equivalents to match Roslyn's output - var typeKeyword = GetCSharpKeyword(type); - if (typeKeyword != null) - { - return typeKeyword; + var commas = new string(',', rank - 1); + return $"{elementTypeName}[{commas}]"; } - - // For generic types, construct the full name matching Roslyn's format - if (type.IsGenericType) - { - var genericTypeDef = type.GetGenericTypeDefinition(); - var genericArgs = type.GetGenericArguments(); - var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; - - // Remove the `n suffix (e.g., `1, `2) - var backtickIndex = baseName.IndexOf('`'); - if (backtickIndex > 0) - { - baseName = baseName.Substring(0, backtickIndex); - } - - var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); - return $"{baseName}<{args}>"; - } - - if (type.FullName != null) + } + + // Map primitive types to their C# keyword equivalents to match Roslyn's output + var typeKeyword = GetCSharpKeyword(type); + if (typeKeyword != null) + { + return typeKeyword; + } + + // For generic types, construct the full name matching Roslyn's format + if (type.IsGenericType) + { + var genericTypeDef = type.GetGenericTypeDefinition(); + var genericArgs = type.GetGenericArguments(); + var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; + + // Remove the `n suffix (e.g., `1, `2) + var backtickIndex = baseName.IndexOf('`'); + if (backtickIndex > 0) { - // Replace + with . for nested types to match Roslyn's format - return type.FullName.Replace('+', '.'); + baseName = baseName.Substring(0, backtickIndex); } - - return type.Name; + + var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); + return $"{baseName}<{args}>"; } - - static string? GetCSharpKeyword(Type type) + + if (type.FullName != null) { - if (type == typeof(bool)) return "bool"; - if (type == typeof(byte)) return "byte"; - if (type == typeof(sbyte)) return "sbyte"; - if (type == typeof(char)) return "char"; - if (type == typeof(decimal)) return "decimal"; - if (type == typeof(double)) return "double"; - if (type == typeof(float)) return "float"; - if (type == typeof(int)) return "int"; - if (type == typeof(uint)) return "uint"; - if (type == typeof(long)) return "long"; - if (type == typeof(ulong)) return "ulong"; - if (type == typeof(short)) return "short"; - if (type == typeof(ushort)) return "ushort"; - if (type == typeof(object)) return "object"; - if (type == typeof(string)) return "string"; - return null; + // Replace + with . for nested types to match Roslyn's format + return type.FullName.Replace('+', '.'); } + + return type.Name; + } + + private static string? GetCSharpKeyword(Type type) + { + if (type == typeof(bool)) return "bool"; + if (type == typeof(byte)) return "byte"; + if (type == typeof(sbyte)) return "sbyte"; + if (type == typeof(char)) return "char"; + if (type == typeof(decimal)) return "decimal"; + if (type == typeof(double)) return "double"; + if (type == typeof(float)) return "float"; + if (type == typeof(int)) return "int"; + if (type == typeof(uint)) return "uint"; + if (type == typeof(long)) return "long"; + if (type == typeof(ulong)) return "ulong"; + if (type == typeof(short)) return "short"; + if (type == typeof(ushort)) return "ushort"; + if (type == typeof(object)) return "object"; + if (type == typeof(string)) return "string"; + return null; } } } From 7e49dc7543a163863036a8a4642604182d6e25a1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:43:44 +0000 Subject: [PATCH 04/12] Compact registry generation: Register helper + RegistryTests Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com> --- .../ProjectionExpressionGenerator.cs | 282 +++++++++--------- .../RegistryTests.cs | 191 ++++++++++++ 2 files changed, 326 insertions(+), 147 deletions(-) create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index e4580af..4df1521 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -349,20 +349,42 @@ static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members /// into a single static dictionary keyed by . /// Uses SyntaxFactory for the class/method/field structure, consistent with . + /// The generated Build() method uses a shared Register helper to avoid repeating + /// the lookup boilerplate for every entry. /// static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) { - var validEntries = entries + // Build the per-entry Register(...) statements first so we can bail out early + // if every entry is generic (they all fall back to reflection, no registry needed). + var entryStatements = entries .Where(e => e is not null) - .Select(e => e!) + .Select(e => BuildRegistryEntryStatement(e!)) + .Where(s => s is not null) + .Select(s => s!) .ToList(); - if (validEntries.Count == 0) + if (entryStatements.Count == 0) return; - // Build the Build() method body: one block per valid (non-generic) entry + // Build() body: + // const BindingFlags allFlags = ...; + // var map = new Dictionary(); + // Register(map, typeof(T).GetXxx(...), "ClassName"); ← one line per entry + // return map; var buildStatements = new List { + // const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | ...; + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("BindingFlags")) + .AddVariables( + VariableDeclarator("allFlags") + .WithInitializer(EqualsValueClause( + ParseExpression( + "BindingFlags.Public | BindingFlags.NonPublic | " + + "BindingFlags.Instance | BindingFlags.Static"))))) + .WithModifiers(TokenList(Token(SyntaxKind.ConstKeyword))), + + // var map = new Dictionary(); LocalDeclarationStatement( VariableDeclaration(ParseTypeName("var")) .AddVariables( @@ -373,13 +395,7 @@ static void EmitRegistry(ImmutableArray entries, Sour .WithArgumentList(ArgumentList()))))), }; - foreach (var entry in validEntries) - { - var block = BuildRegistryEntryBlock(entry); - if (block is not null) - buildStatements.Add(block); - } - + buildStatements.AddRange(entryStatements); buildStatements.Add(ReturnStatement(IdentifierName("map"))); var classSyntax = ClassDeclaration("ProjectionRegistry") @@ -459,12 +475,15 @@ static void EmitRegistry(ImmutableArray entries, Sour })))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)), - // private static Dictionary Build() + // private static Dictionary Build() { ... } MethodDeclaration(ParseTypeName("Dictionary"), "Build") .WithModifiers(TokenList( Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword))) - .WithBody(Block(buildStatements))); + .WithBody(Block(buildStatements)), + + // private static void Register(Dictionary map, MethodBase? m, string exprClass) + BuildRegisterHelperMethod()); var compilationUnit = CompilationUnit() .AddUsings( @@ -484,155 +503,124 @@ static void EmitRegistry(ImmutableArray entries, Sour } /// - /// Builds the registration block for a single projectable entry inside Build(). + /// Builds a single compact Register(map, typeof(T).GetXxx(...), "ClassName") + /// statement for one projectable entry in Build(). /// Returns for generic class/method entries (they fall back to reflection). /// - static BlockSyntax? BuildRegistryEntryBlock(ProjectableRegistryEntry entry) + static StatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) { if (entry.IsGenericClass || entry.IsGenericMethod) return null; - var bindingFlagsExpr = ParseExpression( - "global::System.Reflection.BindingFlags.Public | " + - "global::System.Reflection.BindingFlags.NonPublic | " + - "global::System.Reflection.BindingFlags.Instance | " + - "global::System.Reflection.BindingFlags.Static"); - - // var t = typeof(DeclaringType); - var tDecl = LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("t") - .WithInitializer(EqualsValueClause( - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)))))); - - // var m = t.GetProperty(...) / t.GetMethod(...) / t.GetConstructor(...); - StatementSyntax? mDecl = entry.MemberKind switch + // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...) + ExpressionSyntax? memberCallExpr = entry.MemberKind switch { - "Property" => LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("m") - .WithInitializer(EqualsValueClause( - ConditionalAccessExpression( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("t"), - IdentifierName("GetProperty"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.MemberLookupName))), - Argument(bindingFlagsExpr)), - MemberBindingExpression(IdentifierName("GetMethod"))))))), - - "Method" => LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("m") - .WithInitializer(EqualsValueClause( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("t"), - IdentifierName("GetMethod"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.MemberLookupName))), - Argument(bindingFlagsExpr), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), - Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))))))), - - "Constructor" => LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("m") - .WithInitializer(EqualsValueClause( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("t"), - IdentifierName("GetConstructor"))) - .AddArgumentListArguments( - Argument(bindingFlagsExpr), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), - Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))))))), + // typeof(T).GetProperty("Name", allFlags)?.GetMethod + "Property" => ConditionalAccessExpression( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetProperty"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(IdentifierName("allFlags"))), + MemberBindingExpression(IdentifierName("GetMethod"))), + + // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null) + "Method" => InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetMethod"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(IdentifierName("allFlags")), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), + + // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null) + "Constructor" => InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetConstructor"))) + .AddArgumentListArguments( + Argument(IdentifierName("allFlags")), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), _ => null }; - if (mDecl is null) + if (memberCallExpr is null) return null; - // var exprType = t.Assembly.GetType("GeneratedClassFullName"); - var exprTypeDecl = LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("exprType") - .WithInitializer(EqualsValueClause( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("t"), - IdentifierName("Assembly")), - IdentifierName("GetType"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.GeneratedClassFullName)))))))); - - // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - var exprMethodDecl = LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("exprMethod") - .WithInitializer(EqualsValueClause( - ConditionalAccessExpression( - IdentifierName("exprType"), - InvocationExpression( - MemberBindingExpression(IdentifierName("GetMethod"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal("Expression"))), - Argument(ParseExpression( - "global::System.Reflection.BindingFlags.Static | " + - "global::System.Reflection.BindingFlags.NonPublic")))))))); - - // if (exprMethod != null) - // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; - var ifExprMethod = IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName("exprMethod"), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - ElementAccessExpression(IdentifierName("map")) - .AddArgumentListArguments( - Argument( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("m"), - IdentifierName("MethodHandle")), - IdentifierName("Value")))), - CastExpression( - ParseTypeName("global::System.Linq.Expressions.LambdaExpression"), - PostfixUnaryExpression(SyntaxKind.SuppressNullableWarningExpression, - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("exprMethod"), - IdentifierName("Invoke"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)))))))); - - // if (m != null) { exprType decl; exprMethod decl; if (exprMethod != null) ... } - var ifM = IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName("m"), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block(exprTypeDecl, exprMethodDecl, ifExprMethod)); - - return Block(tDecl, mDecl, ifM); + // Register(map, , ""); + return ExpressionStatement( + InvocationExpression(IdentifierName("Register")) + .AddArgumentListArguments( + Argument(IdentifierName("map")), + Argument(memberCallExpr), + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.GeneratedClassFullName))))); } + /// + /// Builds the Register private static helper method that all per-entry calls delegate to. + /// It handles the null checks and the common reflection lookup pattern once, centrally. + /// + static MethodDeclarationSyntax BuildRegisterHelperMethod() => + // private static void Register(Dictionary map, MethodBase? m, string exprClass) + // { + // if (m is null) return; + // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + // if (exprMethod is not null) + // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + // } + MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Register") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("map")) + .WithType(ParseTypeName("Dictionary")), + Parameter(Identifier("m")) + .WithType(ParseTypeName("MethodBase?")), + Parameter(Identifier("exprClass")) + .WithType(PredefinedType(Token(SyntaxKind.StringKeyword)))) + .WithBody(Block( + // if (m is null) return; + IfStatement( + IsPatternExpression( + IdentifierName("m"), + ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), + ReturnStatement()), + // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprType") + .WithInitializer(EqualsValueClause( + ParseExpression("m.DeclaringType?.Assembly.GetType(exprClass)"))))), + // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprMethod") + .WithInitializer(EqualsValueClause( + ParseExpression( + @"exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic)"))))), + // if (exprMethod is not null) + // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + IfStatement( + ParseExpression("exprMethod is not null"), + ExpressionStatement( + ParseExpression( + "map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!"))))); + /// /// Builds the typeof(...)-array expression used for reflection method/constructor lookup. /// Returns global::System.Type.EmptyTypes when there are no parameters. diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs new file mode 100644 index 0000000..27fe1f8 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs @@ -0,0 +1,191 @@ +using Xunit; +using Xunit.Abstractions; + +namespace EntityFrameworkCore.Projectables.Generator.Tests; + +public class RegistryTests : ProjectionExpressionGeneratorTestsBase +{ + public RegistryTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { } + + [Fact] + public void NoProjectables_NoRegistry() + { + var compilation = CreateCompilation(@"class C { }"); + var result = RunGenerator(compilation); + + Assert.Null(result.RegistryTree); + } + + [Fact] + public void SingleProperty_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + Assert.Contains("ProjectionRegistry", src); + // Uses the compact Register helper — not a repeated block + Assert.Contains("private static void Register(", src); + Assert.Contains("Register(map,", src); + Assert.Contains("GetProperty(\"IdPlus1\"", src); + Assert.Contains("Foo_C_IdPlus1", src); + } + + [Fact] + public void SingleMethod_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + Assert.Contains("GetMethod(\"AddDelta\"", src); + Assert.Contains("typeof(int)", src); + Assert.Contains("Foo_C_AddDelta_P0_int", src); + } + + [Fact] + public void MultipleProjectables_AllRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Two separate Register(map, ...) calls — one per projectable + Assert.Contains("GetProperty(\"IdPlus1\"", src); + Assert.Contains("GetMethod(\"AddDelta\"", src); + // Each entry is a single line, not a repeated multi-line block + var registerCallCount = CountOccurrences(src, "Register(map,"); + Assert.Equal(2, registerCallCount); + } + + [Fact] + public void GenericClass_NotIncludedInRegistry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + // Generic class members fall back to reflection — no registry emitted + Assert.Null(result.RegistryTree); + } + + [Fact] + public void Registry_ConstBindingFlagsUsedInBuild() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Build() uses a single const BindingFlags instead of repeating the flags per entry + Assert.Contains("const BindingFlags allFlags", src); + Assert.Contains("allFlags", src); + } + + [Fact] + public void Registry_RegisterHelperUsesDeclaringTypeAssembly() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Register helper derives the assembly from m.DeclaringType (no typeof repeated per entry) + Assert.Contains("m.DeclaringType?.Assembly.GetType(exprClass)", src); + } + + [Fact] + public void MethodOverloads_BothRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int Add(int delta) => Id + delta; + [Projectable] + public long Add(long delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Both overloads registered by parameter-type disambiguation + Assert.Contains("typeof(int)", src); + Assert.Contains("typeof(long)", src); + var registerCallCount = CountOccurrences(src, "Register(map,"); + Assert.Equal(2, registerCallCount); + } + + private static int CountOccurrences(string text, string pattern) + { + int count = 0; + int index = 0; + while ((index = text.IndexOf(pattern, index, System.StringComparison.Ordinal)) >= 0) + { + count++; + index += pattern.Length; + } + return count; + } +} From 1ea15c4853cd8e21e7cb2713cda4dc0691005feb Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Wed, 4 Mar 2026 18:11:15 +0100 Subject: [PATCH 05/12] Code cleanup and fix nullable issues --- .../ProjectableRegistryEntry.cs | 4 +- .../ProjectionExpressionGenerator.cs | 177 ++++++++++-------- 2 files changed, 97 insertions(+), 84 deletions(-) diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs index 4e27f8a..463c886 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -7,15 +7,13 @@ namespace EntityFrameworkCore.Projectables.Generator /// Contains only primitive types and ImmutableArray<string> so that value equality /// works correctly across incremental generation steps. /// - internal sealed record ProjectableRegistryEntry( + sealed internal record ProjectableRegistryEntry( string DeclaringTypeFullName, string MemberKind, string MemberLookupName, string GeneratedClassFullName, bool IsGenericClass, - int ClassTypeParamCount, bool IsGenericMethod, - int MethodTypeParamCount, ImmutableArray ParameterTypeNames ); } diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 4df1521..e1de46d 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -3,9 +3,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; using System.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -35,7 +33,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator public void Initialize(IncrementalGeneratorInitializationContext context) { // Do a simple filter for members - IncrementalValuesProvider memberDeclarations = context.SyntaxProvider + var memberDeclarations = context.SyntaxProvider .ForAttributeWithMetadataName( ProjectablesAttributeName, predicate: static (s, _) => s is MemberDeclarationSyntax, @@ -43,7 +41,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(new MemberDeclarationSyntaxEqualityComparer()); // Combine the selected enums with the `Compilation` - IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations + var compilationAndMemberPairs = memberDeclarations .Combine(context.CompilationProvider) .WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer()); @@ -52,11 +50,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) static (spc, source) => Execute(source.Item1, source.Item2, spc)); // Build the projection registry: collect all entries and emit a single registry file - IncrementalValuesProvider registryEntries = + var registryEntries = compilationAndMemberPairs.Select( static (pair, _) => ExtractRegistryEntry(pair.Item1, pair.Item2)); - IncrementalValueProvider> allEntries = + var allEntries = registryEntries.Collect(); context.RegisterImplementationSourceOutput( @@ -64,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) static (spc, entries) => EmitRegistry(entries, spc)); } - static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) + private static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) { var chain = CollectConstructorChain(ctor, compilation); @@ -105,7 +103,7 @@ void AddLine(string text) /// then its delegate's delegate, …). Stops when a delegated constructor has no source /// available in the compilation (e.g. a compiler-synthesised parameterless constructor). /// - static IReadOnlyList CollectConstructorChain( + private static List CollectConstructorChain( ConstructorDeclarationSyntax ctor, Compilation compilation) { var result = new List { ctor }; @@ -116,7 +114,9 @@ static IReadOnlyList CollectConstructorChain( { var semanticModel = compilation.GetSemanticModel(current.SyntaxTree); if (semanticModel.GetSymbolInfo(initializer).Symbol is not IMethodSymbol delegated) + { break; + } var delegatedSyntax = delegated.DeclaringSyntaxReferences .Select(r => r.GetSyntax()) @@ -124,7 +124,9 @@ static IReadOnlyList CollectConstructorChain( .FirstOrDefault(); if (delegatedSyntax is null || !visited.Add(delegatedSyntax)) + { break; + } result.Add(delegatedSyntax); current = delegatedSyntax; @@ -133,7 +135,7 @@ static IReadOnlyList CollectConstructorChain( return result; } - static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context) + private static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context) { var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context); @@ -161,40 +163,40 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou .WithLeadingTrivia(member is ConstructorDeclarationSyntax ctor ? BuildSourceDocComment(ctor, compilation) : TriviaList()) .AddMembers( MethodDeclaration( - GenericName( - Identifier("global::System.Linq.Expressions.Expression"), - TypeArgumentList( - SingletonSeparatedList( - (TypeSyntax)GenericName( - Identifier("global::System.Func"), - GetLambdaTypeArgumentListSyntax(projectable) + GenericName( + Identifier("global::System.Linq.Expressions.Expression"), + TypeArgumentList( + SingletonSeparatedList( + (TypeSyntax)GenericName( + Identifier("global::System.Func"), + GetLambdaTypeArgumentListSyntax(projectable) + ) ) ) - ) - ), - "Expression" - ) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.TypeParameterList) - .WithConstraintClauses(projectable.ConstraintClauses ?? List()) - .WithBody( - Block( - ReturnStatement( - ParenthesizedLambdaExpression( - projectable.ParametersList ?? ParameterList(), - null, - projectable.ExpressionBody + ), + "Expression" + ) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.TypeParameterList) + .WithConstraintClauses(projectable.ConstraintClauses ?? List()) + .WithBody( + Block( + ReturnStatement( + ParenthesizedLambdaExpression( + projectable.ParametersList ?? ParameterList(), + null, + projectable.ExpressionBody + ) ) ) ) - ) ); #nullable disable var compilationUnit = CompilationUnit(); - foreach (var usingDirective in projectable.UsingDirectives) + foreach (var usingDirective in projectable.UsingDirectives!) { compilationUnit = compilationUnit.AddUsings(usingDirective); } @@ -249,13 +251,15 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip /// Returns null when the member does not have [Projectable], is an extension member, /// or cannot be represented in the registry (e.g. a generic class member or generic method). /// - static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation) + private static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation) { var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); var memberSymbol = semanticModel.GetDeclaredSymbol(member); if (memberSymbol is null) + { return null; + } // Verify [Projectable] attribute var projectableAttributeTypeSymbol = compilation.GetTypeByMetadataName("EntityFrameworkCore.Projectables.ProjectableAttribute"); @@ -264,21 +268,25 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip if (projectableAttribute is null || !SymbolEqualityComparer.Default.Equals(projectableAttribute.AttributeClass, projectableAttributeTypeSymbol)) + { return null; + } // Skip C# 14 extension type members — they require special handling (fall back to reflection) if (memberSymbol.ContainingType is { IsExtension: true }) + { return null; + } var containingType = memberSymbol.ContainingType; - bool isGenericClass = containingType.TypeParameters.Length > 0; + var isGenericClass = containingType.TypeParameters.Length > 0; // Determine member kind and lookup name string memberKind; string memberLookupName; - ImmutableArray parameterTypeNames = ImmutableArray.Empty; - int methodTypeParamCount = 0; - bool isGenericMethod = false; + var parameterTypeNames = ImmutableArray.Empty; + var methodTypeParamCount = 0; + var isGenericMethod = false; if (memberSymbol is IMethodSymbol methodSymbol) { @@ -296,9 +304,9 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip memberLookupName = memberSymbol.Name; } - parameterTypeNames = methodSymbol.Parameters - .Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) - .ToImmutableArray(); + parameterTypeNames = [ + ..methodSymbol.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; } else { @@ -307,7 +315,7 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip } // Build the generated class name using the same logic as Execute - string? classNamespace = containingType.ContainingNamespace.IsGlobalNamespace + var classNamespace = containingType.ContainingNamespace.IsGlobalNamespace ? null : containingType.ContainingNamespace.ToDisplayString(); @@ -317,7 +325,7 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip classNamespace, nestedTypePath, memberLookupName, - parameterTypeNames.IsEmpty ? null : (IEnumerable)parameterTypeNames); + parameterTypeNames.IsEmpty ? null : parameterTypeNames); var generatedClassFullName = "EntityFrameworkCore.Projectables.Generated." + generatedClassName; @@ -329,18 +337,18 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip MemberLookupName: memberLookupName, GeneratedClassFullName: generatedClassFullName, IsGenericClass: isGenericClass, - ClassTypeParamCount: containingType.TypeParameters.Length, IsGenericMethod: isGenericMethod, - MethodTypeParamCount: methodTypeParamCount, ParameterTypeNames: parameterTypeNames); } - static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) + private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) { if (typeSymbol.ContainingType is not null) { foreach (var name in GetRegistryNestedTypePath(typeSymbol.ContainingType)) + { yield return name; + } } yield return typeSymbol.Name; } @@ -352,7 +360,7 @@ static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol /// The generated Build() method uses a shared Register helper to avoid repeating /// the lookup boilerplate for every entry. /// - static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) + private static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) { // Build the per-entry Register(...) statements first so we can bail out early // if every entry is generic (they all fall back to reflection, no registry needed). @@ -364,7 +372,9 @@ static void EmitRegistry(ImmutableArray entries, Sour .ToList(); if (entryStatements.Count == 0) + { return; + } // Build() body: // const BindingFlags allFlags = ...; @@ -375,13 +385,13 @@ static void EmitRegistry(ImmutableArray entries, Sour { // const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | ...; LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("BindingFlags")) - .AddVariables( - VariableDeclarator("allFlags") - .WithInitializer(EqualsValueClause( - ParseExpression( - "BindingFlags.Public | BindingFlags.NonPublic | " + - "BindingFlags.Instance | BindingFlags.Static"))))) + VariableDeclaration(ParseTypeName("BindingFlags")) + .AddVariables( + VariableDeclarator("allFlags") + .WithInitializer(EqualsValueClause( + ParseExpression( + "BindingFlags.Public | BindingFlags.NonPublic | " + + "BindingFlags.Instance | BindingFlags.Static"))))) .WithModifiers(TokenList(Token(SyntaxKind.ConstKeyword))), // var map = new Dictionary(); @@ -391,7 +401,7 @@ static void EmitRegistry(ImmutableArray entries, Sour VariableDeclarator("map") .WithInitializer(EqualsValueClause( ObjectCreationExpression( - ParseTypeName("Dictionary")) + ParseTypeName("Dictionary")) .WithArgumentList(ArgumentList()))))), }; @@ -406,12 +416,12 @@ static void EmitRegistry(ImmutableArray entries, Sour .AddMembers( // private static readonly Dictionary _map = Build(); FieldDeclaration( - VariableDeclaration(ParseTypeName("Dictionary")) - .AddVariables( - VariableDeclarator("_map") - .WithInitializer(EqualsValueClause( - InvocationExpression(IdentifierName("Build")) - .WithArgumentList(ArgumentList()))))) + VariableDeclaration(ParseTypeName("Dictionary")) + .AddVariables( + VariableDeclarator("_map") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("Build")) + .WithArgumentList(ArgumentList()))))) .WithModifiers(TokenList( Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), @@ -482,7 +492,7 @@ static void EmitRegistry(ImmutableArray entries, Sour Token(SyntaxKind.StaticKeyword))) .WithBody(Block(buildStatements)), - // private static void Register(Dictionary map, MethodBase? m, string exprClass) + // private static void Register(Dictionary map, MethodBase m, string exprClass) BuildRegisterHelperMethod()); var compilationUnit = CompilationUnit() @@ -507,10 +517,12 @@ static void EmitRegistry(ImmutableArray entries, Sour /// statement for one projectable entry in Build(). /// Returns for generic class/method entries (they fall back to reflection). /// - static StatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) + private static ExpressionStatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) { if (entry.IsGenericClass || entry.IsGenericMethod) + { return null; + } // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...) ExpressionSyntax? memberCallExpr = entry.MemberKind switch @@ -518,9 +530,9 @@ static void EmitRegistry(ImmutableArray entries, Sour // typeof(T).GetProperty("Name", allFlags)?.GetMethod "Property" => ConditionalAccessExpression( InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetProperty"))) + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetProperty"))) .AddArgumentListArguments( Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(entry.MemberLookupName))), @@ -529,9 +541,9 @@ static void EmitRegistry(ImmutableArray entries, Sour // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null) "Method" => InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetMethod"))) + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetMethod"))) .AddArgumentListArguments( Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(entry.MemberLookupName))), @@ -542,9 +554,9 @@ static void EmitRegistry(ImmutableArray entries, Sour // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null) "Constructor" => InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetConstructor"))) + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetConstructor"))) .AddArgumentListArguments( Argument(IdentifierName("allFlags")), Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), @@ -555,7 +567,9 @@ static void EmitRegistry(ImmutableArray entries, Sour }; if (memberCallExpr is null) + { return null; + } // Register(map, , ""); return ExpressionStatement( @@ -571,8 +585,8 @@ static void EmitRegistry(ImmutableArray entries, Sour /// Builds the Register private static helper method that all per-entry calls delegate to. /// It handles the null checks and the common reflection lookup pattern once, centrally. /// - static MethodDeclarationSyntax BuildRegisterHelperMethod() => - // private static void Register(Dictionary map, MethodBase? m, string exprClass) + private static MethodDeclarationSyntax BuildRegisterHelperMethod() => + // private static void Register(Dictionary map, MethodBase m, string exprClass) // { // if (m is null) return; // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); @@ -588,7 +602,7 @@ static MethodDeclarationSyntax BuildRegisterHelperMethod() => Parameter(Identifier("map")) .WithType(ParseTypeName("Dictionary")), Parameter(Identifier("m")) - .WithType(ParseTypeName("MethodBase?")), + .WithType(ParseTypeName("MethodBase")), Parameter(Identifier("exprClass")) .WithType(PredefinedType(Token(SyntaxKind.StringKeyword)))) .WithBody(Block( @@ -625,22 +639,23 @@ static MethodDeclarationSyntax BuildRegisterHelperMethod() => /// Builds the typeof(...)-array expression used for reflection method/constructor lookup. /// Returns global::System.Type.EmptyTypes when there are no parameters. /// - static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames) + private static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames) { if (parameterTypeNames.IsEmpty) + { return ParseExpression("global::System.Type.EmptyTypes"); + } var typeofExprs = parameterTypeNames - .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name))) + .Select(ExpressionSyntax (name) => TypeOfExpression(ParseTypeName(name))) .ToArray(); return ArrayCreationExpression( - ArrayType(ParseTypeName("global::System.Type")) - .AddRankSpecifiers(ArrayRankSpecifier())) + ArrayType(ParseTypeName("global::System.Type")) + .AddRankSpecifiers(ArrayRankSpecifier())) .WithInitializer( InitializerExpression(SyntaxKind.ArrayInitializerExpression, - SeparatedList(typeofExprs))); + SeparatedList(typeofExprs))); } - } -} +} \ No newline at end of file From 0bcc11f17aad27b692f60835afea52c7a6a561d1 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Wed, 4 Mar 2026 18:17:15 +0100 Subject: [PATCH 06/12] Code cleanup --- .../Services/ProjectionExpressionResolver.cs | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 89e0ac9..3e8ab67 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -22,11 +22,14 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver _assemblyRegistries.GetOrAdd(assembly, static asm => { var registryType = asm.GetType("EntityFrameworkCore.Projectables.Generated.ProjectionRegistry"); - if (registryType is null) return null; - var tryGetMethod = registryType.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); - if (tryGetMethod is null) return null; - return (Func)Delegate.CreateDelegate( - typeof(Func), tryGetMethod); + var tryGetMethod = registryType?.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); + + if (tryGetMethod is null) + { + return null; + } + + return (Func)Delegate.CreateDelegate(typeof(Func), tryGetMethod); }); public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) @@ -88,16 +91,12 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo // The first call per assembly does a reflection lookup to find the registry class and // caches it as a delegate; subsequent calls use the cached delegate for an O(1) dictionary lookup. var registry = GetAssemblyRegistry(declaringType.Assembly); - if (registry is not null) - { - var registeredExpr = registry(projectableMemberInfo); - if (registeredExpr is not null) - return registeredExpr; - } - - // Slow path: reflection fallback for open-generic class members and generic methods - // that are not yet in the registry. - return FindGeneratedExpressionViaReflection(projectableMemberInfo); + var registeredExpr = registry?.Invoke(projectableMemberInfo); + + return registeredExpr ?? + // Slow path: reflection fallback for open-generic class members and generic methods + // that are not yet in the registry. + FindGeneratedExpressionViaReflection(projectableMemberInfo); } } From e68b8d69e100cc536476b2499bb096e234a2ce3d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:05:32 +0000 Subject: [PATCH 07/12] Implement code review suggestions: sentinel delegate, EquatableImmutableArray, lambda fix, StringComparison.Ordinal Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com> --- .../ProjectableRegistryEntry.cs | 48 +++++++++++++++++-- .../ProjectionExpressionGenerator.cs | 2 +- .../Services/ProjectionExpressionResolver.cs | 20 +++++--- .../ProjectionExpressionGeneratorTestsBase.cs | 4 +- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs index 463c886..37a45e0 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -1,11 +1,12 @@ using System.Collections.Immutable; +using System.Linq; namespace EntityFrameworkCore.Projectables.Generator { /// /// Incremental-pipeline-safe representation of a single projectable member. - /// Contains only primitive types and ImmutableArray<string> so that value equality - /// works correctly across incremental generation steps. + /// Contains only primitive types and an equatable wrapper around + /// so that structural value equality works correctly across incremental generation steps. /// sealed internal record ProjectableRegistryEntry( string DeclaringTypeFullName, @@ -14,6 +15,47 @@ sealed internal record ProjectableRegistryEntry( string GeneratedClassFullName, bool IsGenericClass, bool IsGenericMethod, - ImmutableArray ParameterTypeNames + EquatableImmutableArray ParameterTypeNames ); + + /// + /// A structural-equality wrapper around of strings. + /// uses reference equality by default, which breaks + /// Roslyn's incremental-source-generator caching when the same logical array is + /// produced by two different steps. This wrapper provides element-wise equality so + /// that incremental steps are correctly cached and skipped. + /// + internal readonly struct EquatableImmutableArray : System.IEquatable + { + public static readonly EquatableImmutableArray Empty = new(ImmutableArray.Empty); + + public readonly ImmutableArray Array; + + public EquatableImmutableArray(ImmutableArray array) + { + Array = array; + } + + public bool IsDefaultOrEmpty => Array.IsDefaultOrEmpty; + + public bool Equals(EquatableImmutableArray other) => + Array.SequenceEqual(other.Array); + + public override bool Equals(object? obj) => + obj is EquatableImmutableArray other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = 17; + foreach (var s in Array) + hash = hash * 31 + (s?.GetHashCode() ?? 0); + return hash; + } + } + + public static implicit operator ImmutableArray(EquatableImmutableArray e) => e.Array; + public static implicit operator EquatableImmutableArray(ImmutableArray a) => new(a); + } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index e1de46d..0c83141 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -647,7 +647,7 @@ private static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parame } var typeofExprs = parameterTypeNames - .Select(ExpressionSyntax (name) => TypeOfExpression(ParseTypeName(name))) + .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name))) .ToArray(); return ArrayCreationExpression( diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 3e8ab67..a4b8cd6 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -9,29 +9,35 @@ namespace EntityFrameworkCore.Projectables.Services { public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver { - // Cache per-assembly registry delegate: Assembly → Func - // After first lookup, subsequent calls do a fast lock-free read. - private static readonly ConcurrentDictionary?> _assemblyRegistries = new(); + // We never store null in the dictionary; assemblies without a registry use a sentinel delegate. + private static readonly Func _nullRegistry = static _ => null!; + private static readonly ConcurrentDictionary> _assemblyRegistries = new(); /// /// Looks up the generated ProjectionRegistry class in an assembly (once, then caches it). /// Returns a delegate that calls TryGet(MemberInfo) on the registry, or null if the registry /// is not present in that assembly (e.g. if the source generator was not run against it). /// - private static Func? GetAssemblyRegistry(Assembly assembly) => - _assemblyRegistries.GetOrAdd(assembly, static asm => + private static Func? GetAssemblyRegistry(Assembly assembly) + { + var registry = _assemblyRegistries.GetOrAdd(assembly, static asm => { var registryType = asm.GetType("EntityFrameworkCore.Projectables.Generated.ProjectionRegistry"); var tryGetMethod = registryType?.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); - + if (tryGetMethod is null) { - return null; + // Use sentinel to indicate "no registry for this assembly" + return _nullRegistry; } return (Func)Delegate.CreateDelegate(typeof(Func), tryGetMethod); }); + // Translate sentinel back to null for callers, preserving existing behavior. + return ReferenceEquals(registry, _nullRegistry) ? null : registry; + } + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) { var projectableAttribute = projectableMemberInfo.GetCustomAttribute() diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index c442005..acfc52a 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -43,7 +43,7 @@ public TestGeneratorRunResult(GeneratorDriverRunResult inner) /// public ImmutableArray GeneratedTrees => _inner.GeneratedTrees - .Where(t => !t.FilePath.EndsWith("ProjectionRegistry.g.cs")) + .Where(t => !t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)) .ToImmutableArray(); /// @@ -56,7 +56,7 @@ public TestGeneratorRunResult(GeneratorDriverRunResult inner) /// The generated ProjectionRegistry.g.cs tree, or null if it was not generated. /// public SyntaxTree? RegistryTree => - _inner.GeneratedTrees.FirstOrDefault(t => t.FilePath.EndsWith("ProjectionRegistry.g.cs")); + _inner.GeneratedTrees.FirstOrDefault(t => t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)); } protected Compilation CreateCompilation([StringSyntax("csharp")] string source) From eccb66aa8476c783c4ecbd264216373d06d48842 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sun, 8 Mar 2026 10:02:22 +0100 Subject: [PATCH 08/12] Fix incompatible code after merge --- .../ProjectionExpressionGenerator.cs | 14 +++++++------- .../ProjectionExpressionGeneratorTestsBase.cs | 15 +++++++++------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 37b3de9..bd23824 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -64,15 +64,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); // Build the projection registry: collect all entries and emit a single registry file - var registryEntries = - compilationAndMemberPairs.Select( - static (pair, _) => ExtractRegistryEntry(pair.Item1, pair.Item2)); - - var allEntries = - registryEntries.Collect(); + var registryEntries = compilationAndMemberPairs.Select( + static (pair, _) => { + var ((member, _), compilation) = pair; + + return ExtractRegistryEntry(member, compilation); + }); context.RegisterImplementationSourceOutput( - allEntries, + registryEntries.Collect(), static (spc, entries) => EmitRegistry(entries, spc)); } diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index 124bae4..2e42c2e 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -128,7 +128,7 @@ protected TestGeneratorRunResult RunGenerator(Compilation compilation) /// returning both the driver and the run result. The driver can be passed to subsequent /// calls to to test incremental caching behavior. /// - protected (GeneratorDriver Driver, GeneratorDriverRunResult Result) CreateAndRunGenerator(Compilation compilation) + protected (GeneratorDriver Driver, TestGeneratorRunResult Result) CreateAndRunGenerator(Compilation compilation) { _testOutputHelper.WriteLine("Creating generator driver and running on initial compilation..."); @@ -136,7 +136,9 @@ protected TestGeneratorRunResult RunGenerator(Compilation compilation) GeneratorDriver driver = CSharpGeneratorDriver.Create(subject); driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); + LogGeneratorResult(result, outputCompilation); return (driver, result); @@ -146,19 +148,20 @@ protected TestGeneratorRunResult RunGenerator(Compilation compilation) /// Runs the generator using an existing driver (preserving incremental state from previous runs) /// on a new compilation, returning the updated driver and run result. /// - protected (GeneratorDriver Driver, GeneratorDriverRunResult Result) RunGeneratorWithDriver( + protected (GeneratorDriver Driver, TestGeneratorRunResult Result) RunGeneratorWithDriver( GeneratorDriver driver, Compilation compilation) { _testOutputHelper.WriteLine("Running generator with existing driver on updated compilation..."); driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); - LogGeneratorResult(result, outputCompilation); + + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); return (driver, result); } - private void LogGeneratorResult(GeneratorDriverRunResult result, Compilation outputCompilation) + private void LogGeneratorResult(TestGeneratorRunResult result, Compilation outputCompilation) { if (result.Diagnostics.IsEmpty) { From 2c3d3765786004d3f7d68b30fd9b24178cfe03d2 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sun, 8 Mar 2026 11:04:34 +0100 Subject: [PATCH 09/12] Optimize and fix tests --- .../GeneratorBenchmarks.cs | 40 ++-- .../ProjectionExpressionGenerator.cs | 198 +++++++++--------- ...ethodOverloads_BothRegistered.verified.txt | 45 ++++ ...pleProjectables_AllRegistered.verified.txt | 45 ++++ ..._ConstBindingFlagsUsedInBuild.verified.txt | 44 ++++ ...lperUsesDeclaringTypeAssembly.verified.txt | 44 ++++ ...eMethod_RegistryContainsEntry.verified.txt | 44 ++++ ...roperty_RegistryContainsEntry.verified.txt | 44 ++++ .../RegistryTests.cs | 85 ++------ 9 files changed, 403 insertions(+), 186 deletions(-) create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs index 6cc2298..7b7ebd1 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs @@ -11,7 +11,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks; [MemoryDiagnoser] public class GeneratorBenchmarks { - [Params(1, 100, 1000)] + [Params(1, 100)] public int ProjectableCount { get; set; } private Compilation _compilation = null!; @@ -84,25 +84,25 @@ public GeneratorDriver RunGenerator() .Create(new ProjectionExpressionGenerator()) .RunGeneratorsAndUpdateCompilation(_compilation, out _, out _); - /// - /// Cold run where the compilation has a trivial edit in a noise file. - /// Shows whether the cold path is sensitive to which files changed. - /// - [Benchmark] - public GeneratorDriver RunGenerator_NoiseChange() - => CSharpGeneratorDriver - .Create(new ProjectionExpressionGenerator()) - .RunGeneratorsAndUpdateCompilation(_noiseModifiedCompilation, out _, out _); - - /// - /// Cold run where the compilation has a trivial edit in a projectable - /// file (comment only — no member body change). - /// - [Benchmark] - public GeneratorDriver RunGenerator_ProjectableChange() - => CSharpGeneratorDriver - .Create(new ProjectionExpressionGenerator()) - .RunGeneratorsAndUpdateCompilation(_projectableModifiedCompilation, out _, out _); + // /// + // /// Cold run where the compilation has a trivial edit in a noise file. + // /// Shows whether the cold path is sensitive to which files changed. + // /// + // [Benchmark] + // public GeneratorDriver RunGenerator_NoiseChange() + // => CSharpGeneratorDriver + // .Create(new ProjectionExpressionGenerator()) + // .RunGeneratorsAndUpdateCompilation(_noiseModifiedCompilation, out _, out _); + // + // /// + // /// Cold run where the compilation has a trivial edit in a projectable + // /// file (comment only — no member body change). + // /// + // [Benchmark] + // public GeneratorDriver RunGenerator_ProjectableChange() + // => CSharpGeneratorDriver + // .Create(new ProjectionExpressionGenerator()) + // .RunGeneratorsAndUpdateCompilation(_projectableModifiedCompilation, out _, out _); // ------------------------------------------------------------------------- // Incremental benchmarks — the pre-warmed driver processes a single edit. diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index bd23824..0354bbd 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -14,7 +14,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator { private const string ProjectablesAttributeName = "EntityFrameworkCore.Projectables.ProjectableAttribute"; - static readonly AttributeSyntax _editorBrowsableAttribute = + private readonly static AttributeSyntax _editorBrowsableAttribute = Attribute( ParseName("global::System.ComponentModel.EditorBrowsable"), AttributeArgumentList( @@ -29,7 +29,10 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator ) ) ); - + + private static MethodDeclarationSyntax? _registerHelperMethod; + private static FieldDeclarationSyntax? _mapField; + private static MethodDeclarationSyntax? _tryGetMethod; public void Initialize(IncrementalGeneratorInitializationContext context) { // Extract only pure stable data from the attribute in the transform. @@ -65,10 +68,18 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Build the projection registry: collect all entries and emit a single registry file var registryEntries = compilationAndMemberPairs.Select( - static (pair, _) => { - var ((member, _), compilation) = pair; + static (source, cancellationToken) => { + var ((member, _), compilation) = source; + + var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); + var memberSymbol = semanticModel.GetDeclaredSymbol(member, cancellationToken); + + if (memberSymbol is null) + { + return null; + } - return ExtractRegistryEntry(member, compilation); + return ExtractRegistryEntry(memberSymbol); }); context.RegisterImplementationSourceOutput( @@ -274,27 +285,8 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip /// Returns null when the member does not have [Projectable], is an extension member, /// or cannot be represented in the registry (e.g. a generic class member or generic method). /// - private static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation) + private static ProjectableRegistryEntry? ExtractRegistryEntry(ISymbol memberSymbol) { - var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); - var memberSymbol = semanticModel.GetDeclaredSymbol(member); - - if (memberSymbol is null) - { - return null; - } - - // Verify [Projectable] attribute - var projectableAttributeTypeSymbol = compilation.GetTypeByMetadataName("EntityFrameworkCore.Projectables.ProjectableAttribute"); - var projectableAttribute = memberSymbol.GetAttributes() - .FirstOrDefault(x => x.AttributeClass?.Name == "ProjectableAttribute"); - - if (projectableAttribute is null || - !SymbolEqualityComparer.Default.Equals(projectableAttribute.AttributeClass, projectableAttributeTypeSymbol)) - { - return null; - } - // Skip C# 14 extension type members — they require special handling (fall back to reflection) if (memberSymbol.ContainingType is { IsExtension: true }) { @@ -308,13 +300,11 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip string memberKind; string memberLookupName; var parameterTypeNames = ImmutableArray.Empty; - var methodTypeParamCount = 0; var isGenericMethod = false; if (memberSymbol is IMethodSymbol methodSymbol) { isGenericMethod = methodSymbol.TypeParameters.Length > 0; - methodTypeParamCount = methodSymbol.TypeParameters.Length; if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor) { @@ -437,77 +427,6 @@ private static void EmitRegistry(ImmutableArray entri Token(SyntaxKind.StaticKeyword))) .AddAttributeLists(AttributeList().AddAttributes(_editorBrowsableAttribute)) .AddMembers( - // private static readonly Dictionary _map = Build(); - FieldDeclaration( - VariableDeclaration(ParseTypeName("Dictionary")) - .AddVariables( - VariableDeclarator("_map") - .WithInitializer(EqualsValueClause( - InvocationExpression(IdentifierName("Build")) - .WithArgumentList(ArgumentList()))))) - .WithModifiers(TokenList( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword), - Token(SyntaxKind.ReadOnlyKeyword))), - - // public static LambdaExpression TryGet(MemberInfo member) - MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet") - .WithModifiers(TokenList( - Token(SyntaxKind.PublicKeyword), - Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters( - Parameter(Identifier("member")) - .WithType(ParseTypeName("MemberInfo"))) - .WithBody(Block( - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("handle") - .WithInitializer(EqualsValueClause( - InvocationExpression(IdentifierName("GetHandle")) - .AddArgumentListArguments( - Argument(IdentifierName("member"))))))), - ReturnStatement( - ParseExpression( - "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))), - - // private static nint? GetHandle(MemberInfo member) => member switch { ... }; - MethodDeclaration(ParseTypeName("nint?"), "GetHandle") - .WithModifiers(TokenList( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters( - Parameter(Identifier("member")) - .WithType(ParseTypeName("MemberInfo"))) - .WithExpressionBody(ArrowExpressionClause( - SwitchExpression(IdentifierName("member")) - .WithArms(SeparatedList( - new SyntaxNodeOrToken[] - { - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("MethodInfo"), - SingleVariableDesignation(Identifier("m"))), - ParseExpression("m.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("PropertyInfo"), - SingleVariableDesignation(Identifier("p"))), - ParseExpression("p.GetMethod?.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("ConstructorInfo"), - SingleVariableDesignation(Identifier("c"))), - ParseExpression("c.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DiscardPattern(), - LiteralExpression(SyntaxKind.NullLiteralExpression)) - })))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)), - // private static Dictionary Build() { ... } MethodDeclaration(ParseTypeName("Dictionary"), "Build") .WithModifiers(TokenList( @@ -515,7 +434,9 @@ private static void EmitRegistry(ImmutableArray entri Token(SyntaxKind.StaticKeyword))) .WithBody(Block(buildStatements)), - // private static void Register(Dictionary map, MethodBase m, string exprClass) + // Cached members — built once and reused across incremental runs + BuildMapField(), + BuildTryGetMethod(), BuildRegisterHelperMethod()); var compilationUnit = CompilationUnit() @@ -604,11 +525,88 @@ private static void EmitRegistry(ImmutableArray entri Literal(entry.GeneratedClassFullName))))); } + + /// + /// Builds (and caches) the _map field declaration: + /// private static readonly Dictionary<nint, LambdaExpression> _map = Build(); + /// + private static FieldDeclarationSyntax BuildMapField() => _mapField ??= + FieldDeclaration( + VariableDeclaration(ParseTypeName("Dictionary")) + .AddVariables( + VariableDeclarator("_map") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("Build")) + .WithArgumentList(ArgumentList()))))) + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.ReadOnlyKeyword))); + + /// + /// Builds (and caches) the TryGet public static method declaration. + /// The GetHandle logic is inlined as a switch expression on member. + /// + private static MethodDeclarationSyntax BuildTryGetMethod() => _tryGetMethod ??= + // public static LambdaExpression TryGet(MemberInfo member) + // { + // var handle = member switch + // { + // MethodInfo m => (nint?)m.MethodHandle.Value, + // PropertyInfo p => p.GetMethod?.MethodHandle.Value, + // ConstructorInfo c => (nint?)c.MethodHandle.Value, + // _ => null + // }; + // return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + // } + MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet") + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("member")) + .WithType(ParseTypeName("MemberInfo"))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("handle") + .WithInitializer(EqualsValueClause( + SwitchExpression(IdentifierName("member")) + .WithArms(SeparatedList( + new SyntaxNodeOrToken[] + { + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("MethodInfo"), + SingleVariableDesignation(Identifier("m"))), + ParseExpression("(nint?)m.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("PropertyInfo"), + SingleVariableDesignation(Identifier("p"))), + ParseExpression("p.GetMethod?.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("ConstructorInfo"), + SingleVariableDesignation(Identifier("c"))), + ParseExpression("(nint?)c.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DiscardPattern(), + LiteralExpression(SyntaxKind.NullLiteralExpression)) + })))))), + ReturnStatement( + ParseExpression( + "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))); + /// /// Builds the Register private static helper method that all per-entry calls delegate to. /// It handles the null checks and the common reflection lookup pattern once, centrally. /// - private static MethodDeclarationSyntax BuildRegisterHelperMethod() => + private static MethodDeclarationSyntax BuildRegisterHelperMethod() => _registerHelperMethod ??= // private static void Register(Dictionary map, MethodBase m, string exprClass) // { // if (m is null) return; diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt new file mode 100644 index 0000000..9d3de1b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt @@ -0,0 +1,45 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_int"); + Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(long) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_long"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt new file mode 100644 index 0000000..4618ba0 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt @@ -0,0 +1,45 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt new file mode 100644 index 0000000..9b74a38 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt @@ -0,0 +1,44 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt new file mode 100644 index 0000000..9b74a38 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt @@ -0,0 +1,44 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt new file mode 100644 index 0000000..51f6d39 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt @@ -0,0 +1,44 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt new file mode 100644 index 0000000..9b74a38 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt @@ -0,0 +1,44 @@ +// +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; + } + + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint? )m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint? )c.MethodHandle.Value, + _ => null + }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) + return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs index 27fe1f8..f8b8782 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs @@ -1,23 +1,26 @@ +using System.Threading.Tasks; +using VerifyXunit; using Xunit; using Xunit.Abstractions; namespace EntityFrameworkCore.Projectables.Generator.Tests; +[UsesVerify] public class RegistryTests : ProjectionExpressionGeneratorTestsBase { public RegistryTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { } [Fact] - public void NoProjectables_NoRegistry() + public Task NoProjectables_NoRegistry() { var compilation = CreateCompilation(@"class C { }"); var result = RunGenerator(compilation); - Assert.Null(result.RegistryTree); + return Verifier.Verify(result.RegistryTree?.GetText().ToString()); } [Fact] - public void SingleProperty_RegistryContainsEntry() + public Task SingleProperty_RegistryContainsEntry() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -30,19 +33,11 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - Assert.Contains("ProjectionRegistry", src); - // Uses the compact Register helper — not a repeated block - Assert.Contains("private static void Register(", src); - Assert.Contains("Register(map,", src); - Assert.Contains("GetProperty(\"IdPlus1\"", src); - Assert.Contains("Foo_C_IdPlus1", src); + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } [Fact] - public void SingleMethod_RegistryContainsEntry() + public Task SingleMethod_RegistryContainsEntry() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -55,16 +50,11 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - Assert.Contains("GetMethod(\"AddDelta\"", src); - Assert.Contains("typeof(int)", src); - Assert.Contains("Foo_C_AddDelta_P0_int", src); + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } [Fact] - public void MultipleProjectables_AllRegistered() + public Task MultipleProjectables_AllRegistered() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -79,19 +69,11 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - // Two separate Register(map, ...) calls — one per projectable - Assert.Contains("GetProperty(\"IdPlus1\"", src); - Assert.Contains("GetMethod(\"AddDelta\"", src); - // Each entry is a single line, not a repeated multi-line block - var registerCallCount = CountOccurrences(src, "Register(map,"); - Assert.Equal(2, registerCallCount); + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } [Fact] - public void GenericClass_NotIncludedInRegistry() + public Task GenericClass_NotIncludedInRegistry() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -104,12 +86,11 @@ class C { }"); var result = RunGenerator(compilation); - // Generic class members fall back to reflection — no registry emitted - Assert.Null(result.RegistryTree); + return Verifier.Verify(result.RegistryTree?.GetText().ToString()); } [Fact] - public void Registry_ConstBindingFlagsUsedInBuild() + public Task Registry_ConstBindingFlagsUsedInBuild() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -122,16 +103,11 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - // Build() uses a single const BindingFlags instead of repeating the flags per entry - Assert.Contains("const BindingFlags allFlags", src); - Assert.Contains("allFlags", src); + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } [Fact] - public void Registry_RegisterHelperUsesDeclaringTypeAssembly() + public Task Registry_RegisterHelperUsesDeclaringTypeAssembly() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -144,15 +120,11 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - // Register helper derives the assembly from m.DeclaringType (no typeof repeated per entry) - Assert.Contains("m.DeclaringType?.Assembly.GetType(exprClass)", src); + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } [Fact] - public void MethodOverloads_BothRegistered() + public Task MethodOverloads_BothRegistered() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -167,25 +139,6 @@ class C { }"); var result = RunGenerator(compilation); - Assert.NotNull(result.RegistryTree); - var src = result.RegistryTree!.GetText().ToString(); - - // Both overloads registered by parameter-type disambiguation - Assert.Contains("typeof(int)", src); - Assert.Contains("typeof(long)", src); - var registerCallCount = CountOccurrences(src, "Register(map,"); - Assert.Equal(2, registerCallCount); - } - - private static int CountOccurrences(string text, string pattern) - { - int count = 0; - int index = 0; - while ((index = text.IndexOf(pattern, index, System.StringComparison.Ordinal)) >= 0) - { - count++; - index += pattern.Length; - } - return count; + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); } } From 11291116f08bb3783defede23c98020b59117b8a Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sun, 8 Mar 2026 11:46:45 +0100 Subject: [PATCH 10/12] Remove unused code and small optimizations --- .../ProjectableRegistryEntry.cs | 26 +++++------- .../ProjectableRegistryMemberType.cs | 8 ++++ .../ProjectionExpressionGenerator.cs | 40 +++++++++---------- 3 files changed, 37 insertions(+), 37 deletions(-) create mode 100644 src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs index 37a45e0..c7adfad 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -10,11 +10,9 @@ namespace EntityFrameworkCore.Projectables.Generator /// sealed internal record ProjectableRegistryEntry( string DeclaringTypeFullName, - string MemberKind, + ProjectableRegistryMemberType MemberKind, string MemberLookupName, string GeneratedClassFullName, - bool IsGenericClass, - bool IsGenericMethod, EquatableImmutableArray ParameterTypeNames ); @@ -25,21 +23,12 @@ EquatableImmutableArray ParameterTypeNames /// produced by two different steps. This wrapper provides element-wise equality so /// that incremental steps are correctly cached and skipped. /// - internal readonly struct EquatableImmutableArray : System.IEquatable + readonly internal struct EquatableImmutableArray(ImmutableArray array) : IEquatable { - public static readonly EquatableImmutableArray Empty = new(ImmutableArray.Empty); - - public readonly ImmutableArray Array; - - public EquatableImmutableArray(ImmutableArray array) - { - Array = array; - } - - public bool IsDefaultOrEmpty => Array.IsDefaultOrEmpty; + private readonly ImmutableArray _array = array; public bool Equals(EquatableImmutableArray other) => - Array.SequenceEqual(other.Array); + _array.SequenceEqual(other._array); public override bool Equals(object? obj) => obj is EquatableImmutableArray other && Equals(other); @@ -49,13 +38,16 @@ public override int GetHashCode() unchecked { var hash = 17; - foreach (var s in Array) + foreach (var s in _array) + { hash = hash * 31 + (s?.GetHashCode() ?? 0); + } + return hash; } } - public static implicit operator ImmutableArray(EquatableImmutableArray e) => e.Array; + public static implicit operator ImmutableArray(EquatableImmutableArray e) => e._array; public static implicit operator EquatableImmutableArray(ImmutableArray a) => new(a); } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs new file mode 100644 index 0000000..4697d72 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs @@ -0,0 +1,8 @@ +namespace EntityFrameworkCore.Projectables.Generator; + +public enum ProjectableRegistryMemberType : byte +{ + Property, + Method, + Constructor, +} \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 0354bbd..30b8a6b 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -287,33 +287,41 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip /// private static ProjectableRegistryEntry? ExtractRegistryEntry(ISymbol memberSymbol) { + var containingType = memberSymbol.ContainingType; + // Skip C# 14 extension type members — they require special handling (fall back to reflection) - if (memberSymbol.ContainingType is { IsExtension: true }) + if (containingType is { IsExtension: true }) { return null; } - var containingType = memberSymbol.ContainingType; - var isGenericClass = containingType.TypeParameters.Length > 0; + // Early exit for generic classes: BuildRegistryEntryStatement returns null for them anyway. + if (containingType.TypeParameters.Length > 0) + { + return null; + } // Determine member kind and lookup name - string memberKind; + ProjectableRegistryMemberType memberKind; string memberLookupName; var parameterTypeNames = ImmutableArray.Empty; - var isGenericMethod = false; if (memberSymbol is IMethodSymbol methodSymbol) { - isGenericMethod = methodSymbol.TypeParameters.Length > 0; + // Early exit for generic methods + if (methodSymbol.TypeParameters.Length > 0) + { + return null; + } if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor) { - memberKind = "Constructor"; + memberKind = ProjectableRegistryMemberType.Constructor; memberLookupName = "_ctor"; } else { - memberKind = "Method"; + memberKind = ProjectableRegistryMemberType.Method; memberLookupName = memberSymbol.Name; } @@ -323,7 +331,7 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip } else { - memberKind = "Property"; + memberKind = ProjectableRegistryMemberType.Property; memberLookupName = memberSymbol.Name; } @@ -349,8 +357,6 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip MemberKind: memberKind, MemberLookupName: memberLookupName, GeneratedClassFullName: generatedClassFullName, - IsGenericClass: isGenericClass, - IsGenericMethod: isGenericMethod, ParameterTypeNames: parameterTypeNames); } @@ -459,20 +465,14 @@ private static void EmitRegistry(ImmutableArray entri /// /// Builds a single compact Register(map, typeof(T).GetXxx(...), "ClassName") /// statement for one projectable entry in Build(). - /// Returns for generic class/method entries (they fall back to reflection). /// private static ExpressionStatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) { - if (entry.IsGenericClass || entry.IsGenericMethod) - { - return null; - } - // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...) ExpressionSyntax? memberCallExpr = entry.MemberKind switch { // typeof(T).GetProperty("Name", allFlags)?.GetMethod - "Property" => ConditionalAccessExpression( + ProjectableRegistryMemberType.Property => ConditionalAccessExpression( InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), @@ -484,7 +484,7 @@ private static void EmitRegistry(ImmutableArray entri MemberBindingExpression(IdentifierName("GetMethod"))), // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null) - "Method" => InvocationExpression( + ProjectableRegistryMemberType.Method => InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), IdentifierName("GetMethod"))) @@ -497,7 +497,7 @@ private static void EmitRegistry(ImmutableArray entri Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null) - "Constructor" => InvocationExpression( + ProjectableRegistryMemberType.Constructor => InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), IdentifierName("GetConstructor"))) From 0ca5faa860c3a8ae0b3fc3a14df37ec3fae2f160 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sun, 8 Mar 2026 14:06:25 +0100 Subject: [PATCH 11/12] Move registry generation to its own class Move code generation to a string builder (much faster) --- .../GeneratorBenchmarks.cs | 40 +-- .../ProjectionExpressionGenerator.cs | 337 +----------------- .../ProjectionRegistryEmitter.cs | 207 +++++++++++ ...ethodOverloads_BothRegistered.verified.txt | 22 +- ...pleProjectables_AllRegistered.verified.txt | 22 +- ..._ConstBindingFlagsUsedInBuild.verified.txt | 22 +- ...lperUsesDeclaringTypeAssembly.verified.txt | 22 +- ...eMethod_RegistryContainsEntry.verified.txt | 22 +- ...roperty_RegistryContainsEntry.verified.txt | 22 +- .../RegistryTests.cs | 13 +- 10 files changed, 326 insertions(+), 403 deletions(-) create mode 100644 src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs index 7b7ebd1..6cc2298 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/GeneratorBenchmarks.cs @@ -11,7 +11,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks; [MemoryDiagnoser] public class GeneratorBenchmarks { - [Params(1, 100)] + [Params(1, 100, 1000)] public int ProjectableCount { get; set; } private Compilation _compilation = null!; @@ -84,25 +84,25 @@ public GeneratorDriver RunGenerator() .Create(new ProjectionExpressionGenerator()) .RunGeneratorsAndUpdateCompilation(_compilation, out _, out _); - // /// - // /// Cold run where the compilation has a trivial edit in a noise file. - // /// Shows whether the cold path is sensitive to which files changed. - // /// - // [Benchmark] - // public GeneratorDriver RunGenerator_NoiseChange() - // => CSharpGeneratorDriver - // .Create(new ProjectionExpressionGenerator()) - // .RunGeneratorsAndUpdateCompilation(_noiseModifiedCompilation, out _, out _); - // - // /// - // /// Cold run where the compilation has a trivial edit in a projectable - // /// file (comment only — no member body change). - // /// - // [Benchmark] - // public GeneratorDriver RunGenerator_ProjectableChange() - // => CSharpGeneratorDriver - // .Create(new ProjectionExpressionGenerator()) - // .RunGeneratorsAndUpdateCompilation(_projectableModifiedCompilation, out _, out _); + /// + /// Cold run where the compilation has a trivial edit in a noise file. + /// Shows whether the cold path is sensitive to which files changed. + /// + [Benchmark] + public GeneratorDriver RunGenerator_NoiseChange() + => CSharpGeneratorDriver + .Create(new ProjectionExpressionGenerator()) + .RunGeneratorsAndUpdateCompilation(_noiseModifiedCompilation, out _, out _); + + /// + /// Cold run where the compilation has a trivial edit in a projectable + /// file (comment only — no member body change). + /// + [Benchmark] + public GeneratorDriver RunGenerator_ProjectableChange() + => CSharpGeneratorDriver + .Create(new ProjectionExpressionGenerator()) + .RunGeneratorsAndUpdateCompilation(_projectableModifiedCompilation, out _, out _); // ------------------------------------------------------------------------- // Incremental benchmarks — the pre-warmed driver processes a single edit. diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 30b8a6b..328615f 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -14,7 +14,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator { private const string ProjectablesAttributeName = "EntityFrameworkCore.Projectables.ProjectableAttribute"; - private readonly static AttributeSyntax _editorBrowsableAttribute = + private readonly static AttributeSyntax _editorBrowsableAttribute = Attribute( ParseName("global::System.ComponentModel.EditorBrowsable"), AttributeArgumentList( @@ -29,10 +29,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator ) ) ); - - private static MethodDeclarationSyntax? _registerHelperMethod; - private static FieldDeclarationSyntax? _mapField; - private static MethodDeclarationSyntax? _tryGetMethod; + public void Initialize(IncrementalGeneratorInitializationContext context) { // Extract only pure stable data from the attribute in the transform. @@ -50,14 +47,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var compilationAndMemberPairs = memberDeclarations .Combine(context.CompilationProvider) .WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer()); - + context.RegisterSourceOutput(compilationAndMemberPairs, static (spc, source) => { var ((member, attribute), compilation) = source; var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); var memberSymbol = semanticModel.GetDeclaredSymbol(member); - + if (memberSymbol is null) { return; @@ -65,12 +62,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) Execute(member, semanticModel, memberSymbol, attribute, compilation, spc); }); - + // Build the projection registry: collect all entries and emit a single registry file var registryEntries = compilationAndMemberPairs.Select( static (source, cancellationToken) => { var ((member, _), compilation) = source; - + var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); var memberSymbol = semanticModel.GetDeclaredSymbol(member, cancellationToken); @@ -78,13 +75,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { return null; } - + return ExtractRegistryEntry(memberSymbol); }); + // Delegate registry file emission to the dedicated ProjectionRegistryEmitter, + // which uses a string-based CodeWriter instead of SyntaxFactory. context.RegisterImplementationSourceOutput( registryEntries.Collect(), - static (spc, entries) => EmitRegistry(entries, spc)); + static (spc, entries) => ProjectionRegistryEmitter.Emit(entries, spc)); } private static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) @@ -221,7 +220,7 @@ private static void Execute( ) ) ) - ) + ) ); #nullable disable @@ -257,7 +256,6 @@ private static void Execute( ) ); - context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable) @@ -288,14 +286,14 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip private static ProjectableRegistryEntry? ExtractRegistryEntry(ISymbol memberSymbol) { var containingType = memberSymbol.ContainingType; - + // Skip C# 14 extension type members — they require special handling (fall back to reflection) if (containingType is { IsExtension: true }) { return null; } - // Early exit for generic classes: BuildRegistryEntryStatement returns null for them anyway. + // Skip generic classes: the registry only supports closed constructed types. if (containingType.TypeParameters.Length > 0) { return null; @@ -308,7 +306,7 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip if (memberSymbol is IMethodSymbol methodSymbol) { - // Early exit for generic methods + // Skip generic methods for the same reason as generic classes if (methodSymbol.TypeParameters.Length > 0) { return null; @@ -371,312 +369,5 @@ private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol ty } yield return typeSymbol.Name; } - - /// - /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members - /// into a single static dictionary keyed by . - /// Uses SyntaxFactory for the class/method/field structure, consistent with . - /// The generated Build() method uses a shared Register helper to avoid repeating - /// the lookup boilerplate for every entry. - /// - private static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) - { - // Build the per-entry Register(...) statements first so we can bail out early - // if every entry is generic (they all fall back to reflection, no registry needed). - var entryStatements = entries - .Where(e => e is not null) - .Select(e => BuildRegistryEntryStatement(e!)) - .Where(s => s is not null) - .Select(s => s!) - .ToList(); - - if (entryStatements.Count == 0) - { - return; - } - - // Build() body: - // const BindingFlags allFlags = ...; - // var map = new Dictionary(); - // Register(map, typeof(T).GetXxx(...), "ClassName"); ← one line per entry - // return map; - var buildStatements = new List - { - // const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | ...; - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("BindingFlags")) - .AddVariables( - VariableDeclarator("allFlags") - .WithInitializer(EqualsValueClause( - ParseExpression( - "BindingFlags.Public | BindingFlags.NonPublic | " + - "BindingFlags.Instance | BindingFlags.Static"))))) - .WithModifiers(TokenList(Token(SyntaxKind.ConstKeyword))), - - // var map = new Dictionary(); - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("map") - .WithInitializer(EqualsValueClause( - ObjectCreationExpression( - ParseTypeName("Dictionary")) - .WithArgumentList(ArgumentList()))))), - }; - - buildStatements.AddRange(entryStatements); - buildStatements.Add(ReturnStatement(IdentifierName("map"))); - - var classSyntax = ClassDeclaration("ProjectionRegistry") - .WithModifiers(TokenList( - Token(SyntaxKind.InternalKeyword), - Token(SyntaxKind.StaticKeyword))) - .AddAttributeLists(AttributeList().AddAttributes(_editorBrowsableAttribute)) - .AddMembers( - // private static Dictionary Build() { ... } - MethodDeclaration(ParseTypeName("Dictionary"), "Build") - .WithModifiers(TokenList( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword))) - .WithBody(Block(buildStatements)), - - // Cached members — built once and reused across incremental runs - BuildMapField(), - BuildTryGetMethod(), - BuildRegisterHelperMethod()); - - var compilationUnit = CompilationUnit() - .AddUsings( - UsingDirective(ParseName("System")), - UsingDirective(ParseName("System.Collections.Generic")), - UsingDirective(ParseName("System.Linq.Expressions")), - UsingDirective(ParseName("System.Reflection"))) - .AddMembers( - NamespaceDeclaration(ParseName("EntityFrameworkCore.Projectables.Generated")) - .AddMembers(classSyntax)) - .WithLeadingTrivia(TriviaList( - Comment("// "), - Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))); - - context.AddSource("ProjectionRegistry.g.cs", - SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); - } - - /// - /// Builds a single compact Register(map, typeof(T).GetXxx(...), "ClassName") - /// statement for one projectable entry in Build(). - /// - private static ExpressionStatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) - { - // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...) - ExpressionSyntax? memberCallExpr = entry.MemberKind switch - { - // typeof(T).GetProperty("Name", allFlags)?.GetMethod - ProjectableRegistryMemberType.Property => ConditionalAccessExpression( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetProperty"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.MemberLookupName))), - Argument(IdentifierName("allFlags"))), - MemberBindingExpression(IdentifierName("GetMethod"))), - - // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null) - ProjectableRegistryMemberType.Method => InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetMethod"))) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.MemberLookupName))), - Argument(IdentifierName("allFlags")), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), - Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), - - // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null) - ProjectableRegistryMemberType.Constructor => InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), - IdentifierName("GetConstructor"))) - .AddArgumentListArguments( - Argument(IdentifierName("allFlags")), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), - Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), - Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), - - _ => null - }; - - if (memberCallExpr is null) - { - return null; - } - - // Register(map, , ""); - return ExpressionStatement( - InvocationExpression(IdentifierName("Register")) - .AddArgumentListArguments( - Argument(IdentifierName("map")), - Argument(memberCallExpr), - Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, - Literal(entry.GeneratedClassFullName))))); - } - - - /// - /// Builds (and caches) the _map field declaration: - /// private static readonly Dictionary<nint, LambdaExpression> _map = Build(); - /// - private static FieldDeclarationSyntax BuildMapField() => _mapField ??= - FieldDeclaration( - VariableDeclaration(ParseTypeName("Dictionary")) - .AddVariables( - VariableDeclarator("_map") - .WithInitializer(EqualsValueClause( - InvocationExpression(IdentifierName("Build")) - .WithArgumentList(ArgumentList()))))) - .WithModifiers(TokenList( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword), - Token(SyntaxKind.ReadOnlyKeyword))); - - /// - /// Builds (and caches) the TryGet public static method declaration. - /// The GetHandle logic is inlined as a switch expression on member. - /// - private static MethodDeclarationSyntax BuildTryGetMethod() => _tryGetMethod ??= - // public static LambdaExpression TryGet(MemberInfo member) - // { - // var handle = member switch - // { - // MethodInfo m => (nint?)m.MethodHandle.Value, - // PropertyInfo p => p.GetMethod?.MethodHandle.Value, - // ConstructorInfo c => (nint?)c.MethodHandle.Value, - // _ => null - // }; - // return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; - // } - MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet") - .WithModifiers(TokenList( - Token(SyntaxKind.PublicKeyword), - Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters( - Parameter(Identifier("member")) - .WithType(ParseTypeName("MemberInfo"))) - .WithBody(Block( - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("handle") - .WithInitializer(EqualsValueClause( - SwitchExpression(IdentifierName("member")) - .WithArms(SeparatedList( - new SyntaxNodeOrToken[] - { - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("MethodInfo"), - SingleVariableDesignation(Identifier("m"))), - ParseExpression("(nint?)m.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("PropertyInfo"), - SingleVariableDesignation(Identifier("p"))), - ParseExpression("p.GetMethod?.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DeclarationPattern( - ParseTypeName("ConstructorInfo"), - SingleVariableDesignation(Identifier("c"))), - ParseExpression("(nint?)c.MethodHandle.Value")), - Token(SyntaxKind.CommaToken), - SwitchExpressionArm( - DiscardPattern(), - LiteralExpression(SyntaxKind.NullLiteralExpression)) - })))))), - ReturnStatement( - ParseExpression( - "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))); - - /// - /// Builds the Register private static helper method that all per-entry calls delegate to. - /// It handles the null checks and the common reflection lookup pattern once, centrally. - /// - private static MethodDeclarationSyntax BuildRegisterHelperMethod() => _registerHelperMethod ??= - // private static void Register(Dictionary map, MethodBase m, string exprClass) - // { - // if (m is null) return; - // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); - // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - // if (exprMethod is not null) - // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; - // } - MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Register") - .WithModifiers(TokenList( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters( - Parameter(Identifier("map")) - .WithType(ParseTypeName("Dictionary")), - Parameter(Identifier("m")) - .WithType(ParseTypeName("MethodBase")), - Parameter(Identifier("exprClass")) - .WithType(PredefinedType(Token(SyntaxKind.StringKeyword)))) - .WithBody(Block( - // if (m is null) return; - IfStatement( - IsPatternExpression( - IdentifierName("m"), - ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), - ReturnStatement()), - // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("exprType") - .WithInitializer(EqualsValueClause( - ParseExpression("m.DeclaringType?.Assembly.GetType(exprClass)"))))), - // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables( - VariableDeclarator("exprMethod") - .WithInitializer(EqualsValueClause( - ParseExpression( - @"exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic)"))))), - // if (exprMethod is not null) - // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; - IfStatement( - ParseExpression("exprMethod is not null"), - ExpressionStatement( - ParseExpression( - "map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!"))))); - - /// - /// Builds the typeof(...)-array expression used for reflection method/constructor lookup. - /// Returns global::System.Type.EmptyTypes when there are no parameters. - /// - private static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames) - { - if (parameterTypeNames.IsEmpty) - { - return ParseExpression("global::System.Type.EmptyTypes"); - } - - var typeofExprs = parameterTypeNames - .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name))) - .ToArray(); - - return ArrayCreationExpression( - ArrayType(ParseTypeName("global::System.Type")) - .AddRankSpecifiers(ArrayRankSpecifier())) - .WithInitializer( - InitializerExpression(SyntaxKind.ArrayInitializerExpression, - SeparatedList(typeofExprs))); - } } } \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs new file mode 100644 index 0000000..16fdb01 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs @@ -0,0 +1,207 @@ +using System.CodeDom.Compiler; +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Emits the ProjectionRegistry.g.cs source file that aggregates all projectable + /// members into a single static dictionary keyed by . + /// + /// + /// Code is generated using rather than Roslyn's + /// SyntaxFactory: the output has a fixed template shape that never changes, so + /// string-based generation is simpler, more readable, and easier to maintain. + /// + /// + internal static class ProjectionRegistryEmitter + { + /// + /// Builds and adds ProjectionRegistry.g.cs to the compilation output. + /// The file is only emitted when at least one non-generic, non-extension projectable + /// member is present (i.e. when yields at least one representable entry). + /// + public static void Emit(ImmutableArray entries, SourceProductionContext context) + { + var validEntries = entries + .Where(e => e is not null) + .Select(e => e!) + .ToList(); + + if (validEntries.Count == 0) + { + return; + } + + // IndentedTextWriter wraps a TextWriter; keep a reference to the StringWriter + // so we can read the result back with .ToString() after all writes are done. + var sw = new StringWriter(); + var writer = new IndentedTextWriter(sw, " "); + + writer.WriteLine("// "); + writer.WriteLine("#nullable disable"); + writer.WriteLine(); + writer.WriteLine("using System;"); + writer.WriteLine("using System.Collections.Generic;"); + writer.WriteLine("using System.Linq.Expressions;"); + writer.WriteLine("using System.Reflection;"); + writer.WriteLine(); + writer.WriteLine("namespace EntityFrameworkCore.Projectables.Generated"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); + writer.WriteLine("internal static class ProjectionRegistry"); + writer.WriteLine("{"); + writer.Indent++; + + EmitBuildMethod(writer, validEntries); + writer.WriteLine(); + EmitMapField(writer); + writer.WriteLine(); + EmitTryGetMethod(writer); + writer.WriteLine(); + EmitRegisterHelper(writer); + + writer.Indent--; + writer.WriteLine("}"); + writer.Indent--; + writer.WriteLine("}"); + + context.AddSource("ProjectionRegistry.g.cs", + SourceText.From(sw.ToString(), Encoding.UTF8)); + } + + /// + /// Emits the private Build() method that populates the runtime registry. + /// Each projectable member is registered via the shared Register(...) helper to keep + /// the method body compact — null-safety and the reflection lookup are handled once, centrally. + /// + private static void EmitBuildMethod(IndentedTextWriter writer, List entries) + { + writer.WriteLine("private static Dictionary Build()"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static;"); + writer.WriteLine("var map = new Dictionary();"); + writer.WriteLine(); + + foreach (var entry in entries) + { + WriteRegistryEntryStatement(writer, entry); + } + + writer.WriteLine(); + writer.WriteLine("return map;"); + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Emits a single Register(map, typeof(T).GetXxx(...), "ClassName") call + /// for one projectable entry inside Build(). + /// + private static void WriteRegistryEntryStatement(IndentedTextWriter writer, ProjectableRegistryEntry entry) + { + // Build the reflection-lookup expression for the member, switching on its kind. + string? memberCallExpr = entry.MemberKind switch + { + // typeof(T).GetProperty("Name", allFlags)?.GetMethod + ProjectableRegistryMemberType.Property => + $"typeof({entry.DeclaringTypeFullName}).GetProperty(\"{entry.MemberLookupName}\", allFlags)?.GetMethod", + + // typeof(T).GetMethod("Name", allFlags, null, new Type[] { typeof(P1), … }, null) + ProjectableRegistryMemberType.Method => + $"typeof({entry.DeclaringTypeFullName}).GetMethod(\"{entry.MemberLookupName}\", allFlags, null, {BuildTypeArrayExpr(entry.ParameterTypeNames)}, null)", + + // typeof(T).GetConstructor(allFlags, null, new Type[] { typeof(P1), … }, null) + ProjectableRegistryMemberType.Constructor => + $"typeof({entry.DeclaringTypeFullName}).GetConstructor(allFlags, null, {BuildTypeArrayExpr(entry.ParameterTypeNames)}, null)", + + _ => null + }; + + if (memberCallExpr is not null) + { + writer.WriteLine($"Register(map, {memberCallExpr}, \"{entry.GeneratedClassFullName}\");"); + } + } + + /// + /// Emits the _map field that lazily builds the registry once at class-load time: + /// private static readonly Dictionary<nint, LambdaExpression> _map = Build(); + /// + private static void EmitMapField(IndentedTextWriter writer) + { + writer.WriteLine("private static readonly Dictionary _map = Build();"); + } + + /// + /// Emits the public TryGet method. + /// It resolves the runtime for any + /// subtype via a switch expression, + /// then looks it up in _map. + /// + private static void EmitTryGetMethod(IndentedTextWriter writer) + { + writer.WriteLine("public static LambdaExpression TryGet(MemberInfo member)"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("var handle = member switch"); + writer.WriteLine("{"); + writer.Indent++; + writer.WriteLine("MethodInfo m => (nint?)m.MethodHandle.Value,"); + writer.WriteLine("PropertyInfo p => p.GetMethod?.MethodHandle.Value,"); + writer.WriteLine("ConstructorInfo c => (nint?)c.MethodHandle.Value,"); + writer.WriteLine("_ => null"); + writer.Indent--; + writer.WriteLine("};"); + writer.WriteLine(); + writer.WriteLine("return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null;"); + + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Emits the private Register static helper shared by all per-entry calls in Build(). + /// Centralises the null-check and the common reflection-to-expression lookup so that + /// each entry only needs one compact call site. + /// + private static void EmitRegisterHelper(IndentedTextWriter writer) + { + writer.WriteLine("private static void Register(Dictionary map, MethodBase m, string exprClass)"); + writer.WriteLine("{"); + writer.Indent++; + writer.WriteLine("if (m is null) return;"); + writer.WriteLine("var exprType = m.DeclaringType?.Assembly.GetType(exprClass);"); + writer.WriteLine(@"var exprMethod = exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic);"); + writer.WriteLine("if (exprMethod is not null)"); + writer.Indent++; + writer.WriteLine("map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!;"); + writer.Indent--; + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Returns the C# expression for a Type[] used in reflection method/constructor lookups. + /// Returns global::System.Type.EmptyTypes when is empty. + /// + private static string BuildTypeArrayExpr(ImmutableArray parameterTypeNames) + { + if (parameterTypeNames.IsEmpty) + { + return "global::System.Type.EmptyTypes"; + } + + var typeofExprs = string.Join(", ", parameterTypeNames.Select(name => $"typeof({name})")); + return $"new global::System.Type[] {{ {typeofExprs} }}"; + } + } +} + diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt index 9d3de1b..76399cd 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,32 +15,35 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_int"); Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(long) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_long"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt index 4618ba0..8e92fdf 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,32 +15,35 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt index 9b74a38..f21aa56 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,31 +15,34 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt index 9b74a38..f21aa56 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,31 +15,34 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt index 51f6d39..5af832d 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,31 +15,34 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt index 9b74a38..f21aa56 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt @@ -1,5 +1,6 @@ // #nullable disable + using System; using System.Collections.Generic; using System.Linq.Expressions; @@ -14,31 +15,34 @@ namespace EntityFrameworkCore.Projectables.Generated { const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; var map = new Dictionary(); + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + return map; } - + private static readonly Dictionary _map = Build(); + public static LambdaExpression TryGet(MemberInfo member) { var handle = member switch { - MethodInfo m => (nint? )m.MethodHandle.Value, - PropertyInfo p => p.GetMethod?.MethodHandle.Value, - ConstructorInfo c => (nint? )c.MethodHandle.Value, - _ => null + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null }; + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; } - + private static void Register(Dictionary map, MethodBase m, string exprClass) { - if (m is null) - return; + if (m is null) return; var exprType = m.DeclaringType?.Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is not null) map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; } } -} \ No newline at end of file +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs index f8b8782..3a082b2 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs @@ -1,6 +1,3 @@ -using System.Threading.Tasks; -using VerifyXunit; -using Xunit; using Xunit.Abstractions; namespace EntityFrameworkCore.Projectables.Generator.Tests; @@ -16,7 +13,9 @@ public Task NoProjectables_NoRegistry() var compilation = CreateCompilation(@"class C { }"); var result = RunGenerator(compilation); - return Verifier.Verify(result.RegistryTree?.GetText().ToString()); + Assert.Null(result.RegistryTree); + + return Task.CompletedTask; } [Fact] @@ -85,8 +84,10 @@ class C { } }"); var result = RunGenerator(compilation); - - return Verifier.Verify(result.RegistryTree?.GetText().ToString()); + + Assert.Null(result.RegistryTree); + + return Task.CompletedTask; } [Fact] From 888b7f6dddab1e295882ee7b391a23158766e8e5 Mon Sep 17 00:00:00 2001 From: "fabien.menager" Date: Sun, 8 Mar 2026 14:08:08 +0100 Subject: [PATCH 12/12] Add back generation log --- .../ProjectionExpressionGeneratorTestsBase.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index 2e42c2e..6508c2b 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -157,6 +157,8 @@ protected TestGeneratorRunResult RunGenerator(Compilation compilation) var rawResult = driver.GetRunResult(); var result = new TestGeneratorRunResult(rawResult); + + LogGeneratorResult(result, outputCompilation); return (driver, result); }