diff --git a/docs/provider/provide.rst b/docs/provider/provide.rst index 18d777e3..7dfe4926 100644 --- a/docs/provider/provide.rst +++ b/docs/provider/provide.rst @@ -136,7 +136,6 @@ It works similar to :ref:`alias`. a = await container.get(AImpl) a is a # True - WithParents generates only one factory and many aliases and is equivalent to ``AnyOf[AImpl, A]``. The following parents are ignored: ``type``, ``object``, ``Enum``, ``ABC``, ``ABCMeta``, ``Generic``, ``Protocol``, ``Exception``, ``BaseException`` * You object's dependencies (and their dependencies) can be simply created by calling their constructors. You do not need to register them manually. Use ``recursive=True`` to register them automatically @@ -178,3 +177,21 @@ WithParents generates only one factory and many aliases and is equivalent to ``A def make_a(self, type_: type[T]) -> A[T]: ... +* Do you want to get "dependencies" by parents which are protocols? Use ``WithProtocols`` as a result hint: + +.. code-block:: python + + from dishka import WithProtocols, provide, Provider, Scope + + class A(Protocol): ... + class AImpl(A): ... + + class MyProvider(Provider): + scope=Scope.APP + + @provide + def a(self) -> WithProtocols[AImpl]: + return A() + + container = make_async_container(MyProvider()) + await container.get(A) diff --git a/src/dishka/__init__.py b/src/dishka/__init__.py index aca426f7..6f624107 100644 --- a/src/dishka/__init__.py +++ b/src/dishka/__init__.py @@ -19,6 +19,7 @@ "provide", "provide_all", "new_scope", + "WithProtocols", "ValidationSettings", "STRICT_VALIDATION", ] @@ -39,4 +40,5 @@ from .entities.scope import BaseScope, Scope, new_scope from .entities.validation_settigs import STRICT_VALIDATION, ValidationSettings from .entities.with_parents import WithParents +from .entities.with_protocols import WithProtocols from .provider import Provider diff --git a/src/dishka/entities/with_parents.py b/src/dishka/entities/with_parents.py index 395ac89b..828e509b 100644 --- a/src/dishka/entities/with_parents.py +++ b/src/dishka/entities/with_parents.py @@ -21,6 +21,8 @@ __all__ = ["WithParents", "ParentsResolver"] +from dishka.text_rendering import get_name + IGNORE_TYPES: Final = ( type, object, @@ -84,8 +86,9 @@ def create_type_vars_map(obj: TypeHint) -> dict[TypeHint, TypeHint]: class ParentsResolver: def get_parents(self, child_type: TypeHint) -> list[TypeHint]: if is_ignored_type(strip_alias(child_type)): + name = get_name(child_type, include_module=False) raise ValueError( - f"The starting class {child_type!r} is in ignored types", + f"The starting class {name} is in ignored types", ) if is_parametrized(child_type) or has_orig_bases(child_type): return self._get_parents_for_generic(child_type) diff --git a/src/dishka/entities/with_protocols.py b/src/dishka/entities/with_protocols.py new file mode 100644 index 00000000..a93b945c --- /dev/null +++ b/src/dishka/entities/with_protocols.py @@ -0,0 +1,42 @@ +__all__ = ["WithProtocols"] + +from typing import TYPE_CHECKING, TypeVar + +from dishka._adaptix.common import TypeHint +from dishka._adaptix.type_tools import is_protocol, strip_alias +from dishka.entities.provides_marker import ProvideMultiple +from dishka.entities.with_parents import ParentsResolver +from dishka.text_rendering import get_name + + +def get_parents_protocols(type_hint: TypeHint) -> list[TypeHint]: + parents = ParentsResolver().get_parents(type_hint) + new_parents = [ + parent for parent in parents + if is_protocol(strip_alias(parent)) + ] + if new_parents: + return new_parents + + name = get_name(type_hint, include_module=False) + error_msg = ( + f"Not a single parent of the protocol was found in {name}.\n" + "Hint:\n" + f" * Maybe you meant just {name}, not WithProtocols[{name}]\n" + ) + if len(parents) > 1: + error_msg += f" * Perhaps you meant WithParents[{name}]?" + raise ValueError(error_msg) + + +T = TypeVar("T") +if TYPE_CHECKING: + from typing import Union + WithProtocols = Union[T, T] # noqa: UP007,PYI016 +else: + class WithProtocols: + def __class_getitem__(cls, item: TypeHint) -> TypeHint: + parents = get_parents_protocols(item) + if len(parents) > 1: + return ProvideMultiple(parents) + return parents[0] diff --git a/tests/unit/container/test_with_protocols.py b/tests/unit/container/test_with_protocols.py new file mode 100644 index 00000000..5eec0f58 --- /dev/null +++ b/tests/unit/container/test_with_protocols.py @@ -0,0 +1,38 @@ +from typing import Protocol + +import pytest + +from dishka import Provider, Scope, WithProtocols, make_container +from dishka.exceptions import NoFactoryError + + +class AProtocol(Protocol): + pass + + +class BProtocol(Protocol): + pass + + +class C(AProtocol, BProtocol): + pass + + +def test_get_parents_protocols() -> None: + provider = Provider(scope=Scope.APP) + provider.provide(C, provides=WithProtocols[C]) + container = make_container(provider) + + assert ( + container.get(BProtocol) + is container.get(AProtocol) + ) + + +def test_get_by_not_protocol() -> None: + provider = Provider(scope=Scope.APP) + provider.provide(C, provides=WithProtocols[C]) + container = make_container(provider) + + with pytest.raises(NoFactoryError): + container.get(C)