From 33cf3a417fb829217e633e0205b0df568b6307f1 Mon Sep 17 00:00:00 2001 From: "alont@superplay.co" Date: Sun, 9 Nov 2025 23:41:04 +0200 Subject: [PATCH] Add resolve delegate support --- src/Jab.Tests/DiagnosticsTest.cs | 70 +++++++++++++++++ src/Jab/AnalyzerReleases.Unshipped.md | 6 +- src/Jab/Attributes.cs | 16 ++++ src/Jab/ContainerGenerator.cs | 13 ++++ src/Jab/DiagnosticDescriptors.cs | 8 ++ src/Jab/KnownTypes.cs | 10 +++ src/Jab/ResolveDelegateCallSite.cs | 14 ++++ src/Jab/ServiceProviderBuilder.cs | 104 ++++++++++++++++++++++++++ 8 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 src/Jab/ResolveDelegateCallSite.cs diff --git a/src/Jab.Tests/DiagnosticsTest.cs b/src/Jab.Tests/DiagnosticsTest.cs index a113b98..ae2616b 100644 --- a/src/Jab.Tests/DiagnosticsTest.cs +++ b/src/Jab.Tests/DiagnosticsTest.cs @@ -129,6 +129,76 @@ await Verify.VerifyAnalyzerAsync(testCode, .WithArguments("Dependency", "Named", "Service")); } + [Fact] + public async Task DoesNotProduceDiagnosticForResolveDelegateWhenServiceRegistered() + { + string testCode = @" +interface IService { } +class Service : IService { } +class Consumer { public Consumer(Resolve resolver) { } } + +[ServiceProvider] +[Transient(typeof(IService), typeof(Service))] +[Transient(typeof(Consumer))] +public partial class Container { } +"; + await Verify.VerifyAnalyzerAsync(testCode); + } + + [Fact] + public async Task ProducesDiagnosticWhenResolveDelegateServiceNotRegistered() + { + string testCode = @" +interface IService { } +class Consumer { public Consumer({|#1:Resolve|} resolver) { } } + +[ServiceProvider] +[Transient(typeof(Consumer))] +public partial class Container { } +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0021") + .WithLocation(1) + .WithArguments("IService")); + } + + [Fact] + public async Task ProducesDiagnosticWhenNamedResolveDelegateHasNoNamedService() + { + string testCode = @" +interface IService { } +class Service : IService { } +class Consumer { public Consumer({|#1:NamedResolve|} resolver) { } } + +[ServiceProvider] +[Transient(typeof(IService), typeof(Service))] +[Transient(typeof(Consumer))] +public partial class Container { } +"; + await Verify.VerifyAnalyzerAsync(testCode, + DiagnosticResult + .CompilerError("JAB0022") + .WithLocation(1) + .WithArguments("IService")); + } + + [Fact] + public async Task DoesNotProduceDiagnosticForNamedResolveWhenNamedServiceRegistered() + { + string testCode = @" +interface IService { } +class Service : IService { } +class Consumer { public Consumer(NamedResolve resolver) { } } + +[ServiceProvider] +[Singleton(typeof(IService), typeof(Service), Name = "Named")] +[Transient(typeof(Consumer))] +public partial class Container { } +"; + await Verify.VerifyAnalyzerAsync(testCode); + } + [Fact] public async Task ProducesJAB0002WhenRequiredDependenciesNotFound() { diff --git a/src/Jab/AnalyzerReleases.Unshipped.md b/src/Jab/AnalyzerReleases.Unshipped.md index 5f28270..4608f70 100644 --- a/src/Jab/AnalyzerReleases.Unshipped.md +++ b/src/Jab/AnalyzerReleases.Unshipped.md @@ -1 +1,5 @@ - \ No newline at end of file +### New Rules +Rule ID | Category | Severity | Notes +--------|----------|----------|------ +JAB0021 | Usage | Error | Resolve delegate requires registered service +JAB0022 | Usage | Error | NamedResolve delegate requires named service registration diff --git a/src/Jab/Attributes.cs b/src/Jab/Attributes.cs index f38f7b3..f486e49 100644 --- a/src/Jab/Attributes.cs +++ b/src/Jab/Attributes.cs @@ -280,6 +280,22 @@ interface INamedServiceProvider T GetService(string name); } +#if JAB_ATTRIBUTES_PACKAGE + public +#else + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Jab", null)] + internal +#endif + delegate T Resolve(string name); + +#if JAB_ATTRIBUTES_PACKAGE + public +#else + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Jab", null)] + internal +#endif + delegate T NamedResolve(string name); + #if JAB_ATTRIBUTES_PACKAGE public #else diff --git a/src/Jab/ContainerGenerator.cs b/src/Jab/ContainerGenerator.cs index 819db06..225c65c 100644 --- a/src/Jab/ContainerGenerator.cs +++ b/src/Jab/ContainerGenerator.cs @@ -193,6 +193,19 @@ private void GenerateCallSite( } }); break; + case ResolveDelegateCallSite resolveDelegateCallSite: + valueCallback(codeWriter, w => + { + if (resolveDelegateCallSite.UsesName) + { + w.Append($"new global::Jab.NamedResolve<{resolveDelegateCallSite.ResolvedType}>(GetService<{resolveDelegateCallSite.ResolvedType}>)"); + } + else + { + w.Append($"new global::Jab.Resolve<{resolveDelegateCallSite.ResolvedType}>(_ => GetService<{resolveDelegateCallSite.ResolvedType}>())"); + } + }); + break; case ServiceProviderCallSite: valueCallback(codeWriter, w => w.AppendRaw("this")); break; diff --git a/src/Jab/DiagnosticDescriptors.cs b/src/Jab/DiagnosticDescriptors.cs index f32026d..6197840 100644 --- a/src/Jab/DiagnosticDescriptors.cs +++ b/src/Jab/DiagnosticDescriptors.cs @@ -74,6 +74,14 @@ internal static class DiagnosticDescriptors "Only string service keys are supported", "Service key '{0}' is not supported, only string keys are supported", "Usage", DiagnosticSeverity.Error, true); + public static readonly DiagnosticDescriptor ResolveDelegateServiceNotRegistered = new("JAB0021", + "Resolve delegate requires registered service", + "Resolve delegate requires the service '{0}' to be registered", "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor NamedResolveDelegateServiceNotRegistered = new("JAB0022", + "NamedResolve delegate requires named service registration", + "NamedResolve delegate requires the service '{0}' to be registered with a name", "Usage", DiagnosticSeverity.Error, true); + public static readonly DiagnosticDescriptor NullableServiceNotRegistered = new("JAB0013", "Not registered nullable dependency without a default value", "The nullable service '{0}' requested to construct '{1}' is not registered. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Error, true); diff --git a/src/Jab/KnownTypes.cs b/src/Jab/KnownTypes.cs index 9b6fce8..b62013e 100644 --- a/src/Jab/KnownTypes.cs +++ b/src/Jab/KnownTypes.cs @@ -10,6 +10,8 @@ internal class KnownTypes public const string ServiceProviderModuleAttributeShortName = "ServiceProviderModule"; public const string ImportAttributeShortName = "Import"; public const string FromNamedServicesAttributeShortName = "FromNamedServices"; + public const string ResolveDelegateShortName = "Resolve"; + public const string NamedResolveDelegateShortName = "NamedResolve"; public const string TransientAttributeTypeName = $"{TransientAttributeShortName}Attribute"; public const string SingletonAttributeTypeName = $"{SingletonAttributeShortName}Attribute"; @@ -19,6 +21,8 @@ internal class KnownTypes public const string ImportAttributeTypeName = $"{ImportAttributeShortName}Attribute"; public const string FromNamedServicesAttributeName = $"{FromNamedServicesAttributeShortName}Attribute"; + public const string ResolveDelegateTypeName = ResolveDelegateShortName; + public const string NamedResolveDelegateTypeName = NamedResolveDelegateShortName; public const string TransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}"; public const string GenericTransientAttributeMetadataName = $"Jab.{TransientAttributeTypeName}`1"; @@ -51,6 +55,8 @@ internal class KnownTypes private const string IKeyedServiceProviderMetadataName = "Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider"; private const string FromKeyedServicesAttributeMetadataName = "Microsoft.Extensions.DependencyInjection.FromKeyedServicesAttribute"; private const string FromNamedServicesAttributeMetadataName = $"Jab.{FromNamedServicesAttributeName}"; + private const string ResolveDelegateMetadataName = $"Jab.{ResolveDelegateTypeName}`1"; + private const string NamedResolveDelegateMetadataName = $"Jab.{NamedResolveDelegateTypeName}`1"; private const string IServiceScopeFactoryMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScopeFactory"; @@ -83,6 +89,8 @@ internal class KnownTypes public INamedTypeSymbol? IKeyedServiceProviderType { get; } public INamedTypeSymbol? FromKeyedServicesAttribute { get; } public INamedTypeSymbol? FromNamedServicesAttribute { get; } + public INamedTypeSymbol ResolveDelegateType { get; } + public INamedTypeSymbol NamedResolveDelegateType { get; } public KnownTypes(Compilation compilation, IModuleSymbol module, IAssemblySymbol assemblySymbol) { @@ -130,6 +138,8 @@ static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation ModuleAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, ServiceProviderModuleAttributeMetadataName); FromNamedServicesAttribute = GetTypeByMetadataNameOrThrow(assemblySymbol, FromNamedServicesAttributeMetadataName); + ResolveDelegateType = GetTypeByMetadataNameOrThrow(assemblySymbol, ResolveDelegateMetadataName); + NamedResolveDelegateType = GetTypeByMetadataNameOrThrow(assemblySymbol, NamedResolveDelegateMetadataName); } public static bool HasKnownTypes(IModuleSymbol sourceModule) diff --git a/src/Jab/ResolveDelegateCallSite.cs b/src/Jab/ResolveDelegateCallSite.cs new file mode 100644 index 0000000..b1bc3c9 --- /dev/null +++ b/src/Jab/ResolveDelegateCallSite.cs @@ -0,0 +1,14 @@ +namespace Jab; + +internal record ResolveDelegateCallSite : ServiceCallSite +{ + public ResolveDelegateCallSite(ServiceIdentity identity, ITypeSymbol resolvedType, bool usesName) + : base(identity, identity.Type, ServiceLifetime.Transient, false) + { + ResolvedType = resolvedType; + UsesName = usesName; + } + + public ITypeSymbol ResolvedType { get; } + public bool UsesName { get; } +} diff --git a/src/Jab/ServiceProviderBuilder.cs b/src/Jab/ServiceProviderBuilder.cs index 2b9af50..66772d5 100644 --- a/src/Jab/ServiceProviderBuilder.cs +++ b/src/Jab/ServiceProviderBuilder.cs @@ -265,6 +265,31 @@ ServiceCallSite BuiltInCallSite(ServiceCallSite callSite) return callSite; } + if (serviceType is INamedTypeSymbol { IsGenericType: true } delegateType) + { + if (SymbolEqualityComparer.Default.Equals(delegateType.ConstructedFrom, _knownTypes.ResolveDelegateType)) + { + var identity = new ServiceIdentity(serviceType, name, null); + if (CheckNotNamed(identity) is { } error) + { + return error; + } + + return CreateResolveDelegateCallSite(identity, delegateType.TypeArguments[0], requiresNamed: false, context); + } + + if (SymbolEqualityComparer.Default.Equals(delegateType.ConstructedFrom, _knownTypes.NamedResolveDelegateType)) + { + var identity = new ServiceIdentity(serviceType, name, null); + if (CheckNotNamed(identity) is { } error) + { + return error; + } + + return CreateResolveDelegateCallSite(identity, delegateType.TypeArguments[0], requiresNamed: true, context); + } + } + if (SymbolEqualityComparer.Default.Equals(serviceType, _knownTypes.IServiceProviderType)) { return BuiltInCallSite(_serviceProviderCallsite); @@ -723,6 +748,47 @@ private ServiceCallSite CreateConstructorCallSite( return (callSites, namedParameters, diagnostics); } + private ServiceCallSite? CreateResolveDelegateCallSite(ServiceIdentity identity, ITypeSymbol resolvedType, bool requiresNamed, ServiceResolutionContext context) + { + if (context.CallSiteCache.TryGet(identity, out var existing)) + { + return existing; + } + + Diagnostic? diagnostic = null; + + if (requiresNamed) + { + if (!HasNamedRegistration(resolvedType, context.ProviderDescription)) + { + diagnostic = Diagnostic.Create( + DiagnosticDescriptors.NamedResolveDelegateServiceNotRegistered, + context.RequestLocation, + resolvedType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + } + } + else + { + if (!CanSatisfy(resolvedType, context.ProviderDescription)) + { + diagnostic = Diagnostic.Create( + DiagnosticDescriptors.ResolveDelegateServiceNotRegistered, + context.RequestLocation, + resolvedType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat)); + } + } + + if (diagnostic != null) + { + _context.ReportDiagnostic(diagnostic); + return new ErrorCallSite(identity, diagnostic); + } + + var callSite = new ResolveDelegateCallSite(identity, resolvedType, requiresNamed); + context.CallSiteCache.Add(callSite); + return callSite; + } + private bool CanSatisfy(ITypeSymbol serviceType, ServiceProviderDescription description) { INamedTypeSymbol? genericType = null; @@ -739,6 +805,19 @@ private bool CanSatisfy(ITypeSymbol serviceType, ServiceProviderDescription desc return true; } + if (genericType != null) + { + if (SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.ResolveDelegateType)) + { + return CanSatisfy(genericType.TypeArguments[0], description); + } + + if (SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.NamedResolveDelegateType)) + { + return HasNamedRegistration(genericType.TypeArguments[0], description); + } + } + foreach (var registration in description.ServiceRegistrations) { if (SymbolEqualityComparer.Default.Equals(registration.ServiceType.ConstructedFrom, serviceType)) @@ -758,6 +837,31 @@ private bool CanSatisfy(ITypeSymbol serviceType, ServiceProviderDescription desc return false; } + private bool HasNamedRegistration(ITypeSymbol serviceType, ServiceProviderDescription description) + { + foreach (var registration in description.ServiceRegistrations) + { + if (registration.Name == null) + { + continue; + } + + if (SymbolEqualityComparer.Default.Equals(registration.ServiceType, serviceType)) + { + return true; + } + + if (serviceType is INamedTypeSymbol { IsGenericType: true } genericServiceType && + registration.ServiceType.IsUnboundGenericType && + SymbolEqualityComparer.Default.Equals(registration.ServiceType.ConstructedFrom, genericServiceType.ConstructedFrom)) + { + return true; + } + } + + return false; + } + private IMethodSymbol? SelectConstructor(INamedTypeSymbol implementationType, ServiceProviderDescription description) {