Skip to content

Commit

Permalink
Bump to 0.12.0, async improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorecki committed Jan 2, 2025
1 parent cf43020 commit 689e475
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 49 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Change Log

## [0.12.0] - 2025-01-02

- Improved handling of async functions by `TransactionContext`. Will raise `TypeError` if sync middleware is used with async handler.

## [0.11.1] - 2024-05-29

- Fix type annotations in `ApplicationModule`
Expand Down
23 changes: 14 additions & 9 deletions examples/example6.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
from lato import Application, ApplicationModule, Command, TransactionContext
from lato.compositon import compose
from lato import Application, ApplicationModule, Command


class GetAllItemDetails(Command):
pass


pricing_module = ApplicationModule("pricing")


@pricing_module.handler(GetAllItemDetails)
def get_item_price(command: GetAllItemDetails):
prices = {'pencil': 1, 'pen': 2}
prices = {"pencil": 1, "pen": 2}
return prices


warehouse_module = ApplicationModule("warehouse")


@warehouse_module.handler(GetAllItemDetails)
def get_item_stock(command: GetAllItemDetails):
stocks = {'pencil': 100, 'pen': 80}
stocks = {"pencil": 100, "pen": 80}
return stocks


app = Application()
app.include_submodule(pricing_module)
app.include_submodule(warehouse_module)


@app.compose(GetAllItemDetails)
def compose_item_details(pricing, warehouse):
assert pricing == {'pencil': 1, 'pen': 2}
assert warehouse == {'pencil': 100, 'pen': 80}
assert pricing == {"pencil": 1, "pen": 2}
assert warehouse == {"pencil": 100, "pen": 80}

details = [dict(item_id=x, price=pricing[x], stock=warehouse[x]) for x in pricing.keys()]
details = [
dict(item_id=x, price=pricing[x], stock=warehouse[x]) for x in pricing.keys()
]
return details


