From 18088ef45ab5fc37d90e8dad4ac5164130d72c48 Mon Sep 17 00:00:00 2001 From: Pavel Perestoronin Date: Thu, 25 Jan 2024 13:08:25 +0100 Subject: [PATCH] =?UTF-8?q?NEW:=20better=20typing=20for=20hooks=20?= =?UTF-8?q?=F0=9F=A7=91=E2=80=8D=F0=9F=92=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- time_execution/decorator.py | 31 ++++++++++++++++++++++++------- time_execution/timed.py | 11 ++++++----- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/time_execution/decorator.py b/time_execution/decorator.py index 53fd8df..857b718 100755 --- a/time_execution/decorator.py +++ b/time_execution/decorator.py @@ -1,18 +1,20 @@ -""" -Time Execution decorator -""" +"""Time Execution decorator""" + +from __future__ import annotations + from asyncio import iscoroutinefunction +from collections.abc import Iterable from functools import wraps -from typing import Any, Callable, List, Optional, TypeVar, cast +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast import fqn_decorators from pkgsettings import Settings -from typing_extensions import overload +from typing_extensions import Protocol, overload _F = TypeVar("_F", bound=Callable[..., Any]) settings = Settings() -settings.configure(backends=[], hooks=[], duration_field="value") +settings.configure(backends=(), hooks=(), duration_field="value") def write_metric(name: str, **metric: Any) -> None: @@ -29,7 +31,7 @@ def time_execution(__wrapped: _F) -> _F: def time_execution( *, get_fqn: Callable[[Any], str] = fqn_decorators.get_fqn, - extra_hooks: Optional[List] = None, + extra_hooks: Optional[Iterable[Hook]] = None, disable_default_hooks: bool = False, ) -> Callable[[_F], _F]: """ @@ -74,3 +76,18 @@ async def wrapper(*call_args, **call_kwargs): # `time_execution` supports async out of the box. time_execution_async = time_execution + + +class Hook(Protocol): + """Hook callback protocol.""" + + def __call__( + self, + response: Any, + exception: Optional[BaseException], + metric: Dict[str, Any], + func: Callable[..., Any], + func_args: Tuple[Any, ...], + func_kwargs: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + ... diff --git a/time_execution/timed.py b/time_execution/timed.py index 11a6d38..e7a5092 100644 --- a/time_execution/timed.py +++ b/time_execution/timed.py @@ -1,12 +1,13 @@ from __future__ import annotations +from collections.abc import Iterable from contextlib import AbstractContextManager from socket import gethostname from timeit import default_timer from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type -from time_execution import settings, write_metric +from time_execution import Hook, settings, write_metric SHORT_HOSTNAME = gethostname() @@ -35,7 +36,7 @@ def __init__( fqn: str, call_args: Tuple[Any, ...], call_kwargs: Dict[str, Any], - extra_hooks: Optional[List] = None, + extra_hooks: Optional[Iterable[Hook]] = None, disable_default_hooks: bool = False, ) -> None: self.result: Optional[Any] = None @@ -64,9 +65,9 @@ def __exit__( if origin: metric["origin"] = origin - hooks = self._extra_hooks or [] + hooks = self._extra_hooks or () if not self._disable_default_hooks: - hooks = settings.hooks + hooks + hooks = (*settings.hooks, *hooks) # Apply the registered hooks, and collect the metadata they might # return to be stored with the metrics.