Skip to content

Commit

Permalink
Merge pull request #235 from cashapp/bradfol/fix-capture-bug
Browse files Browse the repository at this point in the history
Fix Resolver hierarchy + shadowed registration resolution
  • Loading branch information
bradfol authored Jan 29, 2025
2 parents 204451a + 82cae2f commit e63c5ad
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 24 deletions.
7 changes: 6 additions & 1 deletion Sources/Swinject/Container.Arguments.erb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ extension Container {
<%= arg_param_def %>) -> Service?
{
typealias FactoryType = ((Resolver, <%= arg_types %>)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, <%= arg_param_call %>)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, <%= arg_param_call %>))
}
)
}

<% end %>
Expand Down
63 changes: 54 additions & 9 deletions Sources/Swinject/Container.Arguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ extension Container {
argument: Arg1) -> Service?
{
typealias FactoryType = ((Resolver, Arg1)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, argument)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, argument))
}
)
}

/// Retrieves the instance with the specified service type and list of 2 arguments to the factory closure.
Expand Down Expand Up @@ -274,7 +279,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2))
}
)
}

/// Retrieves the instance with the specified service type and list of 3 arguments to the factory closure.
Expand Down Expand Up @@ -307,7 +317,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3))
}
)
}

/// Retrieves the instance with the specified service type and list of 4 arguments to the factory closure.
Expand Down Expand Up @@ -340,7 +355,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4))
}
)
}

/// Retrieves the instance with the specified service type and list of 5 arguments to the factory closure.
Expand Down Expand Up @@ -373,7 +393,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4, _ arg5: Arg5) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4, Arg5)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4, arg5)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4, arg5))
}
)
}

/// Retrieves the instance with the specified service type and list of 6 arguments to the factory closure.
Expand Down Expand Up @@ -406,7 +431,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4, _ arg5: Arg5, _ arg6: Arg6) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4, arg5, arg6)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4, arg5, arg6))
}
)
}

/// Retrieves the instance with the specified service type and list of 7 arguments to the factory closure.
Expand Down Expand Up @@ -439,7 +469,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4, _ arg5: Arg5, _ arg6: Arg6, _ arg7: Arg7) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4, arg5, arg6, arg7)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7))
}
)
}

/// Retrieves the instance with the specified service type and list of 8 arguments to the factory closure.
Expand Down Expand Up @@ -472,7 +507,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4, _ arg5: Arg5, _ arg6: Arg6, _ arg7: Arg7, _ arg8: Arg8) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8))
}
)
}