assert app.execute(GetAllItemDetails()) == [
{'item_id': 'pencil', 'price': 1, 'stock': 100},
{'item_id': 'pen', 'price': 2, 'stock': 80}
{"item_id": "pencil", "price": 1, "stock": 100},
{"item_id": "pen", "price": 2, "stock": 80},
]
27 changes: 14 additions & 13 deletions lato/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lato.message import Event, Message
from lato.transaction_context import (
ComposerFunction,
MessageHandler,
MiddlewareFunction,
OnEnterTransactionContextCallback,
OnExitTransactionContextCallback,
Expand Down Expand Up @@ -147,14 +148,14 @@ async def execute_async(self, message: Message) -> Any:
:raises: ValueError: If no handlers are found for the message.
"""
async with self.transaction_context() as ctx:
result = await ctx.execute(message)
result = await ctx.execute_async(message)
return result

def emit(self, event: Event) -> dict[Callable, Any]:
def emit(self, event: Event) -> dict[MessageHandler, Any]:
"""Deprecated. Use `publish()` instead."""
return self.publish(event)

def publish(self, event: Event) -> dict[Callable, Any]:
def publish(self, event: Event) -> dict[MessageHandler, Any]:
"""
Publish an event by calling all handlers for that event.
Expand All @@ -165,7 +166,7 @@ def publish(self, event: Event) -> dict[Callable, Any]:
result = ctx.publish(event)
return result

async def publish_async(self, event: Event) -> dict[Callable, Any]:
async def publish_async(self, event: Event) -> dict[MessageHandler, Any]:
"""
Asynchronously publish an event by calling all handlers for that event.
Expand Down Expand Up @@ -251,35 +252,35 @@ def transaction_middleware(self, middleware_func):
Decorator for registering a middleware function to be called when executing a function in a transaction context
:param middleware_func:
:return: the decorated function
**Example:**
>>> from typing import Callable
>>> from lato import Application, TransactionContext
>>>
>>> app = Application()
>>> @app.transaction_middleware
... def middleware1(ctx: TransactionContext, call_next: Callable):
... ...
"""
self._transaction_middlewares.insert(0, middleware_func)
self._transaction_middlewares.append(middleware_func)
return middleware_func

def compose(self, alias):
"""
Decorator for composing results of handlers identified by an alias.
**Example:**
>>> from lato import Application, Command, TransactionContext
>>> class SomeCommand(Command):
... pass
>>>
>>> app = Application()
>>> @app.compose(SomeCommand)
... def middleware1(**kwargs):
... ...
Expand Down
88 changes: 63 additions & 25 deletions lato/transaction_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from lato.message import Message
from lato.types import HandlerAlias
from lato.utils import maybe_await

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,6 +176,20 @@ async def end_async(self, exception: Optional[Exception] = None):
else:
log.debug("Ended transaction")

def is_async_context_manager(self) -> bool:
"""
Determine if the transaction context requires `async with` context manager.
This method checks if either the `on_enter` or `on_exit` callbacks
associated with the transaction context are asynchronous.
:return: True if `on_enter` or `on_exit` are asynchronous callbacks, False otherwise.
"""
return any(
[asyncio.iscoroutinefunction(self._on_enter_transaction_context)]
+ [asyncio.iscoroutinefunction(self._on_exit_transaction_context)]
)

def iterate_handlers_for(self, alias: str):
yield from self._handlers_iterator(alias)

Expand All @@ -196,64 +211,87 @@ async def __aexit__(self, exc_type=None, exc_val=None, exc_tb=None):
if asyncio.iscoroutine(result):
await result

def _wrap_with_middlewares(self, handler_func):
p = handler_func
for middleware in self._middlewares:
if not asyncio.iscoroutinefunction(
middleware
) and asyncio.iscoroutinefunction(handler_func):
raise ValueError(
"Cannot use synchronous middleware with async handler",
middleware,
handler_func,
)

p = partial(middleware, self, p)
return p

def call(self, func: Callable, *func_args: Any, **func_kwargs: Any) -> Any:
"""Call a function with the arguments and keyword arguments.
Missing arguments will be resolved with the dependency provider.
If func is coroutine, or any of the middleware functions is coroutine, TypeError will be raised.
:param func: The function to call.
:param func_args: Positional arguments to pass to the function.
:param func_kwargs: Keyword arguments to pass to the function.
:return: The result of the function call.
"""
if asyncio.iscoroutinefunction(func):
raise TypeError(
f"Using async function ({func}) with {self.__class__.__name__}.call() is not allowed. Use call_async() instead."
)

self.dependency_provider.update(ctx=as_type(self, TransactionContext))

resolved_kwargs = self.dependency_provider.resolve_func_params(
func, func_args, func_kwargs
)
self.resolved_kwargs.update(resolved_kwargs)
p = partial(func, **resolved_kwargs)
wrapped_handler = self._wrap_with_middlewares(p)
result = wrapped_handler()
return result

call_next = partial(func, **resolved_kwargs)

for m in self._middlewares[::-1]:
# middleware is async, which is not allowed
if asyncio.iscoroutinefunction(m):
raise TypeError(
f"Using async middleware ({m}) with {self.__class__.__name__}.call() is not allowed. Use call_async() instead."
)

call_next = partial(m, self, call_next)

return call_next()

async def call_async(
self, func: Callable[..., Awaitable[Any]], *func_args: Any, **func_kwargs: Any
) -> Any:
"""Call an async function with the arguments and keyword arguments.
Missing arguments will be resolved with the dependency provider.
Edge cases:
- middlewares and func are sync - this will behave like call()
- middleware is sync, and call_next is async - will raise TypeError, as middleware will not be able to wait for call_next()
:param func: The function to call.
:param func_args: Positional arguments to pass to the function.
:param func_kwargs: Keyword arguments to pass to the function.
:return: The result of the function call.
"""

self.dependency_provider.update(ctx=as_type(self, TransactionContext))

resolved_kwargs = self.dependency_provider.resolve_func_params(
func, func_args, func_kwargs
)
self.resolved_kwargs.update(resolved_kwargs)
p = partial(func, **resolved_kwargs)
wrapped_handler = self._wrap_with_middlewares(p)
result = wrapped_handler()
if asyncio.iscoroutine(result):
result = await result
return result

call_next = partial(func, **resolved_kwargs)

for m in self._middlewares[::-1]:
if asyncio.iscoroutinefunction(m) and not asyncio.iscoroutinefunction(
call_next
):
# async middleware is expecting an awaitable, so we need convert call_next to async
call_next = partial(maybe_await, call_next)

if not asyncio.iscoroutinefunction(m) and asyncio.iscoroutinefunction(
call_next
):
# middleware is not able to retrieve call_next, as call_next is awaitable
raise TypeError(
f"Using sync middleware ({m}) with async call_next ({call_next}) is not allowed."
)
call_next = partial(m, self, call_next)

if asyncio.iscoroutinefunction(call_next):
return await call_next()
else:
return call_next()

def execute(self, message: Message) -> tuple[Any, ...]:
"""Executes all handlers bound to the message. Returns a tuple of handlers' return values.
Expand Down
8 changes: 8 additions & 0 deletions lato/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import re
from collections import OrderedDict
from typing import TypeVar
Expand Down Expand Up @@ -29,3 +30,10 @@ def string_to_kwarg_name(string):
valid_string = "_" + valid_string

return valid_string


async def maybe_await(func, *args, **kwargs):
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "lato"
version = "0.11.1"
version = "0.12.0"
description = "Lato is a Python microframework designed for building modular monoliths and loosely coupled applications."
authors = ["Przemysław Górecki <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_application_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,6 @@ async def async_foo():
await asyncio.sleep(0.001)
return 1

with pytest.raises(ValueError):
with pytest.raises(TypeError):
# cannot use synchronous middleware with async handler
await app.call_async("async_foo")
Loading

0 comments on commit 689e475

Please sign in to comment.