Skip to content

Commit

Permalink
[🟩]Fix typing issues, rewrite gateway code
Browse files Browse the repository at this point in the history
  • Loading branch information
WizzyGeek committed Mar 23, 2022
1 parent 58a3425 commit 56b2870
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 139 deletions.
25 changes: 7 additions & 18 deletions cordy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class Client:

def __init__(
self, token: StrOrToken, *,
intents: Intents = None, sharder_cls: type[BaseSharder[Shard]] = Sharder,
num_shards: int = None, shard_ids: Sequence[int] = None
intents: Intents | None = None, sharder_cls: type[BaseSharder[Shard]] = Sharder,
num_shards: int | None = None, shard_ids: Sequence[int] | None = None
):
self.intents = intents or Intents.default()
self.token = token if isinstance(token, Token) else Token(token, bot=True)
Expand All @@ -80,17 +80,7 @@ def __init__(
self.publisher.add(self.emitter)
self.http = HTTPSession(self.token)

if shard_ids is not None and num_shards is None:
raise ValueError("Must provide num_shards if shard ids are provided.")

if shard_ids is None and num_shards is not None:
shard_ids = list(range(num_shards))

self.shard_ids = set(shard_ids) if shard_ids else None
self.num_shards = num_shards

# May manipulate client attributes
self.sharder = sharder_cls(self)
self.sharder = sharder_cls(self, set(shard_ids) if shard_ids else None, num_shards)
self._closed_cb: Callable | None = None
self._closed: bool = False

Expand All @@ -99,7 +89,7 @@ def shards(self) -> list[Shard]:
"list[:class:`~cordy.gateway.Shard`] : All the shards under this client"
return self.sharder.shards

def listen(self, name: str = None) -> Callable[[CoroFn], CoroFn]:
def listen(self, name: str | None = None) -> Callable[[CoroFn], CoroFn]:
"""This method is used as a decorator.
Add the decorated function as a listener for the specified event
Expand All @@ -119,7 +109,7 @@ def deco(func: CoroFn):
return func
return deco

def add_listener(self, func: CoroFn, name: str = None) -> None:
def add_listener(self, func: CoroFn, name: str | None = None) -> None:
"""Add a listener for the given event.
Parameters
Expand All @@ -132,7 +122,7 @@ def add_listener(self, func: CoroFn, name: str = None) -> None:
"""
return self.publisher.subscribe(func, name = name or func.__name__.lower())

def remove_listener(self, func: CoroFn, name: str = None) -> None:
def remove_listener(self, func: CoroFn, name: str | None = None) -> None:
"""Remove a registered listener.
If the listener or event is not found, then does nothing.
Expand All @@ -146,7 +136,7 @@ def remove_listener(self, func: CoroFn, name: str = None) -> None:
"""
return self.publisher.unsubscribe(func, name)

async def wait_for(self, name: str, timeout: int = None, check: CheckFn = None) -> tuple[Any, ...]:
async def wait_for(self, name: str, timeout: int | None = None, check: CheckFn | None = None) -> tuple[Any, ...]:
"""Wait for an event to occur.
Parameters
Expand Down Expand Up @@ -191,7 +181,6 @@ async def connect(self) -> None:
raise ValueError("Can't connect with a closed client.")

await self.setup()
await self.sharder.create_shards()
await self.sharder.launch_shards()

async def disconnect(self, *, code: int = 4000, message: str = "") -> None:
Expand Down
28 changes: 18 additions & 10 deletions cordy/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import types
from collections.abc import Coroutine, Generator
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, cast, overload
from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, overload

if TYPE_CHECKING:
EV = TypeVar("EV", contravariant=True)
from typing import Any

class Observer(Protocol[EV]):
def __call__(self, event: EV) -> None:
def __call__(self, event: EV) -> None | Any:
...

__all__ = (
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, name, /, *args) -> None:
self.name = _clean_event(name)
self.args = args

async def run(self, coro: CoroFn, err_hdlr: Callable[[Exception], Coroutine] = None) -> None:
async def run(self, coro: CoroFn, err_hdlr: Callable[[Exception], Coroutine] | None = None) -> None:
try:
await coro(*self.args)
except asyncio.CancelledError:
Expand Down Expand Up @@ -134,7 +135,7 @@ class Publisher:
listeners: dict[str, set[CoroFn]]
emitters: dict[Emitter, Generator[None, None, None]]

def __init__(self, error_hdlr: Callable[[Exception], Coroutine] = None) -> None:
def __init__(self, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None:
self.waiters = dict()
self.listeners = dict()
self.emitters = dict()
Expand All @@ -152,6 +153,8 @@ async def _notify(self, event: Event) -> None:
else:
fut.set_result(*event.args)

return None

def _notifier(self):
event = yield

Expand All @@ -168,10 +171,14 @@ def subscribe(self, *, name: str) -> Callable[[CoroFn], CoroFn]:
...

@overload
def subscribe(self, func: CoroFn, *, name: str = None, ) -> None:
def subscribe(self, func: CoroFn, *, name: str) -> None:
...

@overload
def subscribe(self, func: None, *, name: str) -> Callable[[CoroFn], CoroFn]:
...

def subscribe(self, func: CoroFn = None, *, name: str = None) -> Callable[[CoroFn], CoroFn] | None:
def subscribe(self, func: CoroFn | None = None, *, name: str | None = None) -> Callable[[CoroFn], CoroFn] | None:
def decorator(fn: CoroFn) -> CoroFn:
if not iscoroutinefunction(fn):
raise TypeError(f"Expected a coroutine function got {type(fn)}.")
Expand All @@ -187,16 +194,17 @@ def decorator(fn: CoroFn) -> CoroFn:
return decorator
else:
decorator(func)
return None

def unsubscribe(self, listener: CoroFn, name: str = None) -> None:
def unsubscribe(self, listener: CoroFn, name: str | None = None) -> None:
ev_listeners = self.listeners.get(_clean_event(name or listener.__name__.lower()))
try:
if ev_listeners:
ev_listeners.remove(listener)
except KeyError:
return

async def wait_for(self, name: str, timeout: int = None, check: CheckFn = None) -> tuple:
async def wait_for(self, name: str, timeout: int | None = None, check: CheckFn | None = None) -> tuple:
name = _clean_event(name)
ev_waiters = self.waiters.get(name)
if ev_waiters is None:
Expand Down Expand Up @@ -238,7 +246,7 @@ def remove(self, emitter: Emitter) -> None:
# subscribing for particular event
# Cons - Narrow use case, overhead reduction is low.
class SourcedPublisher(Publisher, Emitter):
def __init__(self, error_hdlr: Callable[[Exception], Coroutine] = None) -> None:
def __init__(self, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None:
super().__init__(error_hdlr=error_hdlr)
Emitter.__init__(self)

Expand All @@ -250,7 +258,7 @@ def emit(self, event: Event) -> None:
asyncio.create_task(self._notify(event))

class FilteredPublisher(SourcedPublisher, Filter):
def __init__(self, filter_fn: FilterFn, source: Emitter, error_hdlr: Callable[[Exception], Coroutine] = None) -> None:
def __init__(self, filter_fn: FilterFn, source: Emitter, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None:
super().__init__(error_hdlr=error_hdlr)
Filter.__init__(self, filter_fn, source)

Expand Down
Loading

0 comments on commit 56b2870

Please sign in to comment.