diff --git a/VContainer/Assets/VContainer/Runtime/ContainerBuilderExtensions.cs b/VContainer/Assets/VContainer/Runtime/ContainerBuilderExtensions.cs index 113b949f..63eed63e 100644 --- a/VContainer/Assets/VContainer/Runtime/ContainerBuilderExtensions.cs +++ b/VContainer/Assets/VContainer/Runtime/ContainerBuilderExtensions.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Runtime.CompilerServices; using VContainer.Internal; @@ -57,6 +56,45 @@ public static RegistrationBuilder Register( Func implementationConfiguration, Lifetime lifetime) => builder.Register(new FuncRegistrationBuilder(container => implementationConfiguration(container), typeof(TInterface), lifetime)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static RegistrationBuilder Register( + this IContainerBuilder builder, + Type interfaceType, + Func implementationFactory, + Lifetime lifetime) + { + return builder.Register(new OpenGenericFuncRegistrationBuilder( + interfaceType, + (resolver, args) => implementationFactory(resolver, args[0]), + lifetime)).As(interfaceType); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static RegistrationBuilder Register( + this IContainerBuilder builder, + Type interfaceType, + Func implementationFactory, + Lifetime lifetime) + { + return builder.Register(new OpenGenericFuncRegistrationBuilder( + interfaceType, + (resolver, args) => implementationFactory(resolver, args[0], args[1]), + lifetime)).As(interfaceType); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static RegistrationBuilder Register( + this IContainerBuilder builder, + Type interfaceType, + Func implementationFactory, + Lifetime lifetime) + { + return builder.Register(new OpenGenericFuncRegistrationBuilder( + interfaceType, + (resolver, args) => implementationFactory(resolver, args[0], args[1], args[2]), + lifetime)).As(interfaceType); + } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static RegistrationBuilder RegisterInstance( diff --git a/VContainer/Assets/VContainer/Runtime/IClosedRegistrationProvider.cs b/VContainer/Assets/VContainer/Runtime/IClosedRegistrationProvider.cs new file mode 100644 index 00000000..f1ba6b08 --- /dev/null +++ b/VContainer/Assets/VContainer/Runtime/IClosedRegistrationProvider.cs @@ -0,0 +1,9 @@ +using System; + +namespace VContainer +{ + public interface IClosedRegistrationProvider + { + Registration GetClosedRegistration(Type closedInterfaceType, Type[] typeParameters, object key = null); + } +} \ No newline at end of file diff --git a/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericFuncInstanceProvider.cs b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericFuncInstanceProvider.cs new file mode 100644 index 00000000..6e2f0744 --- /dev/null +++ b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericFuncInstanceProvider.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using TypeParametersKey = VContainer.Internal.OpenGenericTypeParametersKey; + +namespace VContainer.Internal +{ + public class OpenGenericFuncInstanceProvider : IInstanceProvider, IClosedRegistrationProvider + { + readonly Type implementationType; + readonly Lifetime lifetime; + readonly Func factory; + + readonly ConcurrentDictionary constructedRegistrations = new ConcurrentDictionary(); + readonly Func createRegistrationFunc; + + public OpenGenericFuncInstanceProvider(Type implementationType, Lifetime lifetime, Func factory) + { + this.implementationType = implementationType; + this.lifetime = lifetime; + this.factory = factory; + createRegistrationFunc = CreateRegistration; + } + + public Registration GetClosedRegistration(Type closedInterfaceType, Type[] typeParameters, object key = null) + { + var typeParametersKey = new TypeParametersKey(typeParameters, key); + return constructedRegistrations.GetOrAdd(typeParametersKey, createRegistrationFunc); + } + + Registration CreateRegistration(TypeParametersKey key) + { + var newType = implementationType.MakeGenericType(key.TypeParameters); + var spawner = new FuncInstanceProvider(resolver => factory(resolver, key.TypeParameters)); + return new Registration(newType, lifetime, new List(1) { newType }, spawner, key.Key); + } + + public object SpawnInstance(IObjectResolver resolver) + { + throw new InvalidOperationException(); + } + } +} \ No newline at end of file diff --git a/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericInstanceProvider.cs b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericInstanceProvider.cs index f2d7f7db..48bde902 100644 --- a/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericInstanceProvider.cs +++ b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/OpenGenericInstanceProvider.cs @@ -1,54 +1,12 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using TypeParametersKey = VContainer.Internal.OpenGenericTypeParametersKey; namespace VContainer.Internal { - public class OpenGenericInstanceProvider : IInstanceProvider + public class OpenGenericInstanceProvider : IInstanceProvider, IClosedRegistrationProvider { - class TypeParametersKey - { - public readonly Type[] TypeParameters; - public readonly object Key; - - public TypeParametersKey(Type[] typeParameters, object key) - { - TypeParameters = typeParameters; - Key = key; - } - - public override bool Equals(object obj) - { - if (obj is TypeParametersKey other) - { - if (Key != other.Key) - return false; - - if (TypeParameters.Length != other.TypeParameters.Length) - return false; - - for (var i = 0; i < TypeParameters.Length; i++) - { - if (TypeParameters[i] != other.TypeParameters[i]) - return false; - } - return true; - } - return false; - } - - public override int GetHashCode() - { - var hash = 5381; - foreach (var typeParameter in TypeParameters) - { - hash = ((hash << 5) + hash) ^ typeParameter.GetHashCode(); - } - hash = ((hash << 5) + hash) ^ (Key?.GetHashCode() ?? 0); - return hash; - } - } - readonly Lifetime lifetime; readonly Type implementationType; readonly IReadOnlyList customParameters; diff --git a/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/Shared/OpenGenericTypeParametersKey.cs b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/Shared/OpenGenericTypeParametersKey.cs new file mode 100644 index 00000000..95420c2b --- /dev/null +++ b/VContainer/Assets/VContainer/Runtime/Internal/InstanceProviders/Shared/OpenGenericTypeParametersKey.cs @@ -0,0 +1,47 @@ +using System; + +namespace VContainer.Internal +{ + public class OpenGenericTypeParametersKey + { + public readonly Type[] TypeParameters; + public readonly object Key; + + public OpenGenericTypeParametersKey(Type[] typeParameters, object key) + { + TypeParameters = typeParameters; + Key = key; + } + + public override bool Equals(object obj) + { + if (obj is OpenGenericTypeParametersKey other) + { + if (Key != other.Key) + return false; + + if (TypeParameters.Length != other.TypeParameters.Length) + return false; + + for (var i = 0; i < TypeParameters.Length; i++) + { + if (TypeParameters[i] != other.TypeParameters[i]) + return false; + } + return true; + } + return false; + } + + public override int GetHashCode() + { + var hash = 5381; + foreach (var typeParameter in TypeParameters) + { + hash = ((hash << 5) + hash) ^ typeParameter.GetHashCode(); + } + hash = ((hash << 5) + hash) ^ (Key?.GetHashCode() ?? 0); + return hash; + } + } +} diff --git a/VContainer/Assets/VContainer/Runtime/Internal/OpenGenericFuncRegistrationBuilder.cs b/VContainer/Assets/VContainer/Runtime/Internal/OpenGenericFuncRegistrationBuilder.cs new file mode 100644 index 00000000..19328cda --- /dev/null +++ b/VContainer/Assets/VContainer/Runtime/Internal/OpenGenericFuncRegistrationBuilder.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; + +namespace VContainer.Internal +{ + public class OpenGenericFuncRegistrationBuilder : RegistrationBuilder + { + readonly Func factory; + + public OpenGenericFuncRegistrationBuilder( + Type openGenericType, + Func factory, + Lifetime lifetime) : base(openGenericType, lifetime) + { + if (!openGenericType.IsGenericType || openGenericType.IsConstructedGenericType) + throw new VContainerException(openGenericType, "Type is not open generic type."); + + this.factory = factory; + } + + public override Registration Build() + { + var provider = new OpenGenericFuncInstanceProvider(ImplementationType, Lifetime, factory); + return new Registration(ImplementationType, Lifetime, InterfaceTypes, provider); + } + + public override RegistrationBuilder AsImplementedInterfaces() + { + InterfaceTypes ??= new List(); + foreach (var i in ImplementationType.GetInterfaces()) + { + if (!i.IsGenericType) + continue; + + var def = i.GetGenericTypeDefinition(); + if (!InterfaceTypes.Contains(def)) + InterfaceTypes.Add(def); + } + return this; + } + + protected override void AddInterfaceType(Type interfaceType) + { + if (interfaceType.IsConstructedGenericType) + throw new VContainerException(interfaceType, "Type is not open generic type."); + + foreach (var i in ImplementationType.GetInterfaces()) + { + if (!i.IsGenericType || i.GetGenericTypeDefinition() != interfaceType) + continue; + + InterfaceTypes ??= new List(); + + if (!InterfaceTypes.Contains(interfaceType)) + InterfaceTypes.Add(interfaceType); + + return; + } + + base.AddInterfaceType(interfaceType); + } + } +} \ No newline at end of file diff --git a/VContainer/Assets/VContainer/Runtime/Registry.cs b/VContainer/Assets/VContainer/Runtime/Registry.cs index 8a939354..b52a3038 100644 --- a/VContainer/Assets/VContainer/Runtime/Registry.cs +++ b/VContainer/Assets/VContainer/Runtime/Registry.cs @@ -131,9 +131,9 @@ bool TryGetClosedGenericRegistration(Type interfaceType, object key, Type openGe { if (hashTable.TryGet(openGenericType, key, out var openGenericRegistration)) { - if (openGenericRegistration.Provider is OpenGenericInstanceProvider implementationRegistration) + if (openGenericRegistration.Provider is IClosedRegistrationProvider implementationRegistration) { - registration = implementationRegistration.GetClosedRegistration(interfaceType, typeParameters); + registration = implementationRegistration.GetClosedRegistration(interfaceType, typeParameters, key); return true; } }