diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs new file mode 100644 index 0000000..3ab54d7 --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs @@ -0,0 +1,56 @@ +using System.Linq.Expressions; +using System.Reflection; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using EntityFrameworkCore.Projectables.Services; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Micro-benchmarks in + /// isolation (no EF Core overhead) to directly compare the static registry path against + /// the reflection-based path (). + /// + [MemoryDiagnoser] + public class ExpressionResolverBenchmark + { + private static readonly MemberInfo _propertyMember = + typeof(TestEntity).GetProperty(nameof(TestEntity.IdPlus1))!; + + private static readonly MemberInfo _methodMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlus1Method))!; + + private static readonly MemberInfo _methodWithParamMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlusDelta), new[] { typeof(int) })!; + + private readonly ProjectionExpressionResolver _resolver = new(); + + // ── Registry (source-generated) path ───────────────────────────────── + + [Benchmark(Baseline = true)] + public LambdaExpression? ResolveProperty_Registry() + => _resolver.FindGeneratedExpression(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod_Registry() + => _resolver.FindGeneratedExpression(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam_Registry() + => _resolver.FindGeneratedExpression(_methodWithParamMember); + + // ── Reflection path ─────────────────────────────────────────────────── + + [Benchmark] + public LambdaExpression? ResolveProperty_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodWithParamMember); + } +} diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs index 4bc741c..68a04d8 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs @@ -15,5 +15,8 @@ public class TestEntity [Projectable] public int IdPlus1Method() => Id + 1; + + [Projectable] + public int IdPlusDelta(int delta) => Id + delta; } } diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs index d064b7d..b9cda85 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class PlainOverhead { [Benchmark(Baseline = true)] diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs index fdb0f9f..467d470 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableExtensionMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs index 785b52c..0618627 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs index 6f2fa57..3a2e2be 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableProperties { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs new file mode 100644 index 0000000..3876641 --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs @@ -0,0 +1,72 @@ +using System.Linq; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using Microsoft.EntityFrameworkCore; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Measures the per-DbContext cold-start cost of resolver lookup by creating a new + /// on every iteration. The previous benchmarks reuse a single + /// DbContext for 10 000 iterations, so the resolver cache is warm after the first query — + /// these benchmarks expose the cost of the very first query per context. + /// + [MemoryDiagnoser] + public class ResolverOverhead + { + const int Iterations = 1000; + + /// Baseline: no projectables, new DbContext per query. + [Benchmark(Baseline = true)] + public void WithoutProjectables_FreshDbContext() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(false); + dbContext.Entities.Select(x => x.Id + 1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable property. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Property() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Method() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1Method()).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method that takes a parameter, + /// exercising parameter-type disambiguation in the registry key. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_MethodWithParam() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlusDelta(5)).ToQueryString(); + } + } + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs new file mode 100644 index 0000000..bd4930e --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs @@ -0,0 +1,6 @@ +// Polyfill for C# 9 record types when targeting netstandard2.0 or netstandard2.1 +// The compiler requires this type to exist in order to use init-only setters (used by records). +namespace System.Runtime.CompilerServices +{ + internal sealed class IsExternalInit { } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs new file mode 100644 index 0000000..c7adfad --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -0,0 +1,53 @@ +using System.Collections.Immutable; +using System.Linq; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Incremental-pipeline-safe representation of a single projectable member. + /// Contains only primitive types and an equatable wrapper around + /// so that structural value equality works correctly across incremental generation steps. + /// + sealed internal record ProjectableRegistryEntry( + string DeclaringTypeFullName, + ProjectableRegistryMemberType MemberKind, + string MemberLookupName, + string GeneratedClassFullName, + EquatableImmutableArray ParameterTypeNames + ); + + /// + /// A structural-equality wrapper around of strings. + /// uses reference equality by default, which breaks + /// Roslyn's incremental-source-generator caching when the same logical array is + /// produced by two different steps. This wrapper provides element-wise equality so + /// that incremental steps are correctly cached and skipped. + /// + readonly internal struct EquatableImmutableArray(ImmutableArray array) : IEquatable + { + private readonly ImmutableArray _array = array; + + public bool Equals(EquatableImmutableArray other) => + _array.SequenceEqual(other._array); + + public override bool Equals(object? obj) => + obj is EquatableImmutableArray other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = 17; + foreach (var s in _array) + { + hash = hash * 31 + (s?.GetHashCode() ?? 0); + } + + return hash; + } + } + + public static implicit operator ImmutableArray(EquatableImmutableArray e) => e._array; + public static implicit operator EquatableImmutableArray(ImmutableArray a) => new(a); + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs new file mode 100644 index 0000000..4697d72 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryMemberType.cs @@ -0,0 +1,8 @@ +namespace EntityFrameworkCore.Projectables.Generator; + +public enum ProjectableRegistryMemberType : byte +{ + Property, + Method, + Constructor, +} \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 09f6370..328615f 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using System.Collections.Immutable; using System.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -13,7 +14,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator { private const string ProjectablesAttributeName = "EntityFrameworkCore.Projectables.ProjectableAttribute"; - static readonly AttributeSyntax _editorBrowsableAttribute = + private readonly static AttributeSyntax _editorBrowsableAttribute = Attribute( ParseName("global::System.ComponentModel.EditorBrowsable"), AttributeArgumentList( @@ -46,14 +47,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var compilationAndMemberPairs = memberDeclarations .Combine(context.CompilationProvider) .WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer()); - + context.RegisterSourceOutput(compilationAndMemberPairs, static (spc, source) => { var ((member, attribute), compilation) = source; var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); var memberSymbol = semanticModel.GetDeclaredSymbol(member); - + if (memberSymbol is null) { return; @@ -61,9 +62,31 @@ public void Initialize(IncrementalGeneratorInitializationContext context) Execute(member, semanticModel, memberSymbol, attribute, compilation, spc); }); + + // Build the projection registry: collect all entries and emit a single registry file + var registryEntries = compilationAndMemberPairs.Select( + static (source, cancellationToken) => { + var ((member, _), compilation) = source; + + var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); + var memberSymbol = semanticModel.GetDeclaredSymbol(member, cancellationToken); + + if (memberSymbol is null) + { + return null; + } + + return ExtractRegistryEntry(memberSymbol); + }); + + // Delegate registry file emission to the dedicated ProjectionRegistryEmitter, + // which uses a string-based CodeWriter instead of SyntaxFactory. + context.RegisterImplementationSourceOutput( + registryEntries.Collect(), + static (spc, entries) => ProjectionRegistryEmitter.Emit(entries, spc)); } - static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) + private static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) { var chain = CollectConstructorChain(ctor, compilation); @@ -104,7 +127,7 @@ void AddLine(string text) /// then its delegate's delegate, …). Stops when a delegated constructor has no source /// available in the compilation (e.g. a compiler-synthesised parameterless constructor). /// - static IReadOnlyList CollectConstructorChain( + private static List CollectConstructorChain( ConstructorDeclarationSyntax ctor, Compilation compilation) { var result = new List { ctor }; @@ -115,7 +138,9 @@ static IReadOnlyList CollectConstructorChain( { var semanticModel = compilation.GetSemanticModel(current.SyntaxTree); if (semanticModel.GetSymbolInfo(initializer).Symbol is not IMethodSymbol delegated) + { break; + } var delegatedSyntax = delegated.DeclaringSyntaxReferences .Select(r => r.GetSyntax()) @@ -123,7 +148,9 @@ static IReadOnlyList CollectConstructorChain( .FirstOrDefault(); if (delegatedSyntax is null || !visited.Add(delegatedSyntax)) + { break; + } result.Add(delegatedSyntax); current = delegatedSyntax; @@ -132,7 +159,7 @@ static IReadOnlyList CollectConstructorChain( return result; } - static void Execute( + private static void Execute( MemberDeclarationSyntax member, SemanticModel semanticModel, ISymbol memberSymbol, @@ -193,14 +220,14 @@ static void Execute( ) ) ) - ) + ) ); #nullable disable var compilationUnit = CompilationUnit(); - foreach (var usingDirective in projectable.UsingDirectives) + foreach (var usingDirective in projectable.UsingDirectives!) { compilationUnit = compilationUnit.AddUsings(usingDirective); } @@ -229,7 +256,6 @@ static void Execute( ) ); - context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable) @@ -249,5 +275,99 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip return lambdaTypeArguments; } } + +#nullable restore + + /// + /// Extracts a from a member declaration. + /// Returns null when the member does not have [Projectable], is an extension member, + /// or cannot be represented in the registry (e.g. a generic class member or generic method). + /// + private static ProjectableRegistryEntry? ExtractRegistryEntry(ISymbol memberSymbol) + { + var containingType = memberSymbol.ContainingType; + + // Skip C# 14 extension type members — they require special handling (fall back to reflection) + if (containingType is { IsExtension: true }) + { + return null; + } + + // Skip generic classes: the registry only supports closed constructed types. + if (containingType.TypeParameters.Length > 0) + { + return null; + } + + // Determine member kind and lookup name + ProjectableRegistryMemberType memberKind; + string memberLookupName; + var parameterTypeNames = ImmutableArray.Empty; + + if (memberSymbol is IMethodSymbol methodSymbol) + { + // Skip generic methods for the same reason as generic classes + if (methodSymbol.TypeParameters.Length > 0) + { + return null; + } + + if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor) + { + memberKind = ProjectableRegistryMemberType.Constructor; + memberLookupName = "_ctor"; + } + else + { + memberKind = ProjectableRegistryMemberType.Method; + memberLookupName = memberSymbol.Name; + } + + parameterTypeNames = [ + ..methodSymbol.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; + } + else + { + memberKind = ProjectableRegistryMemberType.Property; + memberLookupName = memberSymbol.Name; + } + + // Build the generated class name using the same logic as Execute + var classNamespace = containingType.ContainingNamespace.IsGlobalNamespace + ? null + : containingType.ContainingNamespace.ToDisplayString(); + + var nestedTypePath = GetRegistryNestedTypePath(containingType); + + var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName( + classNamespace, + nestedTypePath, + memberLookupName, + parameterTypeNames.IsEmpty ? null : parameterTypeNames); + + var generatedClassFullName = "EntityFrameworkCore.Projectables.Generated." + generatedClassName; + + var declaringTypeFullName = containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + return new ProjectableRegistryEntry( + DeclaringTypeFullName: declaringTypeFullName, + MemberKind: memberKind, + MemberLookupName: memberLookupName, + GeneratedClassFullName: generatedClassFullName, + ParameterTypeNames: parameterTypeNames); + } + + private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) + { + if (typeSymbol.ContainingType is not null) + { + foreach (var name in GetRegistryNestedTypePath(typeSymbol.ContainingType)) + { + yield return name; + } + } + yield return typeSymbol.Name; + } } -} +} \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs new file mode 100644 index 0000000..16fdb01 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionRegistryEmitter.cs @@ -0,0 +1,207 @@ +using System.CodeDom.Compiler; +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Emits the ProjectionRegistry.g.cs source file that aggregates all projectable + /// members into a single static dictionary keyed by . + /// + /// + /// Code is generated using rather than Roslyn's + /// SyntaxFactory: the output has a fixed template shape that never changes, so + /// string-based generation is simpler, more readable, and easier to maintain. + /// + /// + internal static class ProjectionRegistryEmitter + { + /// + /// Builds and adds ProjectionRegistry.g.cs to the compilation output. + /// The file is only emitted when at least one non-generic, non-extension projectable + /// member is present (i.e. when yields at least one representable entry). + /// + public static void Emit(ImmutableArray entries, SourceProductionContext context) + { + var validEntries = entries + .Where(e => e is not null) + .Select(e => e!) + .ToList(); + + if (validEntries.Count == 0) + { + return; + } + + // IndentedTextWriter wraps a TextWriter; keep a reference to the StringWriter + // so we can read the result back with .ToString() after all writes are done. + var sw = new StringWriter(); + var writer = new IndentedTextWriter(sw, " "); + + writer.WriteLine("// "); + writer.WriteLine("#nullable disable"); + writer.WriteLine(); + writer.WriteLine("using System;"); + writer.WriteLine("using System.Collections.Generic;"); + writer.WriteLine("using System.Linq.Expressions;"); + writer.WriteLine("using System.Reflection;"); + writer.WriteLine(); + writer.WriteLine("namespace EntityFrameworkCore.Projectables.Generated"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); + writer.WriteLine("internal static class ProjectionRegistry"); + writer.WriteLine("{"); + writer.Indent++; + + EmitBuildMethod(writer, validEntries); + writer.WriteLine(); + EmitMapField(writer); + writer.WriteLine(); + EmitTryGetMethod(writer); + writer.WriteLine(); + EmitRegisterHelper(writer); + + writer.Indent--; + writer.WriteLine("}"); + writer.Indent--; + writer.WriteLine("}"); + + context.AddSource("ProjectionRegistry.g.cs", + SourceText.From(sw.ToString(), Encoding.UTF8)); + } + + /// + /// Emits the private Build() method that populates the runtime registry. + /// Each projectable member is registered via the shared Register(...) helper to keep + /// the method body compact — null-safety and the reflection lookup are handled once, centrally. + /// + private static void EmitBuildMethod(IndentedTextWriter writer, List entries) + { + writer.WriteLine("private static Dictionary Build()"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static;"); + writer.WriteLine("var map = new Dictionary();"); + writer.WriteLine(); + + foreach (var entry in entries) + { + WriteRegistryEntryStatement(writer, entry); + } + + writer.WriteLine(); + writer.WriteLine("return map;"); + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Emits a single Register(map, typeof(T).GetXxx(...), "ClassName") call + /// for one projectable entry inside Build(). + /// + private static void WriteRegistryEntryStatement(IndentedTextWriter writer, ProjectableRegistryEntry entry) + { + // Build the reflection-lookup expression for the member, switching on its kind. + string? memberCallExpr = entry.MemberKind switch + { + // typeof(T).GetProperty("Name", allFlags)?.GetMethod + ProjectableRegistryMemberType.Property => + $"typeof({entry.DeclaringTypeFullName}).GetProperty(\"{entry.MemberLookupName}\", allFlags)?.GetMethod", + + // typeof(T).GetMethod("Name", allFlags, null, new Type[] { typeof(P1), … }, null) + ProjectableRegistryMemberType.Method => + $"typeof({entry.DeclaringTypeFullName}).GetMethod(\"{entry.MemberLookupName}\", allFlags, null, {BuildTypeArrayExpr(entry.ParameterTypeNames)}, null)", + + // typeof(T).GetConstructor(allFlags, null, new Type[] { typeof(P1), … }, null) + ProjectableRegistryMemberType.Constructor => + $"typeof({entry.DeclaringTypeFullName}).GetConstructor(allFlags, null, {BuildTypeArrayExpr(entry.ParameterTypeNames)}, null)", + + _ => null + }; + + if (memberCallExpr is not null) + { + writer.WriteLine($"Register(map, {memberCallExpr}, \"{entry.GeneratedClassFullName}\");"); + } + } + + /// + /// Emits the _map field that lazily builds the registry once at class-load time: + /// private static readonly Dictionary<nint, LambdaExpression> _map = Build(); + /// + private static void EmitMapField(IndentedTextWriter writer) + { + writer.WriteLine("private static readonly Dictionary _map = Build();"); + } + + /// + /// Emits the public TryGet method. + /// It resolves the runtime for any + /// subtype via a switch expression, + /// then looks it up in _map. + /// + private static void EmitTryGetMethod(IndentedTextWriter writer) + { + writer.WriteLine("public static LambdaExpression TryGet(MemberInfo member)"); + writer.WriteLine("{"); + writer.Indent++; + + writer.WriteLine("var handle = member switch"); + writer.WriteLine("{"); + writer.Indent++; + writer.WriteLine("MethodInfo m => (nint?)m.MethodHandle.Value,"); + writer.WriteLine("PropertyInfo p => p.GetMethod?.MethodHandle.Value,"); + writer.WriteLine("ConstructorInfo c => (nint?)c.MethodHandle.Value,"); + writer.WriteLine("_ => null"); + writer.Indent--; + writer.WriteLine("};"); + writer.WriteLine(); + writer.WriteLine("return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null;"); + + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Emits the private Register static helper shared by all per-entry calls in Build(). + /// Centralises the null-check and the common reflection-to-expression lookup so that + /// each entry only needs one compact call site. + /// + private static void EmitRegisterHelper(IndentedTextWriter writer) + { + writer.WriteLine("private static void Register(Dictionary map, MethodBase m, string exprClass)"); + writer.WriteLine("{"); + writer.Indent++; + writer.WriteLine("if (m is null) return;"); + writer.WriteLine("var exprType = m.DeclaringType?.Assembly.GetType(exprClass);"); + writer.WriteLine(@"var exprMethod = exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic);"); + writer.WriteLine("if (exprMethod is not null)"); + writer.Indent++; + writer.WriteLine("map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!;"); + writer.Indent--; + writer.Indent--; + writer.WriteLine("}"); + } + + /// + /// Returns the C# expression for a Type[] used in reflection method/constructor lookups. + /// Returns global::System.Type.EmptyTypes when is empty. + /// + private static string BuildTypeArrayExpr(ImmutableArray parameterTypeNames) + { + if (parameterTypeNames.IsEmpty) + { + return "global::System.Type.EmptyTypes"; + } + + var typeofExprs = string.Join(", ", parameterTypeNames.Select(name => $"typeof({name})")); + return $"new global::System.Type[] {{ {typeofExprs} }}"; + } + } +} + diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 16e547a..e342a6e 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -21,27 +21,19 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor private readonly bool _trackingByDefault; private IEntityType? _entityType; - private readonly MethodInfo _select; - private readonly MethodInfo _where; + // Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain) + private static readonly MethodInfo _select = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition(); + + private static readonly MethodInfo _where = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition(); public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false) { _trackingByDefault = trackByDefault; _resolver = projectionExpressionResolver; - _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Select)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); - _where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Where)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 7e26d69..a4b8cd6 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -8,6 +9,35 @@ namespace EntityFrameworkCore.Projectables.Services { public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver { + // We never store null in the dictionary; assemblies without a registry use a sentinel delegate. + private static readonly Func _nullRegistry = static _ => null!; + private static readonly ConcurrentDictionary> _assemblyRegistries = new(); + + /// + /// Looks up the generated ProjectionRegistry class in an assembly (once, then caches it). + /// Returns a delegate that calls TryGet(MemberInfo) on the registry, or null if the registry + /// is not present in that assembly (e.g. if the source generator was not run against it). + /// + private static Func? GetAssemblyRegistry(Assembly assembly) + { + var registry = _assemblyRegistries.GetOrAdd(assembly, static asm => + { + var registryType = asm.GetType("EntityFrameworkCore.Projectables.Generated.ProjectionRegistry"); + var tryGetMethod = registryType?.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); + + if (tryGetMethod is null) + { + // Use sentinel to indicate "no registry for this assembly" + return _nullRegistry; + } + + return (Func)Delegate.CreateDelegate(typeof(Func), tryGetMethod); + }); + + // Translate sentinel back to null for callers, preserving existing behavior. + return ReferenceEquals(registry, _nullRegistry) ? null : registry; + } + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) { var projectableAttribute = projectableMemberInfo.GetCustomAttribute() @@ -62,165 +92,186 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + + // Fast path: check the per-assembly static registry (generated by source generator). + // The first call per assembly does a reflection lookup to find the registry class and + // caches it as a delegate; subsequent calls use the cached delegate for an O(1) dictionary lookup. + var registry = GetAssemblyRegistry(declaringType.Assembly); + var registeredExpr = registry?.Invoke(projectableMemberInfo); - // Keep track of the original declaring type's generic arguments for later use - var originalDeclaringType = declaringType; - - // For generic types, use the generic type definition to match the generated name - // which is based on the open generic type - if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) - { - declaringType = declaringType.GetGenericTypeDefinition(); - } - - // Get parameter types for method overload disambiguation - // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat - // which uses C# keywords for primitive types (int, string, etc.) - string[]? parameterTypeNames = null; - string memberLookupName = projectableMemberInfo.Name; - if (projectableMemberInfo is MethodInfo method) - { - // For generic methods, use the generic definition to get parameter types - // This ensures type parameters like TEntity are used instead of concrete types - var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; - - parameterTypeNames = methodToInspect.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); - } - else if (projectableMemberInfo is ConstructorInfo ctor) + return registeredExpr ?? + // Slow path: reflection fallback for open-generic class members and generic methods + // that are not yet in the registry. + FindGeneratedExpressionViaReflection(projectableMemberInfo); + } + } + + /// + /// Resolves the for a [Projectable] member using the + /// reflection-based slow path only, bypassing the static registry. + /// Useful for benchmarking and for members not yet in the registry (e.g. open-generic types). + /// + public static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo) + { + var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + + // Keep track of the original declaring type's generic arguments for later use + var originalDeclaringType = declaringType; + + // For generic types, use the generic type definition to match the generated name + // which is based on the open generic type + if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) + { + declaringType = declaringType.GetGenericTypeDefinition(); + } + + // Get parameter types for method overload disambiguation + // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat + // which uses C# keywords for primitive types (int, string, etc.) + string[]? parameterTypeNames = null; + string memberLookupName = projectableMemberInfo.Name; + if (projectableMemberInfo is MethodInfo method) + { + // For generic methods, use the generic definition to get parameter types + // This ensures type parameters like TEntity are used instead of concrete types + var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; + + parameterTypeNames = methodToInspect.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + else if (projectableMemberInfo is ConstructorInfo ctor) + { + // Constructors are stored under the synthetic name "_ctor" + memberLookupName = "_ctor"; + parameterTypeNames = ctor.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + + var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); + + var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + + if (expressionFactoryType is not null) + { + if (expressionFactoryType.IsGenericTypeDefinition) { - // Constructors are stored under the synthetic name "_ctor" - memberLookupName = "_ctor"; - parameterTypeNames = ctor.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); + expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); } - - var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); - var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - if (expressionFactoryType is not null) + var methodGenericArguments = projectableMemberInfo switch { + MethodInfo methodInfo => methodInfo.GetGenericArguments(), + _ => null + }; + + if (expressionFactoryMethod is not null) { - if (expressionFactoryType.IsGenericTypeDefinition) + if (methodGenericArguments is { Length: > 0 }) { - expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); + expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); } - var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - - var methodGenericArguments = projectableMemberInfo switch { - MethodInfo methodInfo => methodInfo.GetGenericArguments(), - _ => null - }; + return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); + } + } - if (expressionFactoryMethod is not null) - { - if (methodGenericArguments is { Length: > 0 }) - { - expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); - } + return null; + } - return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); - } - } + private static string GetFullTypeName(Type type) + { + // Handle generic type parameters (e.g., T, TEntity) + if (type.IsGenericParameter) + { + return type.Name; + } - return null; + // Handle nullable value types (e.g., int? -> int?) + var underlyingType = Nullable.GetUnderlyingType(type); + if (underlyingType != null) + { + return $"{GetFullTypeName(underlyingType)}?"; } - - static string GetFullTypeName(Type type) + + // Handle array types + if (type.IsArray) { - // Handle generic type parameters (e.g., T, TEntity) - if (type.IsGenericParameter) + var elementType = type.GetElementType(); + if (elementType == null) { + // Fallback for edge cases where GetElementType() might return null return type.Name; } - - // Handle nullable value types (e.g., int? -> int?) - var underlyingType = Nullable.GetUnderlyingType(type); - if (underlyingType != null) - { - return $"{GetFullTypeName(underlyingType)}?"; - } - - // Handle array types - if (type.IsArray) - { - var elementType = type.GetElementType(); - if (elementType == null) - { - // Fallback for edge cases where GetElementType() might return null - return type.Name; - } - - var rank = type.GetArrayRank(); - var elementTypeName = GetFullTypeName(elementType); - - if (rank == 1) - { - return $"{elementTypeName}[]"; - } - else - { - var commas = new string(',', rank - 1); - return $"{elementTypeName}[{commas}]"; - } - } - - // Map primitive types to their C# keyword equivalents to match Roslyn's output - var typeKeyword = GetCSharpKeyword(type); - if (typeKeyword != null) + + var rank = type.GetArrayRank(); + var elementTypeName = GetFullTypeName(elementType); + + if (rank == 1) { - return typeKeyword; + return $"{elementTypeName}[]"; } - - // For generic types, construct the full name matching Roslyn's format - if (type.IsGenericType) + else { - var genericTypeDef = type.GetGenericTypeDefinition(); - var genericArgs = type.GetGenericArguments(); - var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; - - // Remove the `n suffix (e.g., `1, `2) - var backtickIndex = baseName.IndexOf('`'); - if (backtickIndex > 0) - { - baseName = baseName.Substring(0, backtickIndex); - } - - var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); - return $"{baseName}<{args}>"; + var commas = new string(',', rank - 1); + return $"{elementTypeName}[{commas}]"; } - - if (type.FullName != null) + } + + // Map primitive types to their C# keyword equivalents to match Roslyn's output + var typeKeyword = GetCSharpKeyword(type); + if (typeKeyword != null) + { + return typeKeyword; + } + + // For generic types, construct the full name matching Roslyn's format + if (type.IsGenericType) + { + var genericTypeDef = type.GetGenericTypeDefinition(); + var genericArgs = type.GetGenericArguments(); + var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; + + // Remove the `n suffix (e.g., `1, `2) + var backtickIndex = baseName.IndexOf('`'); + if (backtickIndex > 0) { - // Replace + with . for nested types to match Roslyn's format - return type.FullName.Replace('+', '.'); + baseName = baseName.Substring(0, backtickIndex); } - - return type.Name; + + var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); + return $"{baseName}<{args}>"; } - - static string? GetCSharpKeyword(Type type) + + if (type.FullName != null) { - if (type == typeof(bool)) return "bool"; - if (type == typeof(byte)) return "byte"; - if (type == typeof(sbyte)) return "sbyte"; - if (type == typeof(char)) return "char"; - if (type == typeof(decimal)) return "decimal"; - if (type == typeof(double)) return "double"; - if (type == typeof(float)) return "float"; - if (type == typeof(int)) return "int"; - if (type == typeof(uint)) return "uint"; - if (type == typeof(long)) return "long"; - if (type == typeof(ulong)) return "ulong"; - if (type == typeof(short)) return "short"; - if (type == typeof(ushort)) return "ushort"; - if (type == typeof(object)) return "object"; - if (type == typeof(string)) return "string"; - return null; + // Replace + with . for nested types to match Roslyn's format + return type.FullName.Replace('+', '.'); } + + return type.Name; + } + + private static string? GetCSharpKeyword(Type type) + { + if (type == typeof(bool)) return "bool"; + if (type == typeof(byte)) return "byte"; + if (type == typeof(sbyte)) return "sbyte"; + if (type == typeof(char)) return "char"; + if (type == typeof(decimal)) return "decimal"; + if (type == typeof(double)) return "double"; + if (type == typeof(float)) return "float"; + if (type == typeof(int)) return "int"; + if (type == typeof(uint)) return "uint"; + if (type == typeof(long)) return "long"; + if (type == typeof(ulong)) return "ulong"; + if (type == typeof(short)) return "short"; + if (type == typeof(ushort)) return "ushort"; + if (type == typeof(object)) return "object"; + if (type == typeof(string)) return "string"; + return null; } } } diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index 563acc5..6508c2b 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -1,5 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -17,6 +18,47 @@ protected ProjectionExpressionGeneratorTestsBase(ITestOutputHelper testOutputHel _testOutputHelper = testOutputHelper; } + /// + /// Wraps and exposes + /// as a filtered view that excludes the generated ProjectionRegistry.g.cs file. + /// This keeps all existing tests working without modification after the registry was added. + /// + protected sealed class TestGeneratorRunResult + { + private readonly GeneratorDriverRunResult _inner; + + public TestGeneratorRunResult(GeneratorDriverRunResult inner) + { + _inner = inner; + } + + /// + /// Diagnostics from the generator run. + /// + public ImmutableArray Diagnostics => _inner.Diagnostics; + + /// + /// Generated trees excluding ProjectionRegistry.g.cs. + /// Existing tests use this and should continue to work without modification. + /// + public ImmutableArray GeneratedTrees => + _inner.GeneratedTrees + .Where(t => !t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)) + .ToImmutableArray(); + + /// + /// All generated trees including ProjectionRegistry.g.cs. + /// Use this in new tests that need to verify the registry. + /// + public ImmutableArray AllGeneratedTrees => _inner.GeneratedTrees; + + /// + /// The generated ProjectionRegistry.g.cs tree, or null if it was not generated. + /// + public SyntaxTree? RegistryTree => + _inner.GeneratedTrees.FirstOrDefault(t => t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)); + } + protected IReadOnlyList GetDefaultReferences() { var references = Basic.Reference.Assemblies. @@ -64,7 +106,7 @@ protected Compilation CreateCompilation([StringSyntax("csharp")] string source) return compilation; } - protected GeneratorDriverRunResult RunGenerator(Compilation compilation) + protected TestGeneratorRunResult RunGenerator(Compilation compilation) { _testOutputHelper.WriteLine("Running generator and updating compilation..."); @@ -73,7 +115,8 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) .Create(subject) .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); LogGeneratorResult(result, outputCompilation); @@ -85,7 +128,7 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) /// returning both the driver and the run result. The driver can be passed to subsequent /// calls to to test incremental caching behavior. /// - protected (GeneratorDriver Driver, GeneratorDriverRunResult Result) CreateAndRunGenerator(Compilation compilation) + protected (GeneratorDriver Driver, TestGeneratorRunResult Result) CreateAndRunGenerator(Compilation compilation) { _testOutputHelper.WriteLine("Creating generator driver and running on initial compilation..."); @@ -93,7 +136,9 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) GeneratorDriver driver = CSharpGeneratorDriver.Create(subject); driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); + LogGeneratorResult(result, outputCompilation); return (driver, result); @@ -103,19 +148,22 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) /// Runs the generator using an existing driver (preserving incremental state from previous runs) /// on a new compilation, returning the updated driver and run result. /// - protected (GeneratorDriver Driver, GeneratorDriverRunResult Result) RunGeneratorWithDriver( + protected (GeneratorDriver Driver, TestGeneratorRunResult Result) RunGeneratorWithDriver( GeneratorDriver driver, Compilation compilation) { _testOutputHelper.WriteLine("Running generator with existing driver on updated compilation..."); driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); + LogGeneratorResult(result, outputCompilation); return (driver, result); } - private void LogGeneratorResult(GeneratorDriverRunResult result, Compilation outputCompilation) + private void LogGeneratorResult(TestGeneratorRunResult result, Compilation outputCompilation) { if (result.Diagnostics.IsEmpty) { @@ -131,7 +179,7 @@ private void LogGeneratorResult(GeneratorDriverRunResult result, Compilation out } } - foreach (var newSyntaxTree in result.GeneratedTrees) + foreach (var newSyntaxTree in result.AllGeneratedTrees) { _testOutputHelper.WriteLine($"Produced syntax tree with path produced: {newSyntaxTree.FilePath}"); _testOutputHelper.WriteLine(newSyntaxTree.GetText().ToString()); @@ -139,7 +187,7 @@ private void LogGeneratorResult(GeneratorDriverRunResult result, Compilation out // Verify that the generated code compiles without errors var hasGeneratorErrors = result.Diagnostics.Any(d => d.Severity == DiagnosticSeverity.Error); - if (!hasGeneratorErrors && result.GeneratedTrees.Length > 0) + if (!hasGeneratorErrors && result.AllGeneratedTrees.Length > 0) { _testOutputHelper.WriteLine("Checking that generated code compiles..."); diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt new file mode 100644 index 0000000..76399cd --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MethodOverloads_BothRegistered.verified.txt @@ -0,0 +1,49 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_int"); + Register(map, typeof(global::Foo.C).GetMethod("Add", allFlags, null, new global::System.Type[] { typeof(long) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_Add_P0_long"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt new file mode 100644 index 0000000..8e92fdf --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.MultipleProjectables_AllRegistered.verified.txt @@ -0,0 +1,49 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt new file mode 100644 index 0000000..f21aa56 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_ConstBindingFlagsUsedInBuild.verified.txt @@ -0,0 +1,48 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt new file mode 100644 index 0000000..f21aa56 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.Registry_RegisterHelperUsesDeclaringTypeAssembly.verified.txt @@ -0,0 +1,48 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt new file mode 100644 index 0000000..5af832d --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt @@ -0,0 +1,48 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetMethod("AddDelta", allFlags, null, new global::System.Type[] { typeof(int) }, null), "EntityFrameworkCore.Projectables.Generated.Foo_C_AddDelta_P0_int"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt new file mode 100644 index 0000000..f21aa56 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt @@ -0,0 +1,48 @@ +// +#nullable disable + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class ProjectionRegistry + { + private static Dictionary Build() + { + const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static; + var map = new Dictionary(); + + Register(map, typeof(global::Foo.C).GetProperty("IdPlus1", allFlags)?.GetMethod, "EntityFrameworkCore.Projectables.Generated.Foo_C_IdPlus1"); + + return map; + } + + private static readonly Dictionary _map = Build(); + + public static LambdaExpression TryGet(MemberInfo member) + { + var handle = member switch + { + MethodInfo m => (nint?)m.MethodHandle.Value, + PropertyInfo p => p.GetMethod?.MethodHandle.Value, + ConstructorInfo c => (nint?)c.MethodHandle.Value, + _ => null + }; + + return handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null; + } + + private static void Register(Dictionary map, MethodBase m, string exprClass) + { + if (m is null) return; + var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + if (exprMethod is not null) + map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs new file mode 100644 index 0000000..3a082b2 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs @@ -0,0 +1,145 @@ +using Xunit.Abstractions; + +namespace EntityFrameworkCore.Projectables.Generator.Tests; + +[UsesVerify] +public class RegistryTests : ProjectionExpressionGeneratorTestsBase +{ + public RegistryTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { } + + [Fact] + public Task NoProjectables_NoRegistry() + { + var compilation = CreateCompilation(@"class C { }"); + var result = RunGenerator(compilation); + + Assert.Null(result.RegistryTree); + + return Task.CompletedTask; + } + + [Fact] + public Task SingleProperty_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } + + [Fact] + public Task SingleMethod_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } + + [Fact] + public Task MultipleProjectables_AllRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } + + [Fact] + public Task GenericClass_NotIncludedInRegistry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.Null(result.RegistryTree); + + return Task.CompletedTask; + } + + [Fact] + public Task Registry_ConstBindingFlagsUsedInBuild() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } + + [Fact] + public Task Registry_RegisterHelperUsesDeclaringTypeAssembly() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } + + [Fact] + public Task MethodOverloads_BothRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int Add(int delta) => Id + delta; + [Projectable] + public long Add(long delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + return Verifier.Verify(result.RegistryTree!.GetText().ToString()); + } +}