diff --git a/src/Jab.FunctionalTests.Common/ConstructorSelectionTests.cs b/src/Jab.FunctionalTests.Common/ConstructorSelectionTests.cs index 823c7b6..87cffdf 100644 --- a/src/Jab.FunctionalTests.Common/ConstructorSelectionTests.cs +++ b/src/Jab.FunctionalTests.Common/ConstructorSelectionTests.cs @@ -1,3 +1,4 @@ +using System; using Xunit; using Jab; diff --git a/src/Jab.FunctionalTests.Common/ContainerTests.cs b/src/Jab.FunctionalTests.Common/ContainerTests.cs index b81b634..10b33fa 100644 --- a/src/Jab.FunctionalTests.Common/ContainerTests.cs +++ b/src/Jab.FunctionalTests.Common/ContainerTests.cs @@ -1082,6 +1082,115 @@ public void CanGetMultipleOpenGenericScoped() partial class CanGetMultipleOpenGenericScopedContainer { } + + [Fact] + public void SupportsImplicitFunc() + { + SupportsImplicitFuncFactoryContainer c = new(); + var transientFunc = c.GetService>(); + var transientFunc2 = c.GetService>(); + var transientService1 = transientFunc(); + var transientService2 = transientFunc(); + + var scope1 = c.CreateScope(); + var scopedFunc = scope1.GetService>(); + var scopedFunc2 = scope1.GetService>(); + var scopedService1 = scopedFunc(); + var scopedService2 = scopedFunc(); + + var scope2 = c.CreateScope(); + var scopedFunc3 = scope2.GetService>(); + var scopedService3 = scopedFunc3(); + + var singletonFunc = c.GetService>(); + var singletonFunc2 = c.GetService>(); + + var singletonService1 = singletonFunc(); + var singletonService2 = singletonFunc2(); + + Assert.Equal(2, c.TransientCount); + Assert.Equal(2, c.ScopedCount); + Assert.Equal(1, c.SingletonCount); + + Assert.Same(singletonFunc, singletonFunc2); + Assert.Same(transientFunc, transientFunc2); + Assert.Same(scopedFunc, scopedFunc2); + Assert.NotSame(scopedFunc2, scopedFunc3); + + Assert.Same(singletonService1, singletonService2); + Assert.Same(scopedService1, scopedService2); + Assert.NotSame(scopedService1, scopedService3); + + Assert.NotSame(transientService1, transientService2); + } + + [Fact] + public void SupportsImplicitNamedFunc() + { + SupportsImplicitFuncFactoryContainer c = new(); + var transientFunc = c.GetService>("named"); + var transientFunc2 = c.GetService>("named"); + var transientService1 = transientFunc(); + var transientService2 = transientFunc(); + + var singletonFunc = c.GetService>("named"); + var singletonFunc2 = c.GetService>("named"); + + var singletonService1 = singletonFunc(); + var singletonService2 = singletonFunc2(); + + Assert.Equal(2, c.TransientNamedCount); + Assert.Equal(1, c.SingletonNamedCount); + + Assert.Same(singletonFunc, singletonFunc2); + Assert.Same(transientFunc, transientFunc2); + + Assert.Same(singletonService1, singletonService2); + Assert.NotSame(transientService1, transientService2); + } + + [ServiceProvider(RootServices = new [] { typeof(Func) })] + [Transient(typeof(IService), Factory=nameof(TransientNamedFactory), Name = "named")] + [Singleton(typeof(IService2), Factory=nameof(SingletonNamedFactory), Name = "named")] + [Transient(typeof(IService), Factory=nameof(TransientFactory))] + [Scoped(typeof(IService1), Factory=nameof(ScopedFactory))] + [Singleton(typeof(IService2), Factory=nameof(SingletonFactory))] + internal partial class SupportsImplicitFuncFactoryContainer + { + internal int TransientCount = 0; + internal int ScopedCount = 0; + internal int SingletonCount = 0; + + internal int TransientNamedCount = 0; + internal int SingletonNamedCount = 0; + + internal ServiceImplementation TransientFactory() + { + TransientCount++; + return new(); + } + internal ServiceImplementation ScopedFactory() + { + ScopedCount++; + return new(); + } + internal ServiceImplementation SingletonFactory() + { + SingletonCount++; + return new(); + } + + internal ServiceImplementation TransientNamedFactory() + { + TransientNamedCount++; + return new(); + } + internal ServiceImplementation SingletonNamedFactory() + { + SingletonNamedCount++; + return new(); + } + } #region Non-generic member factory with parameters [Fact] diff --git a/src/Jab/ConstructorCallSite.cs b/src/Jab/ConstructorCallSite.cs index aa4f560..9af044a 100644 --- a/src/Jab/ConstructorCallSite.cs +++ b/src/Jab/ConstructorCallSite.cs @@ -2,7 +2,7 @@ internal record ConstructorCallSite : ServiceCallSite { - public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, int? reverseIndex, bool? isDisposable) + public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair[] optionalParameters, ServiceLifetime lifetime, bool? isDisposable) : base(identity, implementationType, lifetime, isDisposable) { Parameters = parameters; diff --git a/src/Jab/ContainerGenerator.cs b/src/Jab/ContainerGenerator.cs index 5f503f0..bbcd42d 100644 --- a/src/Jab/ContainerGenerator.cs +++ b/src/Jab/ContainerGenerator.cs @@ -24,19 +24,20 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen if (serviceCallSite.Lifetime != ServiceLifetime.Transient) { var cacheLocation = GetCacheLocation(serviceCallSite.Identity); - codeWriter.Line($"if ({cacheLocation} == null)"); - codeWriter.Line($"lock (this)"); - using (codeWriter.Scope($"if ({cacheLocation} == null)")) + var locking = serviceCallSite is not FuncCallSite; + if (locking) { - GenerateCallSite( - codeWriter, - rootReference, - serviceCallSite, - (w, v) => - { - w.Line($"{cacheLocation} = {v};"); - }); + codeWriter.Line($"if ({cacheLocation} == null)"); + codeWriter.Line($"lock (this)"); } + GenerateCallSite( + codeWriter, + rootReference, + serviceCallSite, + (w, v) => + { + w.Line($"{cacheLocation} ??= {v};"); + }); if (serviceCallSite.ImplementationType.IsValueType) { @@ -146,6 +147,14 @@ private void GenerateCallSite(CodeWriter codeWriter, string rootReference, Servi w.Append($")"); }); break; + + case FuncCallSite funcCallSite: + valueCallback(codeWriter, w => + { + w.Append($"() => "); + WriteResolutionCall(codeWriter, funcCallSite.Inner.Identity, "this"); + }); + break; case MemberCallSite memberCallSite: valueCallback(codeWriter, w => { diff --git a/src/Jab/FuncCallSite.cs b/src/Jab/FuncCallSite.cs new file mode 100644 index 0000000..1a4f19b --- /dev/null +++ b/src/Jab/FuncCallSite.cs @@ -0,0 +1,18 @@ +namespace Jab; + +internal record FuncCallSite : ServiceCallSite +{ + public FuncCallSite(ServiceIdentity identity, ServiceCallSite inner) + : base(identity, identity.Type, GetFuncLifetime(inner.Lifetime), false) + { + Inner = inner; + } + + public ServiceCallSite Inner { get; } + + private static ServiceLifetime GetFuncLifetime(ServiceLifetime innerLifetime) => innerLifetime switch + { + ServiceLifetime.Scoped => ServiceLifetime.Scoped, + _ => ServiceLifetime.Singleton + }; +} \ No newline at end of file diff --git a/src/Jab/KnownTypes.cs b/src/Jab/KnownTypes.cs index 9b6fce8..98905b2 100644 --- a/src/Jab/KnownTypes.cs +++ b/src/Jab/KnownTypes.cs @@ -46,6 +46,7 @@ internal class KnownTypes private const string IAsyncDisposableMetadataName = "System.IAsyncDisposable"; private const string IEnumerableMetadataName = "System.Collections.Generic.IEnumerable`1"; + private const string FuncMetadataName = "System.Func`1"; private const string IServiceProviderMetadataName = "System.IServiceProvider"; private const string IServiceScopeMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScope"; private const string IKeyedServiceProviderMetadataName = "Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider"; @@ -59,6 +60,7 @@ internal class KnownTypes "Microsoft.Extensions.DependencyInjection.IServiceProviderIsService"; public INamedTypeSymbol IEnumerableType { get; } + public INamedTypeSymbol FuncType { get; } public INamedTypeSymbol IServiceProviderType { get; } public INamedTypeSymbol CompositionRootAttributeType { get; } public INamedTypeSymbol TransientAttributeType { get; } @@ -102,6 +104,7 @@ static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation ?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found"); IEnumerableType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IEnumerableMetadataName); + FuncType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, FuncMetadataName); IServiceProviderType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IServiceProviderMetadataName); IServiceScopeType = compilation.GetTypeByMetadataName(IServiceScopeMetadataName); IAsyncDisposableType = compilation.GetTypeByMetadataName(IAsyncDisposableMetadataName); diff --git a/src/Jab/ServiceProviderBuilder.cs b/src/Jab/ServiceProviderBuilder.cs index 2b9af50..9887d15 100644 --- a/src/Jab/ServiceProviderBuilder.cs +++ b/src/Jab/ServiceProviderBuilder.cs @@ -228,6 +228,7 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol) return TryCreateSpecial(serviceType, name, context) ?? TryCreateExact(serviceType, name, null, context) ?? TryCreateEnumerable(serviceType, name, context) ?? + TryCreateFunc(serviceType, name, context) ?? TryCreateGeneric(serviceType, name, context); } finally @@ -432,6 +433,37 @@ static ServiceLifetime GetCommonLifetime(IEnumerable callSites) return null; } + + private ServiceCallSite? TryCreateFunc(ITypeSymbol serviceType, string? name, ServiceResolutionContext context) + { + if (serviceType is INamedTypeSymbol { IsGenericType: true } genericType && + SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.FuncType)) + { + var identity = new ServiceIdentity(genericType, name, null); + + if (context.CallSiteCache.TryGet(identity, out var callSite)) + { + return callSite; + } + + var innerType = genericType.TypeArguments[0]; + var inner = GetCallSite(innerType, name, context); + + if (inner == null) + { + return null; + } + + callSite = new FuncCallSite(identity, inner); + + context.CallSiteCache.Add(callSite); + + return callSite; + } + + return null; + } + private ServiceCallSite? TryCreateExact( ITypeSymbol serviceType, string? name, @@ -612,7 +644,6 @@ private ServiceCallSite CreateConstructorCallSite( parameters.ToArray(), namedParameters.ToArray(), registration.Lifetime, - identity.ReverseIndex, // TODO: this can be optimized to avoid check for all the types isDisposable: null );