Skip to content

Commit

Permalink
NEW: support of async hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ss18 committed Jul 10, 2024
1 parent f135a52 commit cd3c8d1
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 44 deletions.
32 changes: 31 additions & 1 deletion tests/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fqn_decorators import get_fqn

from tests.conftest import go
from time_execution import GeneratorHookReturnType, settings, time_execution
from time_execution import GeneratorHookReturnType, settings, time_execution, time_execution_async
from time_execution.backends.base import BaseMetricsBackend


Expand All @@ -29,10 +29,40 @@ def local_hook(**kwargs):
return dict(local_hook_key="local hook value")


async def async_local_hook(**kwargs):
return dict(async_local_hook="async_local_hook value")


async def async_global_hook(**kwargs):
return dict(async_global_hook="async_global_hook value")


def global_hook(**kwargs):
return dict(global_hook_key="global hook value")


class TestAsyncHooks:
pytestmark = pytest.mark.asyncio

async def test_async_hooks(self):
with settings(backends=[CollectorBackend()], hooks=[global_hook, async_global_hook]):
collector = settings.backends[0]

@time_execution_async(extra_hooks=[local_hook, async_local_hook])
async def func_local_hook(*args, **kwargs):
return True

await func_local_hook()

assert len(collector.metrics) == 1
metadata = collector.metrics[0][func_local_hook.get_fqn()]
assert metadata["local_hook_key"] == "local hook value"
assert metadata["global_hook_key"] == "global hook value"
assert metadata["async_local_hook"] == "async_local_hook value"
assert metadata["async_global_hook"] == "async_global_hook value"
collector.clean()


class TestTimeExecution:
def test_custom_hook(self):
with settings(backends=[CollectorBackend()], hooks=[global_hook]):
Expand Down
6 changes: 4 additions & 2 deletions time_execution/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def time_execution(


def time_execution(__wrapped=None, get_fqn: Callable[[Any], str] = fqn_decorators.get_fqn, **kwargs):
from time_execution.timed import Timed # work around the circular dependency
from time_execution.timed import Timed, TimedAsync # work around the circular dependency

def wrap(__wrapped: _F) -> _F:
fqn = get_fqn(__wrapped)
Expand All @@ -62,7 +62,9 @@ def wrapper(*call_args, **call_kwargs):

@wraps(__wrapped)
async def wrapper(*call_args, **call_kwargs):
with Timed(wrapped=__wrapped, call_args=call_args, call_kwargs=call_kwargs, fqn=fqn, **kwargs) as timed:
async with TimedAsync(
wrapped=__wrapped, call_args=call_args, call_kwargs=call_kwargs, fqn=fqn, **kwargs
) as timed:
timed.result = await __wrapped(*call_args, **call_kwargs)
return timed.result

Expand Down
133 changes: 92 additions & 41 deletions time_execution/timed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Iterable
from contextlib import AbstractContextManager
from inspect import isgenerator, isgeneratorfunction
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from inspect import iscoroutinefunction, isgenerator, isgeneratorfunction
from socket import gethostname
from timeit import default_timer
from types import TracebackType
Expand All @@ -13,10 +13,9 @@
SHORT_HOSTNAME = gethostname()


class Timed(AbstractContextManager):
class Base:
"""
Both the sync and async decorators require the same logic around the wrapped function.
This context manager encapsulates the shared behaviour to avoid duplicating the code.
Base class for context managers encapsulates the shared behaviour to avoid duplicating the code.
"""

__slots__ = (
Expand Down Expand Up @@ -59,19 +58,14 @@ def __init__(
for hook in hooks
)

def __enter__(self) -> Timed:
def enter(self) -> Any:
self._start_time = default_timer()
for hook in self._hooks:
if isgenerator(hook):
hook.send(None) # start a generator hook
return self

def __exit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_val: Optional[BaseException],
__exc_tb: Optional[TracebackType],
) -> None:
def get_metric(self) -> Dict[str, Any]:
duration_millis = round(default_timer() - self._start_time, 3) * 1000.0

metric = {settings.duration_field: duration_millis, "hostname": SHORT_HOSTNAME, "name": self._fqn}
Expand All @@ -80,38 +74,95 @@ def __exit__(
if origin:
metric["origin"] = origin

# Apply the registered hooks, and collect the metadata they might
# return to be stored with the metrics.
metadata = self._apply_hooks(
response=self.result,
exception=__exc_val,
metric=metric,
)
return metric

def apply_hook(
self,
hook: Any,
exception: Optional[BaseException],
metric: Dict[str, Any],
metadata: Dict[str, Any],
) -> None:
if not isgenerator(hook):
hook_result = cast(Hook, hook)(
response=self.result,
exception=exception,
metric=metric,
func=self._wrapped,
func_args=self._call_args,
func_kwargs=self._call_kwargs,
)
else:
# Generator hook: send the results and obtain custom metadata.
try:
hook.send((self.result, exception, metric))
except StopIteration as e:
hook_result = e.value
else:
raise RuntimeError("generator hook did not stop")
if hook_result:
metadata.update(hook_result)


class Timed(AbstractContextManager, Base):

def __enter__(self) -> Timed:
return self.enter()

def __exit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_val: Optional[BaseException],
__exc_tb: Optional[TracebackType],
) -> None:

metadata: Dict[str, Any] = dict()
metric: Dict[str, Any] = self.get_metric()

for hook in self._hooks:
self.apply_hook(hook=hook, exception=__exc_val, metric=metric, metadata=metadata)

metric.update(metadata)
write_metric(**metric) # type: ignore[arg-type]

def _apply_hooks(self, response, exception, metric) -> Dict:

class TimedAsync(AbstractAsyncContextManager, Base):

async def __aenter__(self) -> Timed:
return self.enter()

async def __aexit__(
self,
__exc_type: Optional[Type[BaseException]],
__exc_val: Optional[BaseException],
__exc_tb: Optional[TracebackType],
) -> None:

metadata: Dict[str, Any] = dict()
metric: Dict[str, Any] = self.get_metric()

for hook in self._hooks:
if not isgenerator(hook):
# Simple exit hook, call it directly.
hook_result = cast(Hook, hook)(
response=response,
exception=exception,
metric=metric,
func=self._wrapped,
func_args=self._call_args,
func_kwargs=self._call_kwargs,
)
else:
# Generator hook: send the results and obtain custom metadata.
try:
hook.send((response, exception, metric))
except StopIteration as e:
hook_result = e.value
else:
raise RuntimeError("generator hook did not stop")
if hook_result:
metadata.update(hook_result)
return metadata
await self._apply_hook(hook=hook, exception=__exc_val, metric=metric, metadata=metadata)

metric.update(metadata)
write_metric(**metric) # type: ignore[arg-type]

async def _apply_hook(
self,
hook: Any,
exception: Optional[BaseException],
metric: Dict[str, Any],
metadata: Dict[str, Any],
) -> None:
if iscoroutinefunction(hook):
hook_result = await hook(
response=self.result,
exception=exception,
metric=metric,
func=self._wrapped,
func_args=self._call_args,
func_kwargs=self._call_kwargs,
)
metadata.update(hook_result)
else:
self.apply_hook(hook=hook, exception=exception, metric=metric, metadata=metadata)

0 comments on commit cd3c8d1

Please sign in to comment.