diff --git a/CHANGELOG.md b/CHANGELOG.md index 904b543..7d7f7d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Change Log +## [0.12.0] - 2025-01-02 + +- Improved handling of async functions by `TransactionContext`. Will raise `TypeError` if sync middleware is used with async handler. + ## [0.11.1] - 2024-05-29 - Fix type annotations in `ApplicationModule` diff --git a/examples/example6.py b/examples/example6.py index 0b5773e..b1081ef 100644 --- a/examples/example6.py +++ b/examples/example6.py @@ -1,5 +1,5 @@ -from lato import Application, ApplicationModule, Command, TransactionContext -from lato.compositon import compose +from lato import Application, ApplicationModule, Command + class GetAllItemDetails(Command): pass @@ -7,17 +7,19 @@ class GetAllItemDetails(Command): pricing_module = ApplicationModule("pricing") + @pricing_module.handler(GetAllItemDetails) def get_item_price(command: GetAllItemDetails): - prices = {'pencil': 1, 'pen': 2} + prices = {"pencil": 1, "pen": 2} return prices warehouse_module = ApplicationModule("warehouse") + @warehouse_module.handler(GetAllItemDetails) def get_item_stock(command: GetAllItemDetails): - stocks = {'pencil': 100, 'pen': 80} + stocks = {"pencil": 100, "pen": 80} return stocks @@ -25,16 +27,19 @@ def get_item_stock(command: GetAllItemDetails): app.include_submodule(pricing_module) app.include_submodule(warehouse_module) + @app.compose(GetAllItemDetails) def compose_item_details(pricing, warehouse): - assert pricing == {'pencil': 1, 'pen': 2} - assert warehouse == {'pencil': 100, 'pen': 80} + assert pricing == {"pencil": 1, "pen": 2} + assert warehouse == {"pencil": 100, "pen": 80} - details = [dict(item_id=x, price=pricing[x], stock=warehouse[x]) for x in pricing.keys()] + details = [ + dict(item_id=x, price=pricing[x], stock=warehouse[x]) for x in pricing.keys() + ] return details assert app.execute(GetAllItemDetails()) == [ - {'item_id': 'pencil', 'price': 1, 'stock': 100}, - {'item_id': 'pen', 'price': 2, 'stock': 80} + {"item_id": "pencil", "price": 1, "stock": 100}, + {"item_id": "pen", "price": 2, "stock": 80}, ] diff --git a/lato/application.py b/lato/application.py index 1377707..f5cc3cd 100644 --- a/lato/application.py +++ b/lato/application.py @@ -7,6 +7,7 @@ from lato.message import Event, Message from lato.transaction_context import ( ComposerFunction, + MessageHandler, MiddlewareFunction, OnEnterTransactionContextCallback, OnExitTransactionContextCallback, @@ -147,14 +148,14 @@ async def execute_async(self, message: Message) -> Any: :raises: ValueError: If no handlers are found for the message. """ async with self.transaction_context() as ctx: - result = await ctx.execute(message) + result = await ctx.execute_async(message) return result - def emit(self, event: Event) -> dict[Callable, Any]: + def emit(self, event: Event) -> dict[MessageHandler, Any]: """Deprecated. Use `publish()` instead.""" return self.publish(event) - def publish(self, event: Event) -> dict[Callable, Any]: + def publish(self, event: Event) -> dict[MessageHandler, Any]: """ Publish an event by calling all handlers for that event. @@ -165,7 +166,7 @@ def publish(self, event: Event) -> dict[Callable, Any]: result = ctx.publish(event) return result - async def publish_async(self, event: Event) -> dict[Callable, Any]: + async def publish_async(self, event: Event) -> dict[MessageHandler, Any]: """ Asynchronously publish an event by calling all handlers for that event. @@ -251,35 +252,35 @@ def transaction_middleware(self, middleware_func): Decorator for registering a middleware function to be called when executing a function in a transaction context :param middleware_func: :return: the decorated function - + **Example:** - + >>> from typing import Callable >>> from lato import Application, TransactionContext >>> >>> app = Application() - + >>> @app.transaction_middleware ... def middleware1(ctx: TransactionContext, call_next: Callable): ... ... - + """ - self._transaction_middlewares.insert(0, middleware_func) + self._transaction_middlewares.append(middleware_func) return middleware_func def compose(self, alias): """ Decorator for composing results of handlers identified by an alias. - + **Example:** - + >>> from lato import Application, Command, TransactionContext - + >>> class SomeCommand(Command): ... pass >>> >>> app = Application() - + >>> @app.compose(SomeCommand) ... def middleware1(**kwargs): ... ... diff --git a/lato/transaction_context.py b/lato/transaction_context.py index 440453e..6868f30 100644 --- a/lato/transaction_context.py +++ b/lato/transaction_context.py @@ -14,6 +14,7 @@ ) from lato.message import Message from lato.types import HandlerAlias +from lato.utils import maybe_await log = logging.getLogger(__name__) @@ -175,6 +176,20 @@ async def end_async(self, exception: Optional[Exception] = None): else: log.debug("Ended transaction") + def is_async_context_manager(self) -> bool: + """ + Determine if the transaction context requires `async with` context manager. + + This method checks if either the `on_enter` or `on_exit` callbacks + associated with the transaction context are asynchronous. + + :return: True if `on_enter` or `on_exit` are asynchronous callbacks, False otherwise. + """ + return any( + [asyncio.iscoroutinefunction(self._on_enter_transaction_context)] + + [asyncio.iscoroutinefunction(self._on_exit_transaction_context)] + ) + def iterate_handlers_for(self, alias: str): yield from self._handlers_iterator(alias) @@ -196,40 +211,41 @@ async def __aexit__(self, exc_type=None, exc_val=None, exc_tb=None): if asyncio.iscoroutine(result): await result - def _wrap_with_middlewares(self, handler_func): - p = handler_func - for middleware in self._middlewares: - if not asyncio.iscoroutinefunction( - middleware - ) and asyncio.iscoroutinefunction(handler_func): - raise ValueError( - "Cannot use synchronous middleware with async handler", - middleware, - handler_func, - ) - - p = partial(middleware, self, p) - return p - def call(self, func: Callable, *func_args: Any, **func_kwargs: Any) -> Any: """Call a function with the arguments and keyword arguments. Missing arguments will be resolved with the dependency provider. + If func is coroutine, or any of the middleware functions is coroutine, TypeError will be raised. + :param func: The function to call. :param func_args: Positional arguments to pass to the function. :param func_kwargs: Keyword arguments to pass to the function. :return: The result of the function call. """ + if asyncio.iscoroutinefunction(func): + raise TypeError( + f"Using async function ({func}) with {self.__class__.__name__}.call() is not allowed. Use call_async() instead." + ) + self.dependency_provider.update(ctx=as_type(self, TransactionContext)) resolved_kwargs = self.dependency_provider.resolve_func_params( func, func_args, func_kwargs ) self.resolved_kwargs.update(resolved_kwargs) - p = partial(func, **resolved_kwargs) - wrapped_handler = self._wrap_with_middlewares(p) - result = wrapped_handler() - return result + + call_next = partial(func, **resolved_kwargs) + + for m in self._middlewares[::-1]: + # middleware is async, which is not allowed + if asyncio.iscoroutinefunction(m): + raise TypeError( + f"Using async middleware ({m}) with {self.__class__.__name__}.call() is not allowed. Use call_async() instead." + ) + + call_next = partial(m, self, call_next) + + return call_next() async def call_async( self, func: Callable[..., Awaitable[Any]], *func_args: Any, **func_kwargs: Any @@ -237,23 +253,45 @@ async def call_async( """Call an async function with the arguments and keyword arguments. Missing arguments will be resolved with the dependency provider. + Edge cases: + - middlewares and func are sync - this will behave like call() + - middleware is sync, and call_next is async - will raise TypeError, as middleware will not be able to wait for call_next() + :param func: The function to call. :param func_args: Positional arguments to pass to the function. :param func_kwargs: Keyword arguments to pass to the function. :return: The result of the function call. """ + self.dependency_provider.update(ctx=as_type(self, TransactionContext)) resolved_kwargs = self.dependency_provider.resolve_func_params( func, func_args, func_kwargs ) self.resolved_kwargs.update(resolved_kwargs) - p = partial(func, **resolved_kwargs) - wrapped_handler = self._wrap_with_middlewares(p) - result = wrapped_handler() - if asyncio.iscoroutine(result): - result = await result - return result + + call_next = partial(func, **resolved_kwargs) + + for m in self._middlewares[::-1]: + if asyncio.iscoroutinefunction(m) and not asyncio.iscoroutinefunction( + call_next + ): + # async middleware is expecting an awaitable, so we need convert call_next to async + call_next = partial(maybe_await, call_next) + + if not asyncio.iscoroutinefunction(m) and asyncio.iscoroutinefunction( + call_next + ): + # middleware is not able to retrieve call_next, as call_next is awaitable + raise TypeError( + f"Using sync middleware ({m}) with async call_next ({call_next}) is not allowed." + ) + call_next = partial(m, self, call_next) + + if asyncio.iscoroutinefunction(call_next): + return await call_next() + else: + return call_next() def execute(self, message: Message) -> tuple[Any, ...]: """Executes all handlers bound to the message. Returns a tuple of handlers' return values. diff --git a/lato/utils.py b/lato/utils.py index 8021636..d9b2bf0 100644 --- a/lato/utils.py +++ b/lato/utils.py @@ -1,3 +1,4 @@ +import asyncio import re from collections import OrderedDict from typing import TypeVar @@ -29,3 +30,10 @@ def string_to_kwarg_name(string): valid_string = "_" + valid_string return valid_string + + +async def maybe_await(func, *args, **kwargs): + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 6f16ed9..a6375a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lato" -version = "0.11.1" +version = "0.12.0" description = "Lato is a Python microframework designed for building modular monoliths and loosely coupled applications." authors = ["Przemysław Górecki "] readme = "README.md" diff --git a/tests/test_application_async.py b/tests/test_application_async.py index 17cd3ce..f74a6a9 100644 --- a/tests/test_application_async.py +++ b/tests/test_application_async.py @@ -190,6 +190,6 @@ async def async_foo(): await asyncio.sleep(0.001) return 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): # cannot use synchronous middleware with async handler await app.call_async("async_foo") diff --git a/tests/test_transaction_context_async.py b/tests/test_transaction_context_async.py new file mode 100644 index 0000000..4a69b04 --- /dev/null +++ b/tests/test_transaction_context_async.py @@ -0,0 +1,169 @@ +import asyncio +import itertools + +import pytest + +from lato import TransactionContext + + +def enter_hook_fn(ctx): + ctx["call_log"].append("sync_enter") + + +async def async_enter_hook_fn(ctx): + await asyncio.sleep(0.02) + ctx["call_log"].append("async_enter") + + +def exit_hook_fn(ctx, exc=None): + ctx["call_log"].append("sync_exit") + + +async def async_exit_hook_fn(ctx, exc=None): + await asyncio.sleep(0.01) + ctx["call_log"].append("async_exit") + + +def middleware1_fn(ctx, call_next): + ctx["call_log"].append("sync_middleware1_enter") + result = call_next() + ctx["call_log"].append(f"sync_middleware1_exit with {type(result).__name__}") + return f"1:{result}:1" + + +async def async_middleware1_fn(ctx, call_next): + await asyncio.sleep(0.01) + ctx["call_log"].append("async_middleware1_enter") + await asyncio.sleep(0.01) + result = await call_next() + await asyncio.sleep(0.01) + ctx["call_log"].append(f"async_middleware1_exit with {type(result).__name__}") + await asyncio.sleep(0.01) + return f"1:{result}:1" + + +def middleware2_fn(ctx, call_next): + ctx["call_log"].append("sync_middleware2_enter") + result = call_next() + ctx["call_log"].append(f"sync_middleware2_exit with {type(result).__name__}") + return f"2:{result}:2" + + +async def async_middleware2_fn(ctx, call_next): + await asyncio.sleep(0.01) + ctx["call_log"].append("async_middleware2_enter") + await asyncio.sleep(0.01) + result = await call_next() + await asyncio.sleep(0.01) + ctx["call_log"].append(f"async_middleware2_exit with {type(result).__name__}") + await asyncio.sleep(0.01) + return f"2:{result}:2" + + +def sync_handler(text, call_log): + call_log.append("sync_handler") + return text.upper() + + +async def async_handler(text, call_log): + call_log.append("async_handler") + return text.upper() + + +async def run(use_async_context_manager, use_call_async, ctx, handler_fn): + if use_async_context_manager: + async with ctx: + if use_call_async: + result = await ctx.call_async(handler_fn, "foo") + else: + result = ctx.call(handler_fn, "foo") + + else: + with ctx: + if use_call_async: + result = await ctx.call_async(handler_fn, "foo") + else: + result = ctx.call(handler_fn, "foo") + return result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "use_async_context_manager, use_call_async, enter_type, exit_type, middleware1_type, middleware2_type, handler_type", + list( + itertools.product( + [True, False], + [True, False], + ["sync", "async"], # Enter hook types + ["sync", "async"], # Exit hook types + ["sync", "async"], # Middleware 1 types + ["sync", "async"], # Middleware 2 types + ["sync", "async"], # Handler types + ) + ), +) +async def test_transaction_context_async( + use_async_context_manager, + use_call_async, + enter_type, + exit_type, + middleware1_type, + middleware2_type, + handler_type, +): + call_log = [] + ctx = TransactionContext(call_log=call_log) + ctx.configure( + on_enter_transaction_context=enter_hook_fn + if enter_type == "sync" + else async_enter_hook_fn, + on_exit_transaction_context=exit_hook_fn + if exit_type == "sync" + else async_exit_hook_fn, + middlewares=[ + middleware1_fn if middleware1_type == "sync" else async_middleware1_fn, + middleware2_fn if middleware2_type == "sync" else async_middleware2_fn, + ], + ) + use_async_context_manager = any( + x == "async" + for x in [ + enter_type, + exit_type, + middleware1_type, + middleware2_type, + handler_type, + ] + ) + + handler_fn = sync_handler if handler_type == "sync" else async_handler + + should_raise_exception = any( + [ + (middleware1_type == "async" or middleware2_type == "async") + and not use_call_async, # using call_async is required with async middleware + middleware2_type == "sync" + and handler_type + == "async", # using sync middleware with async handler is forbidden + middleware1_type == "sync" + and middleware2_type + == "async", # using sync middleware with async call_next is forbidden + ] + ) + + if should_raise_exception: + with pytest.raises(TypeError): + await run(use_async_context_manager, use_call_async, ctx, handler_fn) + else: + result = await run(use_async_context_manager, use_call_async, ctx, handler_fn) + + assert result == "1:2:FOO:2:1" + assert call_log == [ + f"{enter_type}_enter", + f"{middleware1_type}_middleware1_enter", + f"{middleware2_type}_middleware2_enter", + f"{handler_type}_handler", + f"{middleware2_type}_middleware2_exit with str", + f"{middleware1_type}_middleware1_exit with str", + f"{exit_type}_exit", + ]