Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using VContainer.Internal;

Expand Down Expand Up @@ -57,6 +56,45 @@ public static RegistrationBuilder Register<TInterface>(
Func<IObjectResolver, TInterface> implementationConfiguration,
Lifetime lifetime)
=> builder.Register(new FuncRegistrationBuilder(container => implementationConfiguration(container), typeof(TInterface), lifetime));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static RegistrationBuilder Register(
this IContainerBuilder builder,
Type interfaceType,
Func<IObjectResolver, Type, object> implementationFactory,
Lifetime lifetime)
{
return builder.Register(new OpenGenericFuncRegistrationBuilder(
interfaceType,
(resolver, args) => implementationFactory(resolver, args[0]),
lifetime)).As(interfaceType);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static RegistrationBuilder Register(
this IContainerBuilder builder,
Type interfaceType,
Func<IObjectResolver, Type, Type, object> implementationFactory,
Lifetime lifetime)
{
return builder.Register(new OpenGenericFuncRegistrationBuilder(
interfaceType,
(resolver, args) => implementationFactory(resolver, args[0], args[1]),
lifetime)).As(interfaceType);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static RegistrationBuilder Register(
this IContainerBuilder builder,
Type interfaceType,
Func<IObjectResolver, Type, Type, Type, object> implementationFactory,
Lifetime lifetime)
{
return builder.Register(new OpenGenericFuncRegistrationBuilder(
interfaceType,
(resolver, args) => implementationFactory(resolver, args[0], args[1], args[2]),
lifetime)).As(interfaceType);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static RegistrationBuilder RegisterInstance(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using System;

namespace VContainer
{
public interface IClosedRegistrationProvider
{
Registration GetClosedRegistration(Type closedInterfaceType, Type[] typeParameters, object key = null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using TypeParametersKey = VContainer.Internal.OpenGenericTypeParametersKey;

namespace VContainer.Internal
{
public class OpenGenericFuncInstanceProvider : IInstanceProvider, IClosedRegistrationProvider
{
readonly Type implementationType;
readonly Lifetime lifetime;
readonly Func<IObjectResolver, Type[], object> factory;

readonly ConcurrentDictionary<TypeParametersKey, Registration> constructedRegistrations = new ConcurrentDictionary<TypeParametersKey, Registration>();
readonly Func<TypeParametersKey, Registration> createRegistrationFunc;

public OpenGenericFuncInstanceProvider(Type implementationType, Lifetime lifetime, Func<IObjectResolver, Type[], object> factory)
{
this.implementationType = implementationType;
this.lifetime = lifetime;
this.factory = factory;
createRegistrationFunc = CreateRegistration;
}

public Registration GetClosedRegistration(Type closedInterfaceType, Type[] typeParameters, object key = null)
{
var typeParametersKey = new TypeParametersKey(typeParameters, key);
return constructedRegistrations.GetOrAdd(typeParametersKey, createRegistrationFunc);
}

Registration CreateRegistration(TypeParametersKey key)
{
var newType = implementationType.MakeGenericType(key.TypeParameters);
var spawner = new FuncInstanceProvider(resolver => factory(resolver, key.TypeParameters));
return new Registration(newType, lifetime, new List<Type>(1) { newType }, spawner, key.Key);
}

public object SpawnInstance(IObjectResolver resolver)
{
throw new InvalidOperationException();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,54 +1,12 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using TypeParametersKey = VContainer.Internal.OpenGenericTypeParametersKey;

namespace VContainer.Internal
{
public class OpenGenericInstanceProvider : IInstanceProvider
public class OpenGenericInstanceProvider : IInstanceProvider, IClosedRegistrationProvider
{
class TypeParametersKey
{
public readonly Type[] TypeParameters;
public readonly object Key;

public TypeParametersKey(Type[] typeParameters, object key)
{
TypeParameters = typeParameters;
Key = key;
}

public override bool Equals(object obj)
{
if (obj is TypeParametersKey other)
{
if (Key != other.Key)
return false;

if (TypeParameters.Length != other.TypeParameters.Length)
return false;

for (var i = 0; i < TypeParameters.Length; i++)
{
if (TypeParameters[i] != other.TypeParameters[i])
return false;
}
return true;
}
return false;
}

public override int GetHashCode()
{
var hash = 5381;
foreach (var typeParameter in TypeParameters)
{
hash = ((hash << 5) + hash) ^ typeParameter.GetHashCode();
}
hash = ((hash << 5) + hash) ^ (Key?.GetHashCode() ?? 0);
return hash;
}
}

readonly Lifetime lifetime;
readonly Type implementationType;
readonly IReadOnlyList<IInjectParameter> customParameters;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using System;

namespace VContainer.Internal
{
public class OpenGenericTypeParametersKey
{
public readonly Type[] TypeParameters;
public readonly object Key;

public OpenGenericTypeParametersKey(Type[] typeParameters, object key)
{
TypeParameters = typeParameters;
Key = key;
}

public override bool Equals(object obj)
{
if (obj is OpenGenericTypeParametersKey other)
{
if (Key != other.Key)
return false;

if (TypeParameters.Length != other.TypeParameters.Length)
return false;

for (var i = 0; i < TypeParameters.Length; i++)
{
if (TypeParameters[i] != other.TypeParameters[i])
return false;
}
return true;
}
return false;
}

public override int GetHashCode()
{
var hash = 5381;
foreach (var typeParameter in TypeParameters)
{
hash = ((hash << 5) + hash) ^ typeParameter.GetHashCode();
}
hash = ((hash << 5) + hash) ^ (Key?.GetHashCode() ?? 0);
return hash;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;

namespace VContainer.Internal
{
public class OpenGenericFuncRegistrationBuilder : RegistrationBuilder
{
readonly Func<IObjectResolver, Type[], object> factory;

public OpenGenericFuncRegistrationBuilder(
Type openGenericType,
Func<IObjectResolver, Type[], object> factory,
Lifetime lifetime) : base(openGenericType, lifetime)
{
if (!openGenericType.IsGenericType || openGenericType.IsConstructedGenericType)
throw new VContainerException(openGenericType, "Type is not open generic type.");

this.factory = factory;
}

public override Registration Build()
{
var provider = new OpenGenericFuncInstanceProvider(ImplementationType, Lifetime, factory);
return new Registration(ImplementationType, Lifetime, InterfaceTypes, provider);
}

public override RegistrationBuilder AsImplementedInterfaces()
{
InterfaceTypes ??= new List<Type>();
foreach (var i in ImplementationType.GetInterfaces())
{
if (!i.IsGenericType)
continue;

var def = i.GetGenericTypeDefinition();
if (!InterfaceTypes.Contains(def))
InterfaceTypes.Add(def);
}
return this;
}

protected override void AddInterfaceType(Type interfaceType)
{
if (interfaceType.IsConstructedGenericType)
throw new VContainerException(interfaceType, "Type is not open generic type.");

foreach (var i in ImplementationType.GetInterfaces())
{
if (!i.IsGenericType || i.GetGenericTypeDefinition() != interfaceType)
continue;

InterfaceTypes ??= new List<Type>();

if (!InterfaceTypes.Contains(interfaceType))
InterfaceTypes.Add(interfaceType);

return;
}

base.AddInterfaceType(interfaceType);
}
}
}
4 changes: 2 additions & 2 deletions VContainer/Assets/VContainer/Runtime/Registry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ bool TryGetClosedGenericRegistration(Type interfaceType, object key, Type openGe
{
if (hashTable.TryGet(openGenericType, key, out var openGenericRegistration))
{
if (openGenericRegistration.Provider is OpenGenericInstanceProvider implementationRegistration)
if (openGenericRegistration.Provider is IClosedRegistrationProvider implementationRegistration)
{
registration = implementationRegistration.GetClosedRegistration(interfaceType, typeParameters);
registration = implementationRegistration.GetClosedRegistration(interfaceType, typeParameters, key);
return true;
}
}
Expand Down