/// Retrieves the instance with the specified service type and list of 9 arguments to the factory closure.
Expand Down Expand Up @@ -505,7 +545,12 @@ extension Container {
arguments arg1: Arg1, _ arg2: Arg2, _ arg3: Arg3, _ arg4: Arg4, _ arg5: Arg5, _ arg6: Arg6, _ arg7: Arg7, _ arg8: Arg8, _ arg9: Arg9) -> Service?
{
typealias FactoryType = ((Resolver, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9)) -> Any
return _resolve(name: name) { (factory: FactoryType) in factory((self, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9)) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: FactoryType) in
factory((resolver, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9))
}
)
}

}
38 changes: 26 additions & 12 deletions Sources/Swinject/Container.swift
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,13 @@ public final class Container {
// MARK: - _Resolver

extension Container: _Resolver {

/// See documentation on `_Resolver` protocol where this method is declared.
// swiftlint:disable:next identifier_name
public func _resolve<Service, Arguments>(
name: String?,
option: ServiceKeyOption? = nil,
invoker: @escaping ((Arguments) -> Any) -> Any
invoker: @escaping (Resolver, (Arguments) -> Any) -> Any
) -> Service? {
// No need to use weak self since the resolution will be executed before
// this function exits.
Expand All @@ -246,8 +248,8 @@ extension Container: _Resolver {
return currentObjectGraph as? Service
}

if let entry = getEntry(for: key) {
resolvedInstance = resolve(entry: entry, invoker: invoker)
if let (entry, resolver) = getEntry(for: key) {
resolvedInstance = resolve(entry: entry, invoker: invoker, resolver: resolver)
}

if resolvedInstance == nil {
Expand All @@ -269,15 +271,15 @@ extension Container: _Resolver {
fileprivate func resolveAsWrapper<Wrapper, Arguments>(
name: String?,
option: ServiceKeyOption?,
invoker: @escaping ((Arguments) -> Any) -> Any
invoker: @escaping (Resolver, (Arguments) -> Any) -> Any
) -> Wrapper? {
guard let wrapper = Wrapper.self as? InstanceWrapper.Type else { return nil }

let key = ServiceKey(
serviceType: wrapper.wrappedType, argumentsType: Arguments.self, name: name, option: option
)

if let entry = getEntry(for: key) {
if let (entry, resolver) = getEntry(for: key) {
let factory = { [weak self] (graphIdentifier: GraphIdentifier?) -> Any? in
self?.syncIfEnabled { [weak self] () -> Any? in
guard let self else { return nil }
Expand All @@ -286,7 +288,7 @@ extension Container: _Resolver {
if let graphIdentifier = graphIdentifier {
self.restoreObjectGraph(graphIdentifier)
}
return self.resolve(entry: entry, invoker: invoker) as Any?
return self.resolve(entry: entry, invoker: invoker, resolver: resolver) as Any?
}
}
return wrapper.init(inContainer: self, withInstanceFactory: factory) as? Wrapper
Expand Down Expand Up @@ -352,20 +354,32 @@ extension Container: Resolver {
/// - Returns: The resolved service type instance, or nil if no registration for the service type and name
/// is found in the ``Container``.
public func resolve<Service>(_: Service.Type, name: String?) -> Service? {
return _resolve(name: name) { (factory: (Resolver) -> Any) in factory(self) }
return _resolve(
name: name,
invoker: { (resolver: Resolver, factory: (Resolver) -> Any) in
factory(resolver)
}
)
}

fileprivate func getEntry(for key: ServiceKey) -> ServiceEntryProtocol? {
/// Retrieve the service entry for a given service key.
///
/// - Returns: An optional tuple of the service entry and the source resolver.
fileprivate func getEntry(for key: ServiceKey) -> (ServiceEntryProtocol, Resolver)? {
if let entry = services[key] {
return entry
return (entry, self)
} else if let parentResult = parent?.getEntry(for: key) {
// An entry from a parent container uses that same parent container as the source resolver
return parentResult
} else {
return parent?.getEntry(for: key)
return nil
}
}

fileprivate func resolve<Service, Factory>(
entry: ServiceEntryProtocol,
invoker: @escaping (Factory) -> Any
invoker: @escaping (Resolver, Factory) -> Any,
resolver: Resolver
) -> Service? {
self.incrementResolutionDepth()
defer { self.decrementResolutionDepth() }
Expand All @@ -378,7 +392,7 @@ extension Container: Resolver {
return persistedInstance
}

let resolvedInstance = invoker(entry.factory as! Factory)
let resolvedInstance = invoker(resolver, entry.factory as! Factory)
if let persistedInstance = self.persistedInstance(Service.self, from: entry, in: currentObjectGraph) {
// An instance for the key might be added by the factory invocation.
return persistedInstance
Expand Down
4 changes: 3 additions & 1 deletion Sources/Swinject/_Resolver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ public protocol _Resolver {
/// - Parameter name: The registration name.
/// - Parameter option: A service key option for an extension/plugin.
/// - Parameter invoker: A closure to execute service resolution.
/// The primary responsibility of the invoker is to close over the values
/// of any arguments passed in during the resolve call.
///
/// - Returns: The resolved service type instance, or nil if no service is found.
// swiftlint:disable:next identifier_name
func _resolve<Service, Arguments>(
name: String?,
option: ServiceKeyOption?,
invoker: @escaping ((Arguments) -> Any) -> Any
invoker: @escaping (Resolver, (Arguments) -> Any) -> Any
) -> Service?
}
2 changes: 1 addition & 1 deletion Tests/SwinjectTests/ContainerTests.DebugHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ContainerTests_DebugHelper: XCTestCase {
func testContainerShouldCallDebugHelperWithFailingServiceAndKey() {
let container = Container(debugHelper: spy)

_ = container._resolve(name: "name") { (_: (Int) -> Any) in 1 as Double } as Double?
_ = container._resolve(name: "name") { (_: Resolver, _: (Int) -> Any) in 1 as Double } as Double?

XCTAssertEqual("\(spy.serviceType)", "Double")
XCTAssertEqual(spy.key, ServiceKey(
Expand Down
33 changes: 33 additions & 0 deletions Tests/SwinjectTests/ContainerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,39 @@ class ContainerTests: XCTestCase {
XCTAssertNil(weakCat)
}

func testShadowedRegistration_owningContainerHierarchyAccess() {
let parent = Container()
let child = Container(parent: parent)

parent.register(Animal.self, factory: { _ in Dog()})
child.register(Animal.self, factory: { _ in Cat()})

// Parent registration should not be able to see into any child registrations
parent.register(Animal.self, name: "Spot", factory: { resolver in
resolver.resolve(Animal.self)!
})

XCTAssert(child.resolve(Animal.self, name: "Spot") is Dog)
XCTAssert(parent.resolve(Animal.self, name: "Spot") is Dog)
}

func testShadowedRegistration_owningContainerHierarchyAccess_inObjectScopeContainer() {
let parent = Container()
let child = Container(parent: parent)

parent.register(Animal.self, factory: { _ in Dog()})
child.register(Animal.self, factory: { _ in Cat()})

// Parent registration should not be able to see into any child registrations
parent.register(Animal.self, name: "Spot", factory: { resolver in
resolver.resolve(Animal.self)!
})
.inObjectScope(.container)

XCTAssert(child.resolve(Animal.self, name: "Spot") is Dog)
XCTAssert(parent.resolve(Animal.self, name: "Spot") is Dog)
}

#if !SWIFT_PACKAGE
func testContainerDoesNotTerminateGraphPrematurely() {
let parent = Container()
Expand Down

0 comments on commit e63c5ad

Please sign in to comment.