Skip to content

Commit

Permalink
[DOP-11371] Fix @slot decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 8, 2023
1 parent 4092134 commit 649b2f7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 27 deletions.
33 changes: 25 additions & 8 deletions onetl/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import Callable, Generator, Generic, TypeVar
from typing import Callable, Generator, Generic, TypeVar, overload

from typing_extensions import ParamSpec, Protocol, runtime_checkable
from typing_extensions import Protocol, runtime_checkable

from onetl.log import NOTICE

logger = logging.getLogger(__name__)

T = TypeVar("T")
P = ParamSpec("P")


class HookPriority(int, Enum):
Expand All @@ -36,7 +35,7 @@ class HookPriority(int, Enum):


@dataclass # noqa: WPS338
class Hook(Generic[P, T]): # noqa: WPS338
class Hook(Generic[T]): # noqa: WPS338
"""
Hook representation.
Expand Down Expand Up @@ -70,7 +69,7 @@ def some_func(*args, **kwargs):
hook = Hook(callback=some_func, enabled=True, priority=HookPriority.FIRST)
"""

callback: Callable[P, T]
callback: Callable[..., T]
enabled: bool = True
priority: HookPriority = HookPriority.NORMAL

Expand Down Expand Up @@ -198,7 +197,7 @@ def hook_disabled():
)
self.enabled = True

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T | ContextDecorator:
def __call__(self, *args, **kwargs) -> T | ContextDecorator:
"""
Calls the original callback with passed args.
Expand Down Expand Up @@ -361,7 +360,25 @@ def process_result(self, result: T) -> T | None:
return None


def hook(inp: Callable[P, T] | None = None, enabled: bool = True, priority: HookPriority = HookPriority.NORMAL):
@overload
def hook(
inp: Callable[..., T],
enabled: bool,
priority: HookPriority,
) -> Hook[T]:
...


@overload
def hook(
inp: None,
enabled: bool,
priority: HookPriority,
) -> Callable[[Callable[..., T]], Hook[T]]:
...


def hook(inp=None, enabled=True, priority=HookPriority.NORMAL):
"""
Initialize hook from callable/context manager.
Expand Down Expand Up @@ -423,7 +440,7 @@ def process_result(self, result):
...
"""

def inner_wrapper(callback: Callable[P, T]): # noqa: WPS430
def inner_wrapper(callback: Callable[..., T]) -> Hook[T]: # noqa: WPS430
if isinstance(callback, Hook):
raise TypeError("@hook decorator can be applied only once")

Expand Down
11 changes: 5 additions & 6 deletions onetl/hooks/slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial, wraps
from typing import Any, Callable, ContextManager, TypeVar

from typing_extensions import ParamSpec, Protocol
from typing_extensions import Protocol

from onetl.exception import SignatureError
from onetl.hooks.hook import CanProcessResult, Hook, HookPriority
Expand All @@ -17,13 +17,12 @@
from onetl.hooks.method_inheritance_stack import MethodInheritanceStack
from onetl.log import NOTICE

logger = logging.getLogger(__name__)
Method = TypeVar("Method", bound=Callable[..., Any])

P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)


def _unwrap_method(method: Callable[P, T]) -> Callable[P, T]:
def _unwrap_method(method: Method) -> Method:
"""Unwrap @classmethod and @staticmethod to get original function"""
return getattr(method, "__func__", method)

Expand Down Expand Up @@ -624,7 +623,7 @@ def bind(self):
...


def slot(method: Callable[P, T]) -> Callable[P, T]:
def slot(method: Method) -> Method:
"""
Decorator which enables hooks functionality on a specific class method.
Expand Down
18 changes: 5 additions & 13 deletions onetl/impl/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,22 @@

from __future__ import annotations

from typing import ClassVar
from typing import Any

from pydantic import BaseModel as PydanticBaseModel


class BaseModel(PydanticBaseModel):
_forward_refs_updated: ClassVar[bool] = False

class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
extra = "forbid"
underscore_attrs_are_private = True

def __init__(self, **kwargs):
if not self._forward_refs_updated:
# if pydantic cannot detect type hints (referenced class is not imported yet),
# it wraps annotation with ForwardRef(...), which should be resolved before creating the instance.
# so using a small hack to detect all those refs and update them
# when first object instance is being created
refs = self._forward_refs()
self.__class__.update_forward_refs(**refs)
self.__class__._forward_refs_updated = True # noqa: WPS437
super().__init__(**kwargs)
@classmethod
def __try_update_forward_refs__(cls, **localns: Any) -> None:
refs = cls._forward_refs()
cls.update_forward_refs(**refs, **localns)

@classmethod
def _forward_refs(cls) -> dict[str, type]:
Expand Down

0 comments on commit 649b2f7

Please sign in to comment.