Skip to content

Commit

Permalink
fix mypy errors, make lato compatibile with 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorecki committed Mar 4, 2024
1 parent 45a6598 commit c4d7a63
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
include:
- os: Ubuntu
python-version: pypy-3.8
Expand Down
2 changes: 1 addition & 1 deletion lato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

def add_stderr_logger(
level: int = logging.DEBUG,
) -> logging.StreamHandler[typing.TextIO]:
) -> logging.StreamHandler:
"""
Helper for quickly adding a StreamHandler to the logger. Useful for
debugging.
Expand Down
16 changes: 8 additions & 8 deletions lato/application.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from typing import Any

from typing import Any, Optional, Union, List
from lato.types import DependencyIdentifier
from lato.application_module import ApplicationModule
from lato.dependency_provider import BasicDependencyProvider, DependencyProvider
from lato.message import Command, Event, Message
Expand All @@ -24,7 +24,7 @@ class Application(ApplicationModule):
def __init__(
self,
name=__name__,
dependency_provider: DependencyProvider | None = None,
dependency_provider: Optional[DependencyProvider] = None,
**kwargs,
):
"""Initialize the application instance.
Expand All @@ -42,10 +42,10 @@ def __init__(
self._transaction_context_factory = None
self._on_enter_transaction_context = lambda ctx: None
self._on_exit_transaction_context = lambda ctx, exception=None: None
self._transaction_middlewares = []
self._composers: dict[str | Command, Callable] = {}
self._transaction_middlewares: List[Callable] = []
self._composers: dict[Union[Message, str], Callable] = {}

def get_dependency(self, identifier: str | type) -> Any:
def get_dependency(self, identifier: DependencyIdentifier) -> Any:
"""Gets a dependency from the dependency provider. Dependencies can be resolved either by name or by type.
:param identifier: A string or a type representing the dependency.
Expand All @@ -54,10 +54,10 @@ def get_dependency(self, identifier: str | type) -> Any:
"""
return self.dependency_provider.get_dependency(identifier)

def __getitem__(self, identifier: str | type) -> Any:
def __getitem__(self, identifier: DependencyIdentifier) -> Any:
return self.get_dependency(identifier)

def call(self, func: Callable | str, *args, **kwargs):
def call(self, func: Union[Callable, str], *args, **kwargs):
"""Invokes a function with `args` and `kwargs` within the :class:`TransactionContext`.
If `func` is a string, then it is an alias, and the corresponding handler for the alias is retrieved.
Any missing arguments are provided by the dependency provider of a transaction context,
Expand Down
14 changes: 7 additions & 7 deletions lato/application_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from collections import defaultdict
from collections.abc import Callable

from lato.message import Message
from lato.types import HandlerAlias
from lato.utils import OrderedSet

log = logging.getLogger(__name__)
Expand All @@ -28,7 +28,7 @@ def include_submodule(self, a_module: "ApplicationModule"):
), f"Can only include {ApplicationModule} instances, got {a_module}"
self._submodules.add(a_module)

def handler(self, alias: type[Message] | str):
def handler(self, alias: HandlerAlias):
"""
Decorator for registering a handler. Handler can be aliased by a name or by a message type.
Expand Down Expand Up @@ -65,12 +65,12 @@ def handler(self, alias: type[Message] | str):
command handler called
"""
try:
is_message = issubclass(alias, Message)
except TypeError:
is_message = False
if isinstance(alias, type):
is_message_type = issubclass(alias, Message)
else:
is_message_type = False

if callable(alias) and not is_message:
if callable(alias) and not is_message_type:
# decorator was called without any argument
func = alias
alias = func.__name__
Expand Down
2 changes: 1 addition & 1 deletion lato/compositon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from operator import add, or_
from typing import Any, Optional

from mergedeep import Strategy, merge
from mergedeep import Strategy, merge # type: ignore

additive_merge = partial(merge, strategy=Strategy.TYPESAFE_ADDITIVE)

Expand Down
13 changes: 7 additions & 6 deletions lato/dependency_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
from typing import Any

from lato.types import DependencyIdentifier
from lato.utils import OrderedDict


Expand Down Expand Up @@ -47,7 +48,7 @@ class DependencyProvider(ABC):
allow_types = True

@abstractmethod
def has_dependency(self, identifier: str | type) -> bool:
def has_dependency(self, identifier: DependencyIdentifier) -> bool:
"""
Check if a dependency with the given identifier exists.
Expand All @@ -57,7 +58,7 @@ def has_dependency(self, identifier: str | type) -> bool:
raise NotImplementedError()

@abstractmethod
def register_dependency(self, identifier: str | type, dependency: Any):
def register_dependency(self, identifier: DependencyIdentifier, dependency: Any):
"""
Register a dependency with a given identifier (name or type).
Expand All @@ -66,7 +67,7 @@ def register_dependency(self, identifier: str | type, dependency: Any):
"""
raise NotImplementedError()

def get_dependency(self, identifier: str | type) -> Any:
def get_dependency(self, identifier: DependencyIdentifier) -> Any:
"""
Retrieve a dependency using its identifier (name or type).
Expand Down Expand Up @@ -205,7 +206,7 @@ def __init__(self, *args, **kwargs):
self._dependencies = {}
self.update(*args, **kwargs)

def register_dependency(self, identifier: str | type, dependency: Any):
def register_dependency(self, identifier: DependencyIdentifier, dependency: Any):
"""
Register a dependency with a given identifier (name or type).
Expand All @@ -217,7 +218,7 @@ def register_dependency(self, identifier: str | type, dependency: Any):

self._dependencies[identifier] = dependency

def has_dependency(self, identifier: str | type) -> bool:
def has_dependency(self, identifier: DependencyIdentifier) -> bool:
"""
Check if a dependency with the given identifier exists.
Expand All @@ -226,7 +227,7 @@ def has_dependency(self, identifier: str | type) -> bool:
"""
return identifier in self._dependencies

def get_dependency(self, identifier: str | type) -> Any:
def get_dependency(self, identifier: DependencyIdentifier) -> Any:
"""
Retrieve a dependency using its identifier (name or type).
Expand Down
15 changes: 8 additions & 7 deletions lato/testing.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
import contextlib
from typing import Iterator

from lato import Application


@contextlib.contextmanager
def override_app(application: Application, *args, **kwargs) -> Application:
def override_app(application: Application, *args, **kwargs) -> Iterator[Application]:
original_dependency_provider = application.dependency_provider
overriden_dependency_provider = original_dependency_provider.copy(*args, **kwargs)
overridden_dependency_provider = original_dependency_provider.copy(*args, **kwargs)

application.dependency_provider = overriden_dependency_provider
application.dependency_provider = overridden_dependency_provider
yield application

application.dependency_provider = original_dependency_provider


@contextlib.contextmanager
def override_ctx(application: Application, *args, **kwargs) -> Application:
def override_ctx(application: Application, *args, **kwargs) -> Iterator[Application]:
original_transaction_context = application.transaction_context

def overriden_transaction_context(**dependencies):
ctx = original_transaction_context(**dependencies)
ctx.dependency_provider = ctx.dependency_provider.copy(*args, **kwargs)
return ctx

application.transaction_context = overriden_transaction_context
application.transaction_context = overriden_transaction_context # type: ignore
yield application

application.transaction_context = original_transaction_context
application.transaction_context = original_transaction_context # type: ignore


@contextlib.contextmanager
def override(application: Application, *args, **kwargs) -> Application:
def override(application: Application, *args, **kwargs) -> Iterator[Application]:
with override_app(application, **kwargs) as overridden1:
with override_ctx(overridden1, **kwargs) as overridden2:
yield overridden2
24 changes: 12 additions & 12 deletions lato/transaction_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from collections import OrderedDict
from collections.abc import Callable, Iterator
from functools import partial
from typing import Any, NewType, Optional
from typing import Any, NewType, Optional, Union

from lato.types import HandlerAlias
from lato.compositon import compose
from lato.dependency_provider import (
BasicDependencyProvider,
Expand All @@ -12,7 +13,6 @@
)
from lato.message import Message

Alias = NewType("Alias", Any)

log = logging.getLogger(__name__)

Expand All @@ -38,7 +38,7 @@ class TransactionContext:
dependency_provider_factory = BasicDependencyProvider

def __init__(
self, dependency_provider: DependencyProvider | None = None, *args, **kwargs
self, dependency_provider: Optional[DependencyProvider] = None, *args, **kwargs
):
"""Initialize the transaction context instance.
Expand All @@ -56,8 +56,8 @@ def __init__(
self._on_enter_transaction_context = lambda ctx: None
self._on_exit_transaction_context = lambda ctx, exception=None: None
self._middlewares: list[Callable] = []
self._composers: dict[str | Message, Callable] = {}
self._handlers_iterator: Iterator = lambda alias: iter([])
self._composers: dict[HandlerAlias, Callable] = {}
self._handlers_iterator: Callable[[HandlerAlias], Iterator[Callable]] = lambda alias: iter([])

def configure(
self,
Expand Down Expand Up @@ -96,7 +96,7 @@ def begin(self):
"""Should be used to start a transaction"""
self._on_enter_transaction_context(self)

def end(self, exception: Exception = None):
def end(self, exception: Optional[Exception] = None):
"""Ends the transaction context by calling `on_exit_transaction_context` callback,
optionally passing an exception.
Expand Down Expand Up @@ -163,26 +163,26 @@ def execute(self, message: Message) -> tuple[Any, ...]:
composed_result = self._compose_results(message, values)
return composed_result

def emit(self, message: str | Message, *args, **kwargs) -> dict[Callable, Any]:
def emit(self, message: Union[str, Message], *args, **kwargs) -> dict[Callable, Any]:
# TODO: mark as obsolete
return self.publish(message, *args, **kwargs)

def publish(self, message: str | Message, *args, **kwargs) -> dict[Callable, Any]:
def publish(self, message: Union[str, Message], *args, **kwargs) -> dict[Callable, Any]:
"""
Publish a message by calling all handlers for that message.
:param message: The message object to publish, or a string.
:param message: The message object to publish, or an alias of a handler to call.
:param args: Positional arguments to pass to the handlers.
:param kwargs: Keyword arguments to pass to the handlers.
:return: A dictionary mapping handlers to their results.
"""
alias = type(message) if isinstance(message, Message) else message
message_type = type(message) if isinstance(message, Message) else message

if isinstance(message, Message):
args = (message, *args)

all_results = OrderedDict()
for handler in self._handlers_iterator(alias):
for handler in self._handlers_iterator(message_type): # type: ignore
self.set_dependency("message", message)
# FIXME: push and pop current action instead of setting it
self.current_handler = handler
Expand Down Expand Up @@ -214,4 +214,4 @@ def _compose_results(self, message: Message, results: tuple[Any, ...]) -> Any:
@property
def current_action(self) -> tuple[Message, Callable]:
"""Returns current message and handler being executed"""
return self.get_dependency("message"), self.current_handler
return self.get_dependency("message"), self.current_handler # type: ignore
5 changes: 5 additions & 0 deletions lato/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Union
from lato.message import Message

HandlerAlias = Union[type[Message], str]
DependencyIdentifier = Union[type, str]
26 changes: 14 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ license = "MIT"
"Bug Tracker" = "https://github.com/pgorecki/lato/issues"

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
pytest = "^7.4.2"
pydantic = "^2.4.2"
mergedeep = "^1.3.4"
Expand Down

0 comments on commit c4d7a63

Please sign in to comment.