diff --git a/docs/conf.py b/docs/conf.py index 7b75ce17..f2bcb23e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,7 +2,7 @@ # -- Project information project = "that-depends" -copyright = "2024, Modern Python" +copyright = "2025, Modern Python" author = "Shiriev Artur" release = "" diff --git a/docs/index.md b/docs/index.md index 20fdf6d4..2c7ea9e4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,6 +11,7 @@ introduction/litestar introduction/faststream introduction/inject-factories + introduction/scopes introduction/multiple-containers introduction/dynamic-container introduction/application-settings diff --git a/docs/introduction/scopes.md b/docs/introduction/scopes.md new file mode 100644 index 00000000..c1cc45c4 --- /dev/null +++ b/docs/introduction/scopes.md @@ -0,0 +1,171 @@ +# Named Scopes + +Named scopes allow you to define the lifecycle of a `ContextResource`. +In essence, they provide a tool to manage when `ContextResources` can be resolved and when they should be finalized. + +Before continuing, make sure you're familiar with `ContextResource` providers by reading their [documentation](../providers/context-resources.md). + +## Quick Start + +By default, `ContextResources` have the named scope `ANY`, meaning they will be re-initialized each time you enter a named scope. +You can change the scope of a `ContextResource` in two ways: + +### Setting the scope for providers + +1. By setting the `default_scope` attribute in the container class: + + ```python + class MyContainer(BaseContainer): + default_scope = ContextScope.APP + p = providers.ContextResource(my_resource) + ``` + +2. By calling the `with_config()` method when creating a `ContextResource`. This also overrides the class default: + + ```python + p = providers.ContextResource(my_resource).with_config(scope=ContextScope.APP) + ``` + +### Entering and exiting scopes + +Once you have assigned scopes to providers, you can enter a named scope using `container_context()`. +After entering a scope, you can resolve resources that have been defined with that scope: + +```python +from that_depends import container_context + +async with container_context(scope=ContextScopes.APP): + # resolve resources with scope APP + await my_app_scoped_provider.async_resolve() +``` + +## Checking the current scope + +If you want to check the current scope, you can use the `get_current_scope()` function: + +```python +from that_depends.providers.context_resources import get_current_scope, ContextScopes + +async with container_context(scope=ContextScopes.APP): + assert get_current_scope() == ContextScopes.APP +``` + +## Understanding resolution & strict scope providers + +In order for a `ContextResource` to be resolved, you must first initialize the context for that resource. +When you call `container_context(scope=ContextScopes.APP)` this both enters the `APP` scope and (re-)initializes context for +all providers that have `APP` scope. Scoped resources will prevent their context initialization if the current scope does +not match their scope: +```python +p = providers.ContextResource(my_resource).with_config(scope=ContextScopes.APP) + +async with p.async_context(): + # will raise an InvalidContextError since current scope is `None` + ... +``` + +Similarly, this will also not work: +```python +async with container_context(p, scope=ContextScopes.REQUEST): + # will raise and InvalidContextError since you are entering `REQUEST` scope + ... +``` + +Once the context has been initialized, a resource can be resolved regardless of the current scope. For example: + +```python +await p.async_resolve() # will raise an exception + +async with container_context(p, scope=ContextScopes.APP): + val_1 = await p.async_resolve() # will resolve + async with container_context(p, scope=ContextScopes.REQUEST): + val_2 = await p.async_resolve() # will resolve + assert val_1 == val_2 # but value stays the same since context is the same +``` + +If you want resources to be resolved **only** in the specified scope, enable strict resolution: + +```python +p = providers.ContextResource(my_resource).with_config(scope=ContextScopes.APP, strict_scope=True) +async with container_context(p, scope=ContextScopes.APP): + await p.async_resolve() # will resolve + + async with container_context(scope=ContextScopes.REQUEST): + await p.async_resolve() # will raise an exception +``` + +## Entering a context by force + +If you for some reason need to (re-)initialize a context for a `ContextResource` outside of its defined scope, +you can force enter its context: +```python +p = providers.ContextResource(my_resource).with_config(scope=ContextScopes.APP) + +async with p.async_context(force=True): + assert get_current_scope() == None + await p.async_resolve() # will resolve +``` +Or similarly using the `context` wrapper (both `ContextResource` providers and containers provide this API): +```python +class Container(BaseContainer): + p = providers.ContextResource(my_resource).with_config(scope=ContextScopes.APP) + +@Container.context(force=True) +@inject +async def injected(val = Provide[Container.p]): + return p + +await injected() # will resolve +``` + +## Predefined scopes + +`that-depends` includes four predefined scopes in the `ContextScopes` class: + +- `ANY`: Indicates that a resource can be resolved in any scope (even `None`). This scope cannot be entered, so it won’t be accepted by any class or method that requires entering a named scope. + +- `APP`: A convenience scope with no special behavior. + +- `REQUEST`: A convenience scope with no special behavior. + +- `INJECT`: The default scope of the `@inject` wrapper. Read more in the [Named scopes with the @inject wrapper](#named-scopes-with-the-inject-wrapper) section. + +> **Note:** The default scope, before entering any named scope, is `None`. You can pass `None` as a scope to providers, but since it cannot be entered, in most scenarios passing `None` simply means you did not specify a scope. + +## Named scopes with the `@inject` wrapper + +The `@inject` wrapper also supports named scopes. Its default scope is `INJECT`, but you can pass any scope you like: + +```python +@inject(scope=ContextScopes.APP) +def foo(...): + get_current_scope() # APP +``` + +When you pass a scope to the `@inject` wrapper, it enters that scope before calling the function, and exits the scope after the function returns. If you do not want to enter any scope, pass `None`. + +## Implementing custom scopes + +If the default scopes don’t fit your needs, you can define custom scopes by creating a `ContextScope` object: + +```python +from that_depends.providers.context_resources import ContextScope + +CUSTOM = ContextScope("CUSTOM") +``` + +If you want to group all of your scopes in one place, you can extend the `ContextScopes` class: + +```python +from that_depends.providers.context_resources import ContextScopes, ContextScope + +class MyContextScopes(ContextScopes): + CUSTOM = ContextScope("CUSTOM") +``` + +## Named scopes with middleware +You can pass a named scope to the `DIContextMiddleware` to set the scope and pre-initialize scoped `ContextResources` for the entire request: + +```python +middleware = DIContextMiddleware(app, scope=ContextScopes.REQUEST) +``` diff --git a/docs/migration/v2.md b/docs/migration/v2.md index 70f35a9d..cc4b26f6 100644 --- a/docs/migration/v2.md +++ b/docs/migration/v2.md @@ -33,7 +33,8 @@ If you want to learn more about the new features introduced in `2.*`, please ref await MyContainer.init_resources() ``` -2. **`that_depends.providers.AsyncResource` removed** +2. **`that_depends.providers.AsyncResource` removed** + The `AsyncResource` class has been removed. Use `providers.Resource` instead. **Example:** @@ -83,6 +84,21 @@ If you want to learn more about the new features introduced in `2.*`, please ref > **Note:** `reset_all_containers=True` only reinitializes the context for `ContextResource` instances defined within containers (i.e., classes inheriting from `BaseContainer`). If you also need to reset contexts for resources defined outside containers, you must handle these explicitly. See the [ContextResource documentation](../providers/context-resources.md) for more details. +3. **Container classes now require you to define `default_scope`** + + In `2.*`, you must define the `default_scope` attribute in your container classes if you plan to define any `ContextResource` providers in that class. This attribute specifies the default scope for all `ContextResource` providers defined within the container. + + **Example:** + + ```python + from that_depends import BaseContainer, providers + + class MyContainer(BaseContainer): + default_scope = None # This will maintain compatibility with 1.* + p = providers.ContextResource(my_resource) + ``` + Setting the value of `default_context = None` maintains the same behaviours as in `1.*`. Please look at the [scopes Documentation](../introduction/scopes.md). + --- ## Potential Issues with `container_context()` diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index 64ff900f..8daa823b 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -187,7 +187,7 @@ async with container_context(): ### Resolving resources whenever a function is called -`container_context` can be used as a decorator: +`ContextResource.context()` can also be used as a decorator: ```python @MyContainer.session.context # wrap with a session-specific context @inject @@ -207,6 +207,8 @@ Each time you call `await insert_into_database()`, a new instance of `session` w | Reset all resources in a container | `async with container_context(my_container):` | `async with my_container.async_context():` | `@my_container.context` | | Reset all sync resources in a container | `with container_context(my_container):` | `with my_container.sync_context():` | `@my_container.context` | +> **Note:** the `context()` wrapper is technically not part of the `SupportsContext` API, however all classes which +> implement this `SupportsContext` also implement this method. --- ## Middleware diff --git a/tests/container.py b/tests/container.py index a2f53507..42c4180e 100644 --- a/tests/container.py +++ b/tests/container.py @@ -54,6 +54,7 @@ class SingletonFactory: class DIContainer(BaseContainer): + default_scope = None sync_resource = providers.Resource(create_sync_resource) async_resource = providers.Resource(create_async_resource) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index e7ef285a..bbd9911e 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -6,12 +6,22 @@ import typing import uuid from contextlib import AsyncExitStack, ExitStack +from unittest.mock import Mock import pytest from that_depends import BaseContainer, Provide, fetch_context_item, inject, providers from that_depends.entities.resource_context import ResourceContext +from that_depends.meta import DefaultScopeNotDefinedError from that_depends.providers import container_context +from that_depends.providers.context_resources import ( + ContextScope, + ContextScopes, + DIContextMiddleware, + InvalidContextError, + _enter_named_scope, + get_current_scope, +) logger = logging.getLogger(__name__) @@ -30,6 +40,7 @@ async def create_async_context_resource() -> typing.AsyncIterator[str]: class DIContainer(BaseContainer): + default_scope = ContextScopes.ANY sync_context_resource = providers.ContextResource(create_sync_context_resource) async_context_resource = providers.ContextResource(create_async_context_resource) dynamic_context_resource = providers.Selector( @@ -40,6 +51,7 @@ class DIContainer(BaseContainer): class DependentDiContainer(BaseContainer): + default_scope = ContextScopes.ANY dependent_sync_context_resource = providers.ContextResource(create_sync_context_resource) dependent_async_context_resource = providers.ContextResource(create_async_context_resource) @@ -618,6 +630,7 @@ async def slow_async_creator() -> typing.AsyncIterator[str]: yield str(uuid.uuid4()) class MyContainer(BaseContainer): + default_scope = None slow_provider = providers.ContextResource(slow_async_creator) async def _injected() -> str: @@ -633,6 +646,7 @@ def slow_sync_creator() -> typing.Iterator[str]: yield str(uuid.uuid4()) class MyContainer(BaseContainer): + default_scope = None slow_provider = providers.ContextResource(slow_sync_creator) def _injected() -> str: @@ -646,3 +660,398 @@ def _injected() -> str: for thread in threads: thread.join() + + +def test_default_named_scope_is_none() -> None: + assert get_current_scope() is None + + +def test_entering_scope_sets_current_scope() -> None: + with _enter_named_scope(ContextScopes.INJECT): + assert get_current_scope() == ContextScopes.INJECT + assert get_current_scope() is None + + +def test_entering_scope_with_container_context_sync() -> None: + with container_context(scope=ContextScopes.INJECT): + assert get_current_scope() == ContextScopes.INJECT + assert get_current_scope() is None + + +async def test_entering_scope_with_container_context_async() -> None: + async with container_context(scope=ContextScopes.INJECT): + assert get_current_scope() == ContextScopes.INJECT + assert get_current_scope() is None + + +def test_scoped_provider_get_scope() -> None: + provider = providers.ContextResource(create_async_context_resource) + assert provider.get_scope() == ContextScopes.ANY + provider = provider.with_config(scope=ContextScopes.INJECT) + assert provider.get_scope() == ContextScopes.INJECT + + +def test_scoped_container_get_scope() -> None: + class _Container(BaseContainer): ... + + assert _Container.get_scope() is ContextScopes.ANY + + class _ScopedContainer(BaseContainer): + default_scope = ContextScopes.INJECT + + assert _ScopedContainer.get_scope() == ContextScopes.INJECT + + +def test_sync_resolve_scoped_resource() -> None: + provider = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.INJECT) + with pytest.raises(RuntimeError): + provider.sync_resolve() + + with container_context(provider, scope=ContextScopes.INJECT): + assert provider.sync_resolve() is not None + + +async def test_async_resolve_scoped_resource() -> None: + provider = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.INJECT) + with pytest.raises(RuntimeError): + await provider.async_resolve() + + async with container_context(provider, scope=ContextScopes.INJECT): + assert await provider.async_resolve() is not None + + +async def test_async_resolve_non_scoped_in_named_context() -> None: + provider = providers.ContextResource(create_async_context_resource) + async with container_context(provider, scope=ContextScopes.INJECT): + assert await provider.async_resolve() is not None + + +def test_sync_resolve_non_scoped_in_named_context() -> None: + provider = providers.ContextResource(create_sync_context_resource) + with container_context(provider, scope=ContextScopes.INJECT): + assert provider.sync_resolve() is not None + + +async def test_async_container_init_context_for_scoped_resources() -> None: + class _Container(BaseContainer): + async_resource = providers.ContextResource(create_async_context_resource).with_config( + scope=ContextScopes.INJECT + ) + + async with container_context(scope=ContextScopes.INJECT): + assert await _Container.async_resource.async_resolve() is not None + with pytest.raises(RuntimeError): + async with container_context(scope=None): + assert await _Container.async_resource.async_resolve() is not None + + +def test_sync_container_init_context_for_scoped_resources() -> None: + class _Container(BaseContainer): + sync_resource = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.INJECT) + + with container_context(scope=ContextScopes.INJECT): + assert _Container.sync_resource.sync_resolve() is not None + with pytest.raises(RuntimeError), container_context(scope=None): + assert _Container.sync_resource.sync_resolve() is not None + + +async def test_sync_container_init_context_for_default_container_resources() -> None: + class _Container(BaseContainer): + default_scope = ContextScopes.INJECT + sync_resource = providers.ContextResource(create_sync_context_resource) + + assert _Container.sync_resource.get_scope() == ContextScopes.INJECT + with container_context(scope=ContextScopes.INJECT): + assert _Container.sync_resource.sync_resolve() is not None + + +def test_container_with_context_resources_must_have_default_scope_set() -> None: + with pytest.raises(DefaultScopeNotDefinedError): + + class _Container(BaseContainer): + sync_resource = providers.ContextResource(create_sync_context_resource) + + +def test_providers_with_explicit_scope_ignore_default_scope() -> None: + class _Container(BaseContainer): + default_scope = None + sync_resource = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.INJECT) + + assert _Container.sync_resource.get_scope() == ContextScopes.INJECT + + +async def test_none_scoped_provider_should_not_be_resolvable_in_named_scope_async() -> None: + provider = providers.ContextResource(create_async_context_resource).with_config(scope=None) + async with container_context(scope=ContextScopes.INJECT): + with pytest.raises(RuntimeError): + await provider.async_resolve() + + +def test_none_scoped_provider_should_not_be_resolvable_in_named_scope_sync() -> None: + provider = providers.ContextResource(create_sync_context_resource).with_config(scope=None) + with container_context(scope=ContextScopes.INJECT), pytest.raises(RuntimeError): + provider.sync_resolve() + + +def test_container_context_does_not_support_scope_any() -> None: + with ( + pytest.raises(ValueError, match=f"{ContextScopes.ANY} cannot be entered!"), + ): + container_context(scope=ContextScopes.ANY) + + +def test_di_middleware_does_not_support_scope_any() -> None: + with ( + pytest.raises(ValueError, match=f"{ContextScopes.ANY} cannot be entered!"), + ): + DIContextMiddleware(Mock(), scope=ContextScopes.ANY) + + +async def test_resource_context_does_not_reset_in_wrong_scope_async() -> None: + class _Container(BaseContainer): + default_scope = ContextScopes.REQUEST + p_app = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.APP) + p_request = providers.ContextResource(create_async_context_resource) + + async with container_context(scope=ContextScopes.APP): + value_app_1 = await _Container.p_app.async_resolve() + with pytest.raises(RuntimeError): + await _Container.p_request.async_resolve() + + async with container_context(scope=ContextScopes.REQUEST): + value_app_2 = await _Container.p_app.async_resolve() + assert await _Container.p_request.async_resolve() is not None + + assert value_app_1 == value_app_2 + + with pytest.raises(RuntimeError): + await _Container.p_request.async_resolve() + + +def test_resource_context_does_not_set_in_wrong_scope_sync() -> None: + class _Container(BaseContainer): + default_scope = ContextScopes.REQUEST + p_app = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.APP) + p_request = providers.ContextResource(create_sync_context_resource) + + with container_context(scope=ContextScopes.APP): + value_app_1 = _Container.p_app.sync_resolve() + with pytest.raises(RuntimeError): + _Container.p_request.sync_resolve() + + with container_context(scope=ContextScopes.REQUEST): + value_app_2 = _Container.p_app.sync_resolve() + assert _Container.p_request.sync_resolve() is not None + + assert value_app_1 == value_app_2 + + with pytest.raises(RuntimeError): + _Container.p_request.sync_resolve() + + +async def test_strict_scope_resource_only_resolvable_in_given_scope_async() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_async_context_resource).with_config( + scope=ContextScopes.APP, strict_scope=True + ) + p_request = providers.ContextResource(create_async_context_resource).with_config( + scope=ContextScopes.REQUEST, strict_scope=True + ) + + with pytest.raises(RuntimeError): + await _Container.p_app.async_resolve() + + with pytest.raises(InvalidContextError): + await container_context(_Container.p_app, _Container.p_request).__aenter__() + + async with container_context(scope=ContextScopes.APP): + assert await _Container.p_app.async_resolve() is not None + with pytest.raises(RuntimeError): + await _Container.p_request.async_resolve() + + async with container_context(scope=ContextScopes.REQUEST): + assert await _Container.p_request.async_resolve() is not None + with pytest.raises(RuntimeError): + await _Container.p_app.async_resolve() + + +def test_strict_scope_resource_only_resolvable_in_given_scope_sync() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_sync_context_resource).with_config( + scope=ContextScopes.APP, strict_scope=True + ) + p_request = providers.ContextResource(create_sync_context_resource).with_config( + scope=ContextScopes.REQUEST, strict_scope=True + ) + + with pytest.raises(RuntimeError): + _Container.p_app.sync_resolve() + + with pytest.raises(InvalidContextError): + container_context(_Container.p_app, _Container.p_request).__enter__() + + with container_context(scope=ContextScopes.APP): + assert _Container.p_app.sync_resolve() is not None + with pytest.raises(RuntimeError): + _Container.p_request.sync_resolve() + + with container_context(scope=ContextScopes.REQUEST): + assert _Container.p_request.sync_resolve() is not None + with pytest.raises(RuntimeError): + _Container.p_app.sync_resolve() + + +def test_strict_scope_not_allowed_with_any_scope() -> None: + with pytest.raises(ValueError, match=f"Cannot set strict_scope with scope {ContextScopes.ANY}."): + providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.ANY, strict_scope=True) + + +async def test_async_resource_with_custom_scope() -> None: + class MyScopes(ContextScopes): + CUSTOM = ContextScope("CUSTOM") + + class _Container(BaseContainer): + p_custom = providers.ContextResource(create_async_context_resource).with_config(scope=MyScopes.CUSTOM) + + assert _Container.p_custom.get_scope() == MyScopes.CUSTOM + + with pytest.raises(RuntimeError): + await _Container.p_custom.async_resolve() + + async with container_context(_Container.p_custom, scope=MyScopes.CUSTOM): + assert await _Container.p_custom.async_resolve() is not None + + +async def test_async_entering_container_context_for_all_containers_correctly_handles_named_scopes() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.APP) + p_request = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.REQUEST) + + async with container_context(reset_all_containers=True): + with pytest.raises(RuntimeError): + await _Container.p_app.async_resolve() + + async with container_context(reset_all_containers=True, scope=ContextScopes.APP): + assert await _Container.p_app.async_resolve() is not None + with pytest.raises(RuntimeError): + await _Container.p_request.async_resolve() + + async with container_context(reset_all_containers=True, scope=ContextScopes.REQUEST): + assert await _Container.p_request.async_resolve() is not None + with pytest.raises(RuntimeError): + await _Container.p_app.async_resolve() + + +def test_sync_entering_container_context_for_all_containers_correctly_handles_named_scopes() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.APP) + p_request = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.REQUEST) + + with container_context(reset_all_containers=True), pytest.raises(RuntimeError): + _Container.p_app.sync_resolve() + + with container_context(reset_all_containers=True, scope=ContextScopes.APP): + assert _Container.p_app.sync_resolve() is not None + with pytest.raises(RuntimeError): + _Container.p_request.sync_resolve() + + with container_context(reset_all_containers=True, scope=ContextScopes.REQUEST): + assert _Container.p_request.sync_resolve() is not None + with pytest.raises(RuntimeError): + _Container.p_app.sync_resolve() + + +async def test_async_force_enter_context_for_scoped_resource() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.APP) + + async with _Container.p_app.async_context(force=True): + assert await _Container.p_app.async_resolve() is not None + + async with _Container.async_context(force=True): + assert await _Container.p_app.async_resolve() is not None + + +def test_sync_force_enter_context_for_scoped_resource() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.APP) + + with _Container.p_app.sync_context(force=True): + assert _Container.p_app.sync_resolve() is not None + + with _Container.sync_context(force=True): + assert _Container.p_app.sync_resolve() is not None + + +async def test_async_force_enter_context_with_context_annotation() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_async_context_resource).with_config(scope=ContextScopes.APP) + + @_Container.context(force=True) + @inject(scope=None) + async def _injected(val: str = Provide[_Container.p_app]) -> str: + return val + + @_Container.p_app.context(force=True) + @inject(scope=None) + async def _injected_p(val: str = Provide[_Container.p_app]) -> str: + return val + + assert await _injected() is not None + assert await _injected_p() is not None + + +def test_sync_force_enter_context_with_context_annotation() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_sync_context_resource).with_config(scope=ContextScopes.APP) + + @_Container.context(force=True) + @inject(scope=None) + def _injected(val: str = Provide[_Container.p_app]) -> str: + return val + + @_Container.p_app.context(force=True) + @inject(scope=None) + def _injected_p(val: str = Provide[_Container.p_app]) -> str: + return val + + assert _injected() is not None + assert _injected_p() is not None + + +async def test_async_container_context_selects_context_items_on_entry() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_async_context_resource).with_config( + scope=ContextScopes.APP, strict_scope=True + ) + p_request = providers.ContextResource(create_async_context_resource).with_config( + scope=ContextScopes.REQUEST, strict_scope=True + ) + + async with container_context(scope=ContextScopes.APP): + cc = container_context() + + async with container_context(scope=ContextScopes.REQUEST): + assert get_current_scope() == ContextScopes.REQUEST + async with cc: + assert get_current_scope() == ContextScopes.REQUEST + assert await _Container.p_request.async_resolve() is not None + + +def test_sync_container_context_selects_context_items_on_entry() -> None: + class _Container(BaseContainer): + p_app = providers.ContextResource(create_sync_context_resource).with_config( + scope=ContextScopes.APP, strict_scope=True + ) + p_request = providers.ContextResource(create_sync_context_resource).with_config( + scope=ContextScopes.REQUEST, strict_scope=True + ) + + with container_context(scope=ContextScopes.APP): + cc = container_context() + + with container_context(scope=ContextScopes.REQUEST): + assert get_current_scope() == ContextScopes.REQUEST + with cc: + assert get_current_scope() == ContextScopes.REQUEST + assert _Container.p_request.sync_resolve() is not None diff --git a/tests/providers/test_local_singleton.py b/tests/providers/test_local_singleton.py index b2c963cf..e83ab4c0 100644 --- a/tests/providers/test_local_singleton.py +++ b/tests/providers/test_local_singleton.py @@ -10,6 +10,9 @@ from that_depends.providers import AsyncFactory, ThreadLocalSingleton +random.seed(23) + + async def _async_factory() -> int: await asyncio.sleep(0.01) return threading.get_ident() diff --git a/tests/test_injection.py b/tests/test_injection.py index b9a8a6b5..de0203cc 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -1,11 +1,13 @@ import asyncio import datetime +import typing import warnings import pytest from tests import container -from that_depends import Provide, inject +from that_depends import BaseContainer, Provide, inject, providers +from that_depends.providers.context_resources import ContextScopes @pytest.fixture(name="fixture_one") @@ -13,6 +15,14 @@ def create_fixture_one() -> int: return 1 +async def _async_creator() -> typing.AsyncIterator[int]: + yield 1 + + +def _sync_creator() -> typing.Iterator[int]: + yield 1 + + @inject async def test_injection( fixture_one: int, @@ -73,8 +83,7 @@ def inner( return _ factory = container.SimpleFactory(dep1="1", dep2=2) - with pytest.warns(RuntimeWarning, match="Expected injection, but nothing found. Remove @inject decorator."): - assert inner(_=factory) == factory + assert inner(_=factory) == factory def test_sync_empty_injection() -> None: @@ -94,3 +103,40 @@ async def main(simple_factory: container.SimpleFactory = Provide[container.DICon assert simple_factory asyncio.run(main()) + + +async def test_async_injection_with_scope() -> None: + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + async_resource = providers.ContextResource(_async_creator).with_config(scope=ContextScopes.INJECT) + + async def _injected(val: int = Provide[_Container.async_resource]) -> int: + return val + + assert await inject(scope=ContextScopes.INJECT)(_injected)() == 1 + assert await inject(_injected)() == 1 + with pytest.raises(RuntimeError): + await inject(scope=None)(_injected)() + with pytest.raises(RuntimeError): + await inject(scope=ContextScopes.REQUEST)(_injected)() + + +async def test_sync_injection_with_scope() -> None: + class _Container(BaseContainer): + default_scope = ContextScopes.ANY + p_inject = providers.ContextResource(_sync_creator).with_config(scope=ContextScopes.INJECT) + + def _injected(val: int = Provide[_Container.p_inject]) -> int: + return val + + assert inject(scope=ContextScopes.INJECT)(_injected)() == 1 + assert inject(_injected)() == 1 + with pytest.raises(RuntimeError): + inject(scope=None)(_injected)() + with pytest.raises(RuntimeError): + inject(scope=ContextScopes.REQUEST)(_injected)() + + +def test_inject_decorator_should_not_allow_any_scope() -> None: + with pytest.raises(ValueError, match=f"{ContextScopes.ANY} is not allowed in inject decorator."): + inject(scope=ContextScopes.ANY) diff --git a/that_depends/__init__.py b/that_depends/__init__.py index aac2a910..c96fbe01 100644 --- a/that_depends/__init__.py +++ b/that_depends/__init__.py @@ -4,7 +4,7 @@ from that_depends.container import BaseContainer from that_depends.injection import Provide, inject from that_depends.providers import container_context -from that_depends.providers.context_resources import fetch_context_item +from that_depends.providers.context_resources import fetch_context_item, get_current_scope __all__ = [ @@ -12,6 +12,7 @@ "Provide", "container_context", "fetch_context_item", + "get_current_scope", "inject", "providers", ] diff --git a/that_depends/container.py b/that_depends/container.py index 7f102f67..3cb0f675 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,12 +1,13 @@ import inspect import typing from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager +from typing import overload from typing_extensions import override from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, Resource, Singleton -from that_depends.providers.context_resources import ContextResource, SupportsContext +from that_depends.providers.context_resources import ContextResource, ContextScope, ContextScopes, SupportsContext if typing.TYPE_CHECKING: @@ -22,6 +23,55 @@ class BaseContainer(SupportsContext[None], metaclass=BaseContainerMeta): providers: dict[str, AbstractProvider[typing.Any]] containers: list[type["BaseContainer"]] + default_scope: ContextScope | None = ContextScopes.ANY + + @classmethod + @overload + def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... + + @classmethod + @overload + def context(cls, *, force: bool = False) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: ... + + @classmethod + def context( + cls, func: typing.Callable[P, T] | None = None, force: bool = False + ) -> typing.Callable[P, T] | typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: + """Wrap a function with this resources' context. + + Args: + func: function to be wrapped. + force: force context initialization, ignoring scope. + + Returns: + wrapped function or wrapper if func is None. + + """ + + def _wrapper(func: typing.Callable[P, T]) -> typing.Callable[P, T]: + if inspect.iscoroutinefunction(func): + + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with cls.async_context(force=force): + return await func(*args, **kwargs) # type: ignore[no-any-return, misc] + + return typing.cast(typing.Callable[P, T], _async_wrapper) + + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with cls.sync_context(force=force): + return func(*args, **kwargs) + + return _sync_wrapper + + if func: + return _wrapper(func) + return _wrapper + + @classmethod + @override + def get_scope(cls) -> ContextScope | None: + """Get default container scope.""" + return cls.default_scope @override def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": @@ -36,44 +86,27 @@ def supports_sync_context(cls) -> bool: @classmethod @contextmanager @override - def sync_context(cls) -> typing.Iterator[None]: + def sync_context(cls, force: bool = False) -> typing.Iterator[None]: with ExitStack() as stack: for container in cls.get_containers(): - stack.enter_context(container.sync_context()) + stack.enter_context(container.sync_context(force=force)) for provider in cls.get_providers().values(): if isinstance(provider, ContextResource) and not provider.is_async: - stack.enter_context(provider.sync_context()) + stack.enter_context(provider.sync_context(force=force)) yield @classmethod @asynccontextmanager @override - async def async_context(cls) -> typing.AsyncIterator[None]: + async def async_context(cls, force: bool = False) -> typing.AsyncIterator[None]: async with AsyncExitStack() as stack: for container in cls.get_containers(): - await stack.enter_async_context(container.async_context()) + await stack.enter_async_context(container.async_context(force=force)) for provider in cls.get_providers().values(): if isinstance(provider, ContextResource): - await stack.enter_async_context(provider.async_context()) + await stack.enter_async_context(provider.async_context(force=force)) yield - @classmethod - @override - def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: - if inspect.iscoroutinefunction(func): - - async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - async with cls.async_context(): - return await func(*args, **kwargs) # type: ignore[no-any-return] - - return typing.cast(typing.Callable[P, T], _async_wrapper) - - def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - with cls.sync_context(): - return func(*args, **kwargs) - - return _sync_wrapper - @classmethod def connect_containers(cls, *containers: type["BaseContainer"]) -> None: """Connect containers. diff --git a/that_depends/injection.py b/that_depends/injection.py index 12d4a267..e7ea19ef 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -4,93 +4,115 @@ import warnings from that_depends.providers import AbstractProvider +from that_depends.providers.context_resources import ContextScope, ContextScopes, container_context P = typing.ParamSpec("P") T = typing.TypeVar("T") -def inject( - func: typing.Callable[P, T], -) -> typing.Callable[P, T]: - """Decorate a function to enable dependency injection. - - Args: - func: sync or async function with dependencies. - - Returns: - function that will resolve dependencies on call. - - - Example: - ```python - @inject - async def func(a: str = Provide[Container.a_provider]) -> str: - ... - ``` - - """ - if inspect.iscoroutinefunction(func): - return typing.cast(typing.Callable[P, T], _inject_to_async(func)) - - return _inject_to_sync(func) - - -def _inject_to_async( - func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], -) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]: - signature = inspect.signature(func) - - @functools.wraps(func) - async def inner(*args: P.args, **kwargs: P.kwargs) -> T: - injected = False - for i, (field_name, field_value) in enumerate(signature.parameters.items()): - if i < len(args): - continue - - if not isinstance(field_value.default, AbstractProvider): - continue - - if field_name in kwargs: - continue - - kwargs[field_name] = await field_value.default.async_resolve() - injected = True - if not injected: - warnings.warn( - "Expected injection, but nothing found. Remove @inject decorator.", RuntimeWarning, stacklevel=1 - ) - return await func(*args, **kwargs) +@typing.overload +def inject(func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... - return inner - -def _inject_to_sync( - func: typing.Callable[P, T], -) -> typing.Callable[P, T]: +@typing.overload +def inject( + *, + scope: ContextScope | None = ContextScopes.INJECT, +) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: ... + + +def inject( # noqa: C901 + func: typing.Callable[P, T] | None = None, scope: ContextScope | None = ContextScopes.INJECT +) -> typing.Callable[P, T] | typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: + """Inject dependencies into a function.""" + if scope == ContextScopes.ANY: + msg = f"{scope} is not allowed in inject decorator." + raise ValueError(msg) + + def _inject( + func: typing.Callable[P, T], + ) -> typing.Callable[P, T]: + if inspect.iscoroutinefunction(func): + return typing.cast(typing.Callable[P, T], _inject_to_async(func)) + + return _inject_to_sync(func) + + def _inject_to_async( + func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], + ) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]: + @functools.wraps(func) + async def inner(*args: P.args, **kwargs: P.kwargs) -> T: + if scope: + async with container_context(scope=scope): + return await _resolve_async(func, *args, **kwargs) + return await _resolve_async(func, *args, **kwargs) + + return inner + + def _inject_to_sync( + func: typing.Callable[P, T], + ) -> typing.Callable[P, T]: + @functools.wraps(func) + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + if scope: + with container_context(scope=scope): + return _resolve_sync(func, *args, **kwargs) + return _resolve_sync(func, *args, **kwargs) + + return inner + + if func: + return _inject(func) + + return _inject + + +def _resolve_sync(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + injected = False signature: typing.Final = inspect.signature(func) - - @functools.wraps(func) - def inner(*args: P.args, **kwargs: P.kwargs) -> T: - injected = False - for i, (field_name, field_value) in enumerate(signature.parameters.items()): - if i < len(args): - continue - if not isinstance(field_value.default, AbstractProvider): - continue - if field_name in kwargs: - continue - kwargs[field_name] = field_value.default.sync_resolve() - injected = True - - if not injected: - warnings.warn( - "Expected injection, but nothing found. Remove @inject decorator.", RuntimeWarning, stacklevel=1 - ) - - return func(*args, **kwargs) - - return inner + for i, (field_name, field_value) in enumerate(signature.parameters.items()): + if i < len(args): + continue + if not isinstance(field_value.default, AbstractProvider): + continue + if field_name in kwargs: + if isinstance(field_value.default, AbstractProvider): + injected = True + continue + kwargs[field_name] = field_value.default.sync_resolve() + injected = True + + if not injected: + warnings.warn("Expected injection, but nothing found. Remove @inject decorator.", RuntimeWarning, stacklevel=1) + + return func(*args, **kwargs) + + +async def _resolve_async( + func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], *args: P.args, **kwargs: P.kwargs +) -> T: + injected = False + signature = inspect.signature(func) + for i, (field_name, field_value) in enumerate(signature.parameters.items()): + if i < len(args): + if isinstance(field_value.default, AbstractProvider): + injected = True + continue + + if not isinstance(field_value.default, AbstractProvider): + continue + + if field_name in kwargs: + if isinstance(field_value.default, AbstractProvider): + injected = True + continue + + kwargs[field_name] = await field_value.default.async_resolve() + injected = True + if not injected: + warnings.warn("Expected injection, but nothing found. Remove @inject decorator.", RuntimeWarning, stacklevel=1) + return await func(*args, **kwargs) class ClassGetItemMeta(type): diff --git a/that_depends/meta.py b/that_depends/meta.py index 0315a6df..edb73207 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -1,5 +1,6 @@ import abc import typing +from collections.abc import MutableMapping from threading import Lock from typing_extensions import override @@ -9,6 +10,28 @@ from that_depends.container import BaseContainer +class DefaultScopeNotDefinedError(Exception): + """Exception raised when default_scope is not defined.""" + + +class _ContainerMetaDict(dict[str, typing.Any]): + """Implements custom logic for the container metaclass.""" + + @override + def __setitem__(self, key: str, value: typing.Any) -> None: + from that_depends.providers.context_resources import ContextResource, ContextScopes + + if isinstance(value, ContextResource) and value.get_scope() == ContextScopes.ANY: + try: + default_scope = self.__getitem__("default_scope") + super().__setitem__(key, value.with_config(default_scope)) + except KeyError as e: + msg = "Explicitly define default_scope before defining ContextResource providers." + raise DefaultScopeNotDefinedError(msg) from e + else: + super().__setitem__(key, value) + + class BaseContainerMeta(abc.ABCMeta): """Metaclass for BaseContainer.""" @@ -16,6 +39,11 @@ class BaseContainerMeta(abc.ABCMeta): _lock: Lock = Lock() + @classmethod + @override + def __prepare__(cls, name: str, bases: tuple[type, ...], /, **kwds: typing.Any) -> MutableMapping[str, object]: + return _ContainerMetaDict() + @override def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: new_cls = super().__new__(cls, name, bases, namespace) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index f546acd3..38937031 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -23,9 +23,9 @@ class AbstractProvider(typing.Generic[T_co], abc.ABC): """Base class for all providers.""" - def __init__(self) -> None: + def __init__(self, **kwargs: typing.Any) -> None: # noqa: ANN401 """Create a new provider.""" - super().__init__() + super().__init__(**kwargs) self._override: typing.Any = None def __deepcopy__(self, *_: object, **__: object) -> typing_extensions.Self: @@ -132,12 +132,14 @@ def __init__( Args: creator: sync or async iterator or context manager that yields resource. - *args: arguments to pass to the creator. + *args: positional arguments to pass to the creator. **kwargs: keyword arguments to pass to the creator. + """ super().__init__() self._creator: typing.Any + if inspect.isasyncgenfunction(creator): self.is_async = True self._creator = contextlib.asynccontextmanager(creator) @@ -153,9 +155,8 @@ def __init__( else: msg = "Unsupported resource type" raise TypeError(msg) - - self._args: typing.Final[P.args] = args - self._kwargs: typing.Final[P.kwargs] = kwargs + self._args: P.args = args + self._kwargs: P.kwargs = kwargs @abc.abstractmethod def _fetch_context(self) -> ResourceContext[T_co]: ... diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 83e54f82..3c138dc9 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -6,11 +6,11 @@ import threading import typing from abc import abstractmethod -from contextlib import AbstractAsyncContextManager, AbstractContextManager +from contextlib import AbstractAsyncContextManager, AbstractContextManager, contextmanager from contextvars import ContextVar, Token from functools import wraps from types import TracebackType -from typing import Final +from typing import Final, overload from typing_extensions import TypeIs, override @@ -26,7 +26,8 @@ logger: typing.Final = logging.getLogger(__name__) T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") -_CONTAINER_CONTEXT: typing.Final[ContextVar[dict[str, typing.Any]]] = ContextVar("CONTAINER_CONTEXT") +_CONTAINER_CONTEXT: typing.Final[ContextVar[dict[str, typing.Any]]] = ContextVar("__CONTAINER_CONTEXT__") + AppType = typing.TypeVar("AppType") Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] @@ -38,6 +39,45 @@ ContextType = dict[str, typing.Any] +class InvalidContextError(RuntimeError): + """Raised when an invalid context is being used.""" + + +class ContextScope: + """A named context scope.""" + + def __init__(self, name: str) -> None: + """Initialize a new context scope.""" + self._name = name + + @property + def name(self) -> str: + """Get the name of the context scope.""" + return self._name + + @override + def __eq__(self, other: object) -> bool: + if isinstance(other, ContextScope): + return self.name == other.name + return False + + @override + def __repr__(self) -> str: + return f"{self.name!r}" + + +class ContextScopes: + """Enumeration of context scopes.""" + + ANY = ContextScope("ANY") # special scope that can be used in any context + APP = ContextScope("APP") # application scope + REQUEST = ContextScope("REQUEST") # request scope + INJECT = ContextScope("INJECT") # inject scope + + +_CONTAINER_SCOPE: typing.Final[ContextVar[ContextScope | None]] = ContextVar("__CONTAINER_SCOPE__", default=None) + + def _get_container_context() -> dict[str, typing.Any]: try: return _CONTAINER_CONTEXT.get() @@ -46,6 +86,27 @@ def _get_container_context() -> dict[str, typing.Any]: raise RuntimeError(msg) from exc +def get_current_scope() -> ContextScope | None: + """Get the current context scope. + + Returns: + ContextScope | None: The current context scope. + + """ + return _CONTAINER_SCOPE.get() + + +def _set_current_scope(scope: ContextScope | None) -> Token[ContextScope | None]: + return _CONTAINER_SCOPE.set(scope) + + +@contextmanager +def _enter_named_scope(scope: ContextScope) -> typing.Iterator[ContextScope]: + token = _set_current_scope(scope) + yield scope + _CONTAINER_SCOPE.reset(token) + + def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # noqa: ANN401 """Retrieve a value from the global context. @@ -78,31 +139,16 @@ class SupportsContext(typing.Generic[CT], abc.ABC): """ @abstractmethod - def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: - """Wrap a function with a new context. - - The returned function will automatically initialize and tear down - the context whenever it is called. - - Args: - func (Callable[P, T]): The function to wrap. - - Returns: - Callable[P, T]: The wrapped function. - - Example: - ```python - @my_resource.context - def my_function(): - return do_something() - ``` - - """ + def get_scope(self) -> ContextScope | None: + """Return the scope of the resource.""" @abstractmethod - def async_context(self) -> typing.AsyncContextManager[CT]: + def async_context(self, force: bool = False) -> typing.AsyncContextManager[CT]: """Create an async context manager for this resource. + Args: + force (bool): If True, the context will be entered regardless of the current scope. + Returns: AsyncContextManager[CT]: An async context manager. @@ -115,9 +161,12 @@ def async_context(self) -> typing.AsyncContextManager[CT]: """ @abstractmethod - def sync_context(self) -> typing.ContextManager[CT]: + def sync_context(self, force: bool = False) -> typing.ContextManager[CT]: """Create a sync context manager for this resource. + Args: + force (bool): If True, the context will be entered regardless of the current scope. + Returns: ContextManager[CT]: A sync context manager. @@ -149,6 +198,26 @@ class ContextResource( and ensures they are properly torn down when the context exits. """ + @override + async def async_resolve(self) -> T_co: + current_scope = get_current_scope() + if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + return await super().async_resolve() + msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" + raise RuntimeError(msg) + + @override + def sync_resolve(self) -> T_co: + current_scope = get_current_scope() + if not self._strict_scope or self._scope in (ContextScopes.ANY, current_scope): + return super().sync_resolve() + msg = f"Cannot resolve resource with scope `{self._scope}` in scope `{current_scope}`" + raise RuntimeError(msg) + + @override + def get_scope(self) -> ContextScope | None: + return self._scope + __slots__ = ( "_args", "_context", @@ -156,6 +225,7 @@ class ContextResource( "_internal_name", "_kwargs", "_override", + "_scope", "_token", "is_async", ) @@ -171,30 +241,95 @@ def __init__( Args: creator (Callable[P, Iterator[T_co] | AsyncIterator[T_co]]): A sync or async iterator that yields the resource to be provided. - *args: Positional arguments to pass to the creator. - **kwargs: Keyword arguments to pass to the creator. + *args (P.args): Positional arguments to pass to the creator. + **kwargs (P.kwargs): Keyword arguments to pass to the creator. """ super().__init__(creator, *args, **kwargs) + self._from_creator = creator self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") self._token: Token[ResourceContext[T_co]] | None = None self._async_lock: Final = asyncio.Lock() self._lock: Final = threading.Lock() + self._scope: ContextScope | None = ContextScopes.ANY + self._strict_scope: bool = False + + @overload + def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... + + @overload + def context(self, *, force: bool = False) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: ... + + def context( + self, func: typing.Callable[P, T] | None = None, force: bool = False + ) -> typing.Callable[P, T] | typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: + """Create a new context manager for the resource, the context manager will be async if the resource is async. + + Returns: + typing.ContextManager[ResourceContext[T_co]] | typing.AsyncContextManager[ResourceContext[T_co]]: + A context manager for the resource. + + """ + + def _wrapper(func: typing.Callable[P, T]) -> typing.Callable[P, T]: + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with self.async_context(force=force): + return await func(*args, **kwargs) # type: ignore[no-any-return, misc] + + return typing.cast(typing.Callable[P, T], _async_wrapper) + + # wrapped function is sync + @wraps(func) + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with self.sync_context(force=force): + return func(*args, **kwargs) + + return typing.cast(typing.Callable[P, T], _sync_wrapper) + + if func: + return _wrapper(func) + return _wrapper + + def with_config(self, scope: ContextScope | None, strict_scope: bool = False) -> "ContextResource[T_co]": + """Create a new context-resource with the specified scope. + + Args: + scope: named scope where resource is resolvable. + strict_scope: if True, the resource will only be resolvable in the specified scope. + + Returns: + new context resource with the specified scope. + + """ + if strict_scope and scope == ContextScopes.ANY: + msg = f"Cannot set strict_scope with scope {scope}." + raise ValueError(msg) + r = ContextResource(self._from_creator, *self._args, **self._kwargs) + r._scope = scope # noqa: SLF001 + r._strict_scope = strict_scope # noqa: SLF001 + + return r @override def supports_sync_context(self) -> bool: return not self.is_async - def _enter_sync_context(self) -> ResourceContext[T_co]: + def _enter_sync_context(self, force: bool = False) -> ResourceContext[T_co]: if self.is_async: msg = "You must enter async context for async creators." raise RuntimeError(msg) - return self._enter() + return self._enter(force) - async def _enter_async_context(self) -> ResourceContext[T_co]: - return self._enter() + async def _enter_async_context(self, force: bool = False) -> ResourceContext[T_co]: + return self._enter(force) - def _enter(self) -> ResourceContext[T_co]: + def _enter(self, force: bool = False) -> ResourceContext[T_co]: + if not force and self._scope not in (ContextScopes.ANY, get_current_scope()): + msg = f"Cannot enter context for resource with scope {self._scope} in scope {get_current_scope()!r}" + raise InvalidContextError(msg) self._token = self._context.set(ResourceContext(is_async=self.is_async)) return self._context.get() @@ -225,13 +360,13 @@ async def _exit_async_context(self) -> None: @contextlib.contextmanager @override - def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: + def sync_context(self, force: bool = False) -> typing.Iterator[ResourceContext[T_co]]: if self.is_async: msg = "Please use async context instead." raise RuntimeError(msg) token = self._token with self._lock: - val = self._enter_sync_context() + val = self._enter_sync_context(force=force) temp_token = self._token yield val with self._lock: @@ -241,11 +376,11 @@ def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: @contextlib.asynccontextmanager @override - async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: + async def async_context(self, force: bool = False) -> typing.AsyncIterator[ResourceContext[T_co]]: token = self._token async with self._async_lock: - val = await self._enter_async_context() + val = await self._enter_async_context(force=force) temp_token = self._token yield val async with self._async_lock: @@ -253,32 +388,6 @@ async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: await self._exit_async_context() self._token = token - @override - def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: - """Create a new context manager for the resource, the context manager will be async if the resource is async. - - Returns: - typing.ContextManager[ResourceContext[T_co]] | typing.AsyncContextManager[ResourceContext[T_co]]: - A context manager for the resource. - - """ - if inspect.iscoroutinefunction(func): - - @wraps(func) - async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - async with self.async_context(): - return await func(*args, **kwargs) # type: ignore[no-any-return] - - return typing.cast(typing.Callable[P, T], _async_wrapper) - - # wrapped function is sync - @wraps(func) - def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - with self.sync_context(): - return func(*args, **kwargs) - - return typing.cast(typing.Callable[P, T], _sync_wrapper) - def _fetch_context(self) -> ResourceContext[T_co]: try: return self._context.get() @@ -304,6 +413,7 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex "_initial_context", "_context_token", "_reset_resource_context", + "_scope", ) def __init__( @@ -312,6 +422,7 @@ def __init__( global_context: ContextType | None = None, preserve_global_context: bool = False, reset_all_containers: bool = False, + scope: ContextScope | None = None, ) -> None: """Initialize a new container context. @@ -320,6 +431,7 @@ def __init__( global_context (dict[str, Any] | None): A dictionary representing the global context. preserve_global_context (bool): If True, merges the existing global context with the new one. reset_all_containers (bool): If True, creates a new context for all containers in this scope. + scope (ContextScope | None): The named scope that should be initialized. Example: ```python @@ -328,31 +440,46 @@ def __init__( ``` """ - if preserve_global_context and global_context: - self._initial_context = {**_get_container_context(), **global_context} - else: - self._initial_context: ContextType = ( # type: ignore[no-redef] - _get_container_context() if preserve_global_context else global_context or {} - ) + if scope == ContextScopes.ANY: + msg = f"{scope} cannot be entered!" + raise ValueError(msg) + self._scope = scope + self._preserve_global_context = preserve_global_context + self._global_context = global_context self._context_token: Token[ContextType] | None = None self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._reset_resource_context: typing.Final[bool] = ( not context_items and not global_context ) or reset_all_containers - if self._reset_resource_context: - self._add_providers_from_containers(BaseContainerMeta.get_instances()) - self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None + self._scope_token: Token[ContextScope | None] | None = None + + def _resolve_initial_conditions(self) -> None: + self._scope = self._scope if self._scope else get_current_scope() + if self._preserve_global_context and self._global_context: + self._initial_context = {**_get_container_context(), **self._global_context} + else: + self._initial_context: ContextType = ( # type: ignore[no-redef] + _get_container_context() if self._preserve_global_context else self._global_context or {} + ) + if self._reset_resource_context: # equivalent to reset_all_containers + self._add_providers_from_containers(BaseContainerMeta.get_instances(), self._scope) - def _add_providers_from_containers(self, containers: list[ContainerType]) -> None: + def _add_providers_from_containers( + self, containers: list[ContainerType], scope: ContextScope | None = ContextScopes.ANY + ) -> None: for container in containers: for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): - self._context_items.add(container_provider) + provider_scope = container_provider.get_scope() + if provider_scope in (scope, ContextScopes.ANY): + self._context_items.add(container_provider) @override def __enter__(self) -> ContextType: + self._resolve_initial_conditions() self._context_stack = contextlib.ExitStack() + self._scope_token = _set_current_scope(self._scope) for item in self._context_items: if item.supports_sync_context(): self._context_stack.enter_context(item.sync_context()) @@ -360,7 +487,9 @@ def __enter__(self) -> ContextType: @override async def __aenter__(self) -> ContextType: + self._resolve_initial_conditions() self._context_stack = contextlib.AsyncExitStack() + self._scope_token = _set_current_scope(self._scope) for item in self._context_items: await self._context_stack.enter_async_context(item.async_context()) return self._enter_globals() @@ -370,13 +499,19 @@ def _enter_globals(self) -> ContextType: return _CONTAINER_CONTEXT.get() def _is_context_token(self, _: Token[ContextType] | None) -> TypeIs[Token[ContextType]]: - return isinstance(_, Token) + return _ is not None + + def _is_scope_token(self, _: Token[ContextScope | None] | None) -> TypeIs[Token[ContextScope | None]]: + return _ is not None def _exit_globals(self) -> None: if self._is_context_token(self._context_token): - return _CONTAINER_CONTEXT.reset(self._context_token) - msg = "No context token set for global vars, use __enter__ or __aenter__ first." - raise RuntimeError(msg) + _CONTAINER_CONTEXT.reset(self._context_token) + else: + msg = "No context token set for global vars, use __enter__ or __aenter__ first." + raise RuntimeError(msg) + if self._is_scope_token(self._scope_token): + _CONTAINER_SCOPE.reset(self._scope_token) def _has_async_exit_stack( self, @@ -419,7 +554,7 @@ def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: """Decorate a function to run within this container context. The context is automatically initialized before the function is called and - torn down afterwards. + torn down afterward. Args: func (Callable[P, T_co]): A sync or async callable. @@ -467,6 +602,7 @@ def __init__( *context_items: SupportsContext[typing.Any], global_context: dict[str, typing.Any] | None = None, reset_all_containers: bool = False, + scope: ContextScope | None = None, ) -> None: """Initialize the DIContextMiddleware. @@ -476,6 +612,7 @@ def __init__( need context initialization prior to a request. global_context (dict[str, Any] | None): A global context dictionary to set before requests. reset_all_containers (bool): Whether to reset all containers in the current scope before the request. + scope (ContextScope | None): The scope in which the context should be initialized. Example: ```python @@ -487,6 +624,10 @@ def __init__( self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context: dict[str, typing.Any] | None = global_context self._reset_all_containers: bool = reset_all_containers + if scope == ContextScopes.ANY: + msg = f"{scope} cannot be entered!" + raise ValueError(msg) + self._scope = scope async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Handle the incoming ASGI request by initializing and tearing down context. @@ -495,7 +636,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: closed after the request is completed. Args: - scope (Scope): The ASGI scope. + scope (ContextScope): The ASGI scope. receive (Receive): The receive call. send (Send): The send call. @@ -503,11 +644,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: None """ - if self._context_items: - pass async with ( - container_context(*self._context_items, global_context=self._global_context) + container_context(*self._context_items, global_context=self._global_context, scope=self._scope) if self._context_items - else container_context(global_context=self._global_context, reset_all_containers=self._reset_all_containers) + else container_context( + global_context=self._global_context, reset_all_containers=self._reset_all_containers, scope=self._scope + ) ): return await self.app(scope, receive, send) diff --git a/that_depends/providers/factories.py b/that_depends/providers/factories.py index 01d1e93b..53549156 100644 --- a/that_depends/providers/factories.py +++ b/that_depends/providers/factories.py @@ -142,14 +142,15 @@ def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P Args: factory (Callable[P, Awaitable[T_co]]): Async function that returns the resource. - *args: Arguments to pass to the async factory function. - **kwargs: Keyword arguments to pass to the async factory function. + *args: Arguments to pass to the factory function. + **kwargs: Keyword arguments to pass to the factory + """ super().__init__() self._factory: typing.Final = factory - self._args: typing.Final = args - self._kwargs: typing.Final = kwargs + self._args: typing.Final[P.args] = args + self._kwargs: typing.Final[P.kwargs] = kwargs @override async def async_resolve(self) -> T_co: @@ -157,12 +158,8 @@ async def async_resolve(self) -> T_co: return typing.cast(T_co, self._override) return await self._factory( - *[ # type: ignore[arg-type] - await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) @override diff --git a/that_depends/providers/local_singleton.py b/that_depends/providers/local_singleton.py index ebe271ea..f552134a 100644 --- a/that_depends/providers/local_singleton.py +++ b/that_depends/providers/local_singleton.py @@ -45,15 +45,15 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P Args: factory: A callable that returns a new instance of the dependency. *args: Positional arguments to pass to the factory. - **kwargs: Keyword arguments to pass to the factory. + **kwargs: Keyword arguments to pass to the factory """ super().__init__() self._factory: typing.Final = factory - self._args: typing.Final = args - self._kwargs: typing.Final = kwargs self._thread_local = threading.local() self._asyncio_lock = asyncio.Lock() + self._args: typing.Final[P.args] = args + self._kwargs: typing.Final[P.kwargs] = kwargs @property def _instance(self) -> T_co | None: @@ -73,8 +73,8 @@ async def async_resolve(self) -> T_co: return self._instance self._instance = self._factory( - *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{ # type: ignore[arg-type] + *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{ k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() }, @@ -90,8 +90,8 @@ def sync_resolve(self) -> T_co: return self._instance self._instance = self._factory( - *[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] - **{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] + *[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) return self._instance diff --git a/that_depends/providers/singleton.py b/that_depends/providers/singleton.py index 5cd4332e..26fe4512 100644 --- a/that_depends/providers/singleton.py +++ b/that_depends/providers/singleton.py @@ -46,11 +46,11 @@ def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P """ super().__init__() self._factory: typing.Final = factory - self._args: typing.Final = args - self._kwargs: typing.Final = kwargs self._instance: T_co | None = None self._asyncio_lock: typing.Final = asyncio.Lock() self._threading_lock: typing.Final = threading.Lock() + self._args: typing.Final[P.args] = args + self._kwargs: typing.Final[P.kwargs] = kwargs @override async def async_resolve(self) -> T_co: @@ -66,10 +66,8 @@ async def async_resolve(self) -> T_co: return self._instance self._instance = self._factory( - *[ # type: ignore[arg-type] - await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] + *[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{ k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() }, @@ -90,12 +88,8 @@ def sync_resolve(self) -> T_co: return self._instance self._instance = self._factory( - *[ # type: ignore[arg-type] - x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args - ], - **{ # type: ignore[arg-type] - k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() - }, + *[x.sync_resolve() if isinstance(x, AbstractProvider) else x for x in self._args], + **{k: v.sync_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, ) return self._instance @@ -146,10 +140,10 @@ def __init__( """ super().__init__() self._factory: typing.Final[typing.Callable[P, typing.Awaitable[T_co]]] = factory - self._args: typing.Final[P.args] = args - self._kwargs: typing.Final[P.kwargs] = kwargs self._instance: T_co | None = None self._asyncio_lock: typing.Final = asyncio.Lock() + self._args: typing.Final[P.args] = args + self._kwargs: typing.Final[P.kwargs] = kwargs @override async def async_resolve(self) -> T_co: