Skip to content

Commit

Permalink
Add support for List and Dict providers to _locate_dependent_closing_…
Browse files Browse the repository at this point in the history
…args
  • Loading branch information
ZipFile committed Feb 1, 2025
1 parent 72a316c commit 4badee7
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 99 deletions.
36 changes: 18 additions & 18 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""Wiring module."""

import functools
import inspect
import importlib
import importlib.machinery
import inspect
import pkgutil
import warnings
import sys
import warnings
from types import ModuleType
from typing import (
Optional,
Iterable,
Iterator,
Callable,
Any,
Tuple,
Callable,
Dict,
Generic,
TypeVar,
Iterable,
Iterator,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
Set,
cast,
)

Expand Down Expand Up @@ -645,17 +645,17 @@ def _fetch_reference_injections( # noqa: C901
def _locate_dependent_closing_args(
provider: providers.Provider,
) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"):
return {}
closing_deps: Dict[str, providers.Provider] = {}

closing_deps = {}
for arg in [*provider.args, *provider.kwargs.values()]:
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
for arg in [
*getattr(provider, "args", []),
*getattr(provider, "kwargs", {}).values(),
]:
if not isinstance(arg, providers.Provider):
continue
if isinstance(arg, providers.Resource):
return {str(id(arg)): arg}
if arg.args or arg.kwargs:
closing_deps |= _locate_dependent_closing_args(arg)
closing_deps[str(id(arg))] = arg
closing_deps |= _locate_dependent_closing_args(arg)

return closing_deps

Expand Down Expand Up @@ -1030,8 +1030,8 @@ def is_loader_installed() -> bool:
_loader = AutoLoader()

# Optimizations
from ._cwiring import _get_sync_patched # noqa
from ._cwiring import _async_inject # noqa
from ._cwiring import _get_sync_patched # noqa


# Wiring uses the following Python wrapper because there is
Expand Down
96 changes: 48 additions & 48 deletions tests/unit/samples/wiringstringids/resourceclosing.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,78 @@
from typing import Any, Dict, List, Optional

from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
from dependency_injector.wiring import Closing, Provide, inject


class Counter:
def __init__(self) -> None:
self._init = 0
self._shutdown = 0

def init(self) -> None:
self._init += 1

def shutdown(self) -> None:
self._shutdown += 1

class Singleton:
pass
def reset(self) -> None:
self._init = 0
self._shutdown = 0


class Service:
init_counter: int = 0
shutdown_counter: int = 0
dependency: Singleton = None
def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
self.counter = counter or Counter()
self.dependencies = dependencies

@classmethod
def reset_counter(cls):
cls.init_counter = 0
cls.shutdown_counter = 0
def init(self) -> None:
self.counter.init()

@classmethod
def init(cls, dependency: Singleton = None):
if dependency:
cls.dependency = dependency
cls.init_counter += 1
def shutdown(self) -> None:
self.counter.shutdown()

@classmethod
def shutdown(cls):
cls.shutdown_counter += 1
@property
def init_counter(self) -> int:
return self.counter._init

@property
def shutdown_counter(self) -> int:
return self.counter._shutdown


class FactoryService:
def __init__(self, service: Service):
def __init__(self, service: Service, service2: Service):
self.service = service
self.service2 = service2


class NestedService:
def __init__(self, factory_service: FactoryService):
self.factory_service = factory_service


def init_service():
service = Service()
def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
service = Service(counter, _list=_list, _dict=_dict)
service.init()
yield service
service.shutdown()


def init_service_with_singleton(singleton: Singleton):
service = Service()
service.init(singleton)
yield service
service.shutdown()


class Container(containers.DeclarativeContainer):

service = providers.Resource(init_service)
factory_service = providers.Factory(FactoryService, service)
factory_service_kwargs = providers.Factory(
FactoryService,
service=service
counter = providers.Singleton(Counter)
_list = providers.List(
providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
)
nested_service = providers.Factory(NestedService, factory_service)


class ContainerSingleton(containers.DeclarativeContainer):

singleton = providers.Singleton(Singleton)
service = providers.Resource(
init_service_with_singleton,
singleton
_dict = providers.Dict(
a=providers.Callable(lambda a: a, a=1), b=providers.Callable(lambda b: b, 2)
)
factory_service = providers.Factory(FactoryService, service)
service = providers.Resource(init_service, counter, _list, _dict=_dict)
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
factory_service = providers.Factory(FactoryService, service, service2)
factory_service_kwargs = providers.Factory(
FactoryService,
service=service
service=service,
service2=service2,
)
nested_service = providers.Factory(NestedService, factory_service)

Expand All @@ -84,20 +84,20 @@ def test_function(service: Service = Closing[Provide["service"]]):

@inject
def test_function_dependency(
factory: FactoryService = Closing[Provide["factory_service"]]
factory: FactoryService = Closing[Provide["factory_service"]],
):
return factory


@inject
def test_function_dependency_kwargs(
factory: FactoryService = Closing[Provide["factory_service_kwargs"]]
factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
):
return factory


@inject
def test_function_nested_dependency(
nested: NestedService = Closing[Provide["nested_service"]]
nested: NestedService = Closing[Provide["nested_service"]],
):
return nested
54 changes: 21 additions & 33 deletions tests/unit/wiring/string_ids/test_main_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from decimal import Decimal

from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire
from pytest import fixture, mark, raises

from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.service import Service
from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service

from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire


@fixture(autouse=True)
Expand All @@ -33,14 +33,12 @@ def subcontainer():
container.unwire()


@fixture(params=[
resourceclosing.Container,
resourceclosing.ContainerSingleton,
])
@fixture
def resourceclosing_container(request):
container = request.param()
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
yield container
with container.reset_singletons():
yield container
container.unwire()


Expand Down Expand Up @@ -277,8 +275,6 @@ def test_wire_multiple_containers():

@mark.usefixtures("resourceclosing_container")
def test_closing_resource():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function()
assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1
Expand All @@ -294,55 +290,48 @@ def test_closing_resource():

@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function_dependency()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1
assert result_1.service.shutdown_counter == 1
assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_dependency()

assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2
assert result_2.service.shutdown_counter == 2
assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 4


@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource_kwargs():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1
assert result_1.service.shutdown_counter == 1
assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2
assert result_2.service.shutdown_counter == 2
assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 4


@mark.usefixtures("resourceclosing_container")
def test_closing_nested_dependency_resource():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_1, resourceclosing.NestedService)
assert result_1.factory_service.service.init_counter == 1
assert result_1.factory_service.service.shutdown_counter == 1
assert result_1.factory_service.service.init_counter == 2
assert result_1.factory_service.service.shutdown_counter == 2

result_2 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_2, resourceclosing.NestedService)
assert result_2.factory_service.service.init_counter == 2
assert result_2.factory_service.service.shutdown_counter == 2
assert result_2.factory_service.service.init_counter == 4
assert result_2.factory_service.service.shutdown_counter == 4

assert result_1 is not result_2


@mark.usefixtures("resourceclosing_container")
def test_closing_resource_bypass_marker_injection():
resourceclosing.Service.reset_counter()

result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1
Expand All @@ -358,7 +347,6 @@ def test_closing_resource_bypass_marker_injection():

@mark.usefixtures("resourceclosing_container")
def test_closing_resource_context():
resourceclosing.Service.reset_counter()
service = resourceclosing.Service()

result_1 = resourceclosing.test_function(service=service)
Expand Down

0 comments on commit 4badee7

Please sign in to comment.