From 56b2870f559b6a60b27985d5e78ccb521d0defe2 Mon Sep 17 00:00:00 2001 From: WizzyGeek <51919967+WizzyGeek@users.noreply.github.com> Date: Wed, 23 Mar 2022 22:14:02 +0530 Subject: [PATCH] =?UTF-8?q?[=F0=9F=9F=A9]Fix=20typing=20issues,=20rewrite?= =?UTF-8?q?=20gateway=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cordy/client.py | 25 ++-- cordy/events.py | 28 +++-- cordy/gateway.py | 266 ++++++++++++++++++++++++---------------- cordy/http/ratelimit.py | 2 +- cordy/models/flags.py | 2 +- cordy/runner.py | 2 +- cordy/util.py | 2 +- mypy.ini | 2 +- 8 files changed, 190 insertions(+), 139 deletions(-) diff --git a/cordy/client.py b/cordy/client.py index 1627835..a1bb234 100644 --- a/cordy/client.py +++ b/cordy/client.py @@ -69,8 +69,8 @@ class Client: def __init__( self, token: StrOrToken, *, - intents: Intents = None, sharder_cls: type[BaseSharder[Shard]] = Sharder, - num_shards: int = None, shard_ids: Sequence[int] = None + intents: Intents | None = None, sharder_cls: type[BaseSharder[Shard]] = Sharder, + num_shards: int | None = None, shard_ids: Sequence[int] | None = None ): self.intents = intents or Intents.default() self.token = token if isinstance(token, Token) else Token(token, bot=True) @@ -80,17 +80,7 @@ def __init__( self.publisher.add(self.emitter) self.http = HTTPSession(self.token) - if shard_ids is not None and num_shards is None: - raise ValueError("Must provide num_shards if shard ids are provided.") - - if shard_ids is None and num_shards is not None: - shard_ids = list(range(num_shards)) - - self.shard_ids = set(shard_ids) if shard_ids else None - self.num_shards = num_shards - - # May manipulate client attributes - self.sharder = sharder_cls(self) + self.sharder = sharder_cls(self, set(shard_ids) if shard_ids else None, num_shards) self._closed_cb: Callable | None = None self._closed: bool = False @@ -99,7 +89,7 @@ def shards(self) -> list[Shard]: "list[:class:`~cordy.gateway.Shard`] : All the shards under this client" return self.sharder.shards - def listen(self, name: str = None) -> Callable[[CoroFn], CoroFn]: + def listen(self, name: str | None = None) -> Callable[[CoroFn], CoroFn]: """This method is used as a decorator. Add the decorated function as a listener for the specified event @@ -119,7 +109,7 @@ def deco(func: CoroFn): return func return deco - def add_listener(self, func: CoroFn, name: str = None) -> None: + def add_listener(self, func: CoroFn, name: str | None = None) -> None: """Add a listener for the given event. Parameters @@ -132,7 +122,7 @@ def add_listener(self, func: CoroFn, name: str = None) -> None: """ return self.publisher.subscribe(func, name = name or func.__name__.lower()) - def remove_listener(self, func: CoroFn, name: str = None) -> None: + def remove_listener(self, func: CoroFn, name: str | None = None) -> None: """Remove a registered listener. If the listener or event is not found, then does nothing. @@ -146,7 +136,7 @@ def remove_listener(self, func: CoroFn, name: str = None) -> None: """ return self.publisher.unsubscribe(func, name) - async def wait_for(self, name: str, timeout: int = None, check: CheckFn = None) -> tuple[Any, ...]: + async def wait_for(self, name: str, timeout: int | None = None, check: CheckFn | None = None) -> tuple[Any, ...]: """Wait for an event to occur. Parameters @@ -191,7 +181,6 @@ async def connect(self) -> None: raise ValueError("Can't connect with a closed client.") await self.setup() - await self.sharder.create_shards() await self.sharder.launch_shards() async def disconnect(self, *, code: int = 4000, message: str = "") -> None: diff --git a/cordy/events.py b/cordy/events.py index 69af73d..aa9ce93 100644 --- a/cordy/events.py +++ b/cordy/events.py @@ -4,13 +4,14 @@ import types from collections.abc import Coroutine, Generator from inspect import iscoroutinefunction -from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, cast, overload +from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, overload if TYPE_CHECKING: EV = TypeVar("EV", contravariant=True) + from typing import Any class Observer(Protocol[EV]): - def __call__(self, event: EV) -> None: + def __call__(self, event: EV) -> None | Any: ... __all__ = ( @@ -58,7 +59,7 @@ def __init__(self, name, /, *args) -> None: self.name = _clean_event(name) self.args = args - async def run(self, coro: CoroFn, err_hdlr: Callable[[Exception], Coroutine] = None) -> None: + async def run(self, coro: CoroFn, err_hdlr: Callable[[Exception], Coroutine] | None = None) -> None: try: await coro(*self.args) except asyncio.CancelledError: @@ -134,7 +135,7 @@ class Publisher: listeners: dict[str, set[CoroFn]] emitters: dict[Emitter, Generator[None, None, None]] - def __init__(self, error_hdlr: Callable[[Exception], Coroutine] = None) -> None: + def __init__(self, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None: self.waiters = dict() self.listeners = dict() self.emitters = dict() @@ -152,6 +153,8 @@ async def _notify(self, event: Event) -> None: else: fut.set_result(*event.args) + return None + def _notifier(self): event = yield @@ -168,10 +171,14 @@ def subscribe(self, *, name: str) -> Callable[[CoroFn], CoroFn]: ... @overload - def subscribe(self, func: CoroFn, *, name: str = None, ) -> None: + def subscribe(self, func: CoroFn, *, name: str) -> None: + ... + + @overload + def subscribe(self, func: None, *, name: str) -> Callable[[CoroFn], CoroFn]: ... - def subscribe(self, func: CoroFn = None, *, name: str = None) -> Callable[[CoroFn], CoroFn] | None: + def subscribe(self, func: CoroFn | None = None, *, name: str | None = None) -> Callable[[CoroFn], CoroFn] | None: def decorator(fn: CoroFn) -> CoroFn: if not iscoroutinefunction(fn): raise TypeError(f"Expected a coroutine function got {type(fn)}.") @@ -187,8 +194,9 @@ def decorator(fn: CoroFn) -> CoroFn: return decorator else: decorator(func) + return None - def unsubscribe(self, listener: CoroFn, name: str = None) -> None: + def unsubscribe(self, listener: CoroFn, name: str | None = None) -> None: ev_listeners = self.listeners.get(_clean_event(name or listener.__name__.lower())) try: if ev_listeners: @@ -196,7 +204,7 @@ def unsubscribe(self, listener: CoroFn, name: str = None) -> None: except KeyError: return - async def wait_for(self, name: str, timeout: int = None, check: CheckFn = None) -> tuple: + async def wait_for(self, name: str, timeout: int | None = None, check: CheckFn | None = None) -> tuple: name = _clean_event(name) ev_waiters = self.waiters.get(name) if ev_waiters is None: @@ -238,7 +246,7 @@ def remove(self, emitter: Emitter) -> None: # subscribing for particular event # Cons - Narrow use case, overhead reduction is low. class SourcedPublisher(Publisher, Emitter): - def __init__(self, error_hdlr: Callable[[Exception], Coroutine] = None) -> None: + def __init__(self, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None: super().__init__(error_hdlr=error_hdlr) Emitter.__init__(self) @@ -250,7 +258,7 @@ def emit(self, event: Event) -> None: asyncio.create_task(self._notify(event)) class FilteredPublisher(SourcedPublisher, Filter): - def __init__(self, filter_fn: FilterFn, source: Emitter, error_hdlr: Callable[[Exception], Coroutine] = None) -> None: + def __init__(self, filter_fn: FilterFn, source: Emitter, error_hdlr: Callable[[Exception], Coroutine] | None = None) -> None: super().__init__(error_hdlr=error_hdlr) Filter.__init__(self, filter_fn, source) diff --git a/cordy/gateway.py b/cordy/gateway.py index 15e44eb..a2ecf31 100644 --- a/cordy/gateway.py +++ b/cordy/gateway.py @@ -4,17 +4,17 @@ import logging import zlib from enum import IntEnum -from math import ceil, log10 as _log +from math import ceil +from math import log10 as _log from sys import platform from time import perf_counter -from typing import (TYPE_CHECKING, ClassVar, Protocol, TypeVar, +from typing import (TYPE_CHECKING, Callable, ClassVar, Protocol, TypeVar, runtime_checkable) import aiohttp +import uprate as up from aiohttp import WSMsgType -from aiohttp.client_ws import ClientWebSocketResponse from yarl import URL -import uprate as up from cordy.events import Event, SourcedPublisher @@ -24,11 +24,14 @@ from asyncio.futures import Future from asyncio.locks import Lock from asyncio.tasks import Task + from typing import Any from .client import Client from .types import Dispatch, Payload from .util import Msg + method_map: dict[int, Callable[[GateWay, Payload], Any]] + logger = logging.getLogger(__name__) # TODO: ETF, Pyrlang/Term, discord/erlpack @@ -105,7 +108,6 @@ def get_enum(cls, val: int) -> OpCodes | None: else: return ret - class Inflator: # for zlib """A Callable which decompresses incoming zlib data. @@ -118,24 +120,26 @@ class Inflator: # for zlib stream : :class:`bool` Whether the the data is a part of zlib-stream """ + buf: bytearray + stream: bool + def __init__(self, stream: bool = False, **opt): self.buf = bytearray() self.stream = stream self.decomp = zlib.decompressobj(**opt) def __call__(self, data: bytes) -> str | None: - if not self.stream: - try: - return zlib.decompress(data).decode("utf-8") - except zlib.error: - self.buf.extend(data) + if self.stream: + self.buf.extend(data) - if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': - return None + if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': + return None - msg = self.decomp.decompress(self.buf).decode("utf-8") - self.buf = bytearray() - return msg + msg = self.decomp.decompress(self.buf).decode("utf-8") + self.buf = bytearray() + return msg + else: + return zlib.decompress(data).decode("utf-8") # TODO: Consider Async initialisation to remove checks class GateWay: @@ -149,22 +153,47 @@ class GateWay: "$device": "cordy" } - ws: ClientWebSocketResponse | None - _url: URL | None - _listener: Task[None] | None - _beater: Task[None] | None - _reconnect: bool - _seq: int | None - _interval: float | None - _ack_fut: Future | None - - def __init__(self, client: Client, *, shard: Shard, inflator: Inflator = None, compression: bool = False) -> None: + if TYPE_CHECKING: + from asyncio import AbstractEventLoop + + from aiohttp.client_ws import ClientWebSocketResponse + from uprate import Bucket + + from cordy.auth import Token + from cordy.http import HTTPSession + from cordy.models import Intents + + ws: ClientWebSocketResponse + session: HTTPSession + token: Token + intents: Intents + inflator: Inflator + client: Client + shard_id: int + emitter: SourcedPublisher + shard: Shard + loop: AbstractEventLoop + + _tracker: LatencyTracker + + _ack_fut: Future | None + _url: URL | None + _listener: Task[None] | None + _beater: Task[None] | None + _reconnect: bool + _seq: int | None + _interval: float | None + _ratelimit: Bucket[None] + _session_id: str + + @classmethod + async def make_gateway(cls, client: Client, *, shard: Shard, inflator: Inflator | None = None, compression: bool = False) -> GateWay: + self = cls() self.session = client.http self.token = client.token self.intents = client.intents self.inflator = inflator or Inflator() self.inflator.stream = compression - self.ws = None self.client = client self.shard_id = shard.shard_id self.emitter = shard.emitter @@ -172,20 +201,24 @@ def __init__(self, client: Client, *, shard: Shard, inflator: Inflator = None, c self.loop = asyncio.get_event_loop() self._closed = False - self._url = None + self._url = client.sharder._url or None self._seq = None self._session_id = "" self._interval = 0.0 - self._ack_fut = None self._tracker = LatencyTracker() + + self._listener = None self._beater = None + self._ack_fut = None + self._resume = True - self._listener = None + # Whether listener should attempt reconnect after interal/externally caused disconnect self._reconnect = True - self._reconnecting = False - self._compression = compression - self._ratelimit = up.Bucket[None](120 / up.Minutes(1)) + self._ratelimit = up.Bucket[None](120 / up.Minutes(1)) # type: ignore + + await self.connect() # self.ws + return self @property def resumable(self) -> bool: @@ -197,23 +230,23 @@ def closed(self) -> bool: @property def disconnected(self) -> bool: - return self.ws is not None and self.ws.closed + return self.ws.closed @property def connected(self) -> bool: return not self.disconnected - async def connect(self, url: URL = None) -> None: + async def connect(self, url: URL | None = None) -> None: if self._closed: raise ValueError("GateWay instance already closed") - if self.ws and not self.ws.closed: + if self.ws.closed: await self.disconnect(message=b"Reconnecting") url = url or self._url or await self.session.get_gateway() url %= {"v": 9, "encoding": "json"} - if self._compression: + if self.inflator.stream: url %= {"compress": "zlib-stream"} try: @@ -233,24 +266,22 @@ async def connect(self, url: URL = None) -> None: self._url = url self._reconnect = True - self._listener is not None and self._listener.cancel() + if self._listener is not None: self._listener.cancel() + self._listener = self.loop.create_task(self.listen()) async def listen(self): - if self.ws is None: - return - backoff = 1 reconnecting = False zombie_listener = False while not self._closed: if reconnecting: - delay = 2 * backoff * _log(backoff) + delay = 2 * backoff * _log(backoff + 1) logger.info("Disconnected, reconnecting again in %s", delay) await asyncio.sleep(delay) - backoff = (backoff + 1) % 700 # limit sleep to around 2000 sec + backoff = ((backoff) % 700) + 1 # limit to around 4000 sec if zombie_listener and not reconnecting: break @@ -282,24 +313,23 @@ async def listen(self): else: break - async def process_message(self, msg: Payload) -> None: - op = OpCodes.get_enum(msg["op"]) # 0,1,7,9,10,11 + method = method_map.get(msg["op"]) # 0,1,7,9,10,11 self._seq = msg.get("s") or self._seq - if op is None: + if method is None: logger.warning("Gateway sent payload with unknown op code %s", msg.get("op")) return logger.debug("Shard %s Received: %s", self.shard_id, msg) - await getattr(self, op.name.lower())(msg) + await method(self, msg) async def start_session(self): if self.resumable: await self.resume() else: - await self._ratelimit.reset() + await self._ratelimit.reset(None) await self.identify() async def close(self): @@ -325,7 +355,7 @@ async def heartbeat(self, _: Msg): }) async def heartbeater(self) -> None: - if self.ws is None or self._interval is None: + if self._interval is None: return while not self.ws.closed: @@ -395,10 +425,8 @@ async def invalid_session(self, msg: Msg) -> None: if not self._listener or self._listener.done(): await self.connect(self._url) - async def reconnect(self, _: Payload = None) -> None: + async def reconnect(self, _: Payload | None = None) -> None: logger.debug("Shard %s Reconnecting...") - self._reconnect = True - self._resume = True await self.disconnect(message=b"Reconnect request") @@ -419,37 +447,43 @@ async def dispatch(self, msg: Dispatch) -> None: self.emitter.emit(Event(event, data, self.shard)) async def disconnect(self, *, code: int = 4000, message: bytes = b"") -> None: - if not self.ws: - return - await self.ws.close(code=code, message=message) self._tracker.reset() - self._beater and self._beater.cancel() + if self._beater: self._beater.cancel() self.emitter.emit(Event("disconnect", self.shard)) async def send(self, data: Msg) -> None: - if self.ws: - logger.debug("Shard %s Sending: %s", self.shard_id, data) - # todo: presence update - async with self._ratelimit.acquire(None): - await self.ws.send_str(util.dumps(data)) + logger.debug("Shard %s Sending: %s", self.shard_id, data) + # todo: presence update + async with self._ratelimit.acquire(None): + await self.ws.send_str(util.dumps(data)) + +method_map = { + i: getattr(GateWay, j.name.lower()) + for i, j in OpCodes._value2member_map_.items() # type: ignore[attr-defined] + if i in {0, 1, 7, 9, 10, 11} +} class Shard: gateway: GateWay emitter: SourcedPublisher shard_id: int + client: Client - def __init__(self, client: Client, shard_id: int = 0) -> None: + @classmethod + async def make_shard(cls, client: Client, shard_id: int = 0) -> Shard: + self = cls() self.shard_id = shard_id self.client = client self.emitter = SourcedPublisher() client.publisher.add(self.emitter) - self.gateway = GateWay(client, shard=self) + self.gateway = await GateWay.make_gateway(client, shard=self) + return self async def connect(self): if self.gateway.closed: - self.gateway = GateWay(self.client, shard=self) + self.gateway = await GateWay.make_gateway(self.client, shard=self) await self.gateway.connect() @@ -475,73 +509,93 @@ class BaseSharder(Protocol[S]): client: Client _url: URL | None shards: list[S] + shard_ids: set[int] | None + num_shards: int | None - def __init__(self, client: Client) -> None: + def __init__(self, client: Client, shard_ids: set[int] | None = None, num_shards: int | None = None) -> None: self.client = client self._url: URL | None = None - self.shards: list[S] = [] - @property - def num_shards(self): - return self.client.num_shards + if shard_ids is not None and num_shards is None: + raise ValueError("Must provide num_shards if shard ids are provided.") - @property - def shard_ids(self): - return self.client.shard_ids + if shard_ids is None and num_shards is not None: + shard_ids = set(range(num_shards)) - async def create_shards(self) -> None: - ... + self.shards = list[S]() + self.shard_ids = shard_ids + self.num_shards = num_shards async def launch_shards(self) -> None: ... class Sharder(BaseSharder[Shard]): - def __init__(self, client: Client) -> None: - self.client = client - self._url = None - self.shards = [] - - async def create_shards(self): - if not self.num_shards: - # don't cache the session limit - # coz delay till launch is unknown - data = await self.client.http.get_gateway_bot() - self.client.num_shards = data["shards"] - self.client.shard_ids = self.shard_ids or set(range(data["shards"])) - - shards = [] - - if self.shard_ids: - for id_ in self.shard_ids: - shards.append(Shard(self.client, id_)) - else: - raise ValueError("Cannot create shards for client without shard ids") + async def launch_shards(self): + data = await self.client.http.get_gateway_bot() # Get the Url once + self.num_shards = self.num_shards or data["shards"] + self.shard_ids = self.shard_ids or set(range(data["shards"])) - self.shards = shards + shards = self.shards = [] # in case we are being called again - async def launch_shards(self) -> None: - if not self.shards: - await self.create_shards() - url = self._url - loop = asyncio.get_event_loop() + self._url = URL(data["url"]) # Update the cached url - # session object can't be cached - data = await self.client.http.get_gateway_bot() limit = data["session_start_limit"] + async def runner(sid: int, lock: Lock): + async with lock: + shards.append(await Shard.make_shard(self.client, sid)) + max_conc: int = limit["max_concurrency"] buckets = [asyncio.Lock() for _ in range(max_conc)] - async def runner(sd: Shard, lock: Lock): - async with lock: - await sd.gateway.connect(url) + await asyncio.wait([asyncio.create_task(runner(sd, buckets[sd % max_conc])) for sd in self.shard_ids]) + + # async def create_shards(self): + # if not self.num_shards: + # # don't cache the session limit + # # coz delay till launch is unknown + # data = await self.client.http.get_gateway_bot() + # self.client.num_shards = data["shards"] + # self.client.shard_ids = self.shard_ids or set(range(data["shards"])) + + # shards = [] + + # if self.shard_ids: + # for id_ in self.shard_ids: + # shards.append(await Shard.make_shard(self.client, id_)) + # else: + # raise ValueError("Cannot create shards for client without shard ids") + + # self.shards = shards + + # async def launch_shards(self) -> None: + # if not self.shards: + # await self.create_shards() + # url = self._url + # loop = asyncio.get_event_loop() + + # # session object can't be cached + # data = await self.client.http.get_gateway_bot() + + # if not url: + # url = URL(data["url"]) + + # limit = data["session_start_limit"] + + # max_conc: int = limit["max_concurrency"] + + # buckets = [asyncio.Lock() for _ in range(max_conc)] + + # async def runner(sd: Shard, lock: Lock): + # async with lock: + # await Shard.make_shard - await asyncio.wait([loop.create_task(runner(sd, buckets[sd.shard_id % max_conc])) for sd in self.shards]) + # await asyncio.wait([loop.create_task(runner(sd, buckets[sd.shard_id % max_conc])) for sd in range(data["shards"])]) class SingleSharder(BaseSharder[Shard]): async def create_shards(self) -> None: - self.shards = [Shard(self.client, 0)] + self.shards = [await Shard.make_shard(self.client, 0)] self.client.num_shards = 1 self.client.shard_ids = {0,} # diff --git a/cordy/http/ratelimit.py b/cordy/http/ratelimit.py index 64f1e64..8277e4f 100644 --- a/cordy/http/ratelimit.py +++ b/cordy/http/ratelimit.py @@ -201,7 +201,7 @@ def __init__(self) -> None: self.grouped_buckets = {} self.grouper = Grouper() - def acquire(self, endp: Endpoint, timeout: float = None) -> BaseLimiter: + def acquire(self, endp: Endpoint, timeout: float | None = None) -> BaseLimiter: a_group = self.grouper.group_map.get(endp.route) if a_group: diff --git a/cordy/models/flags.py b/cordy/models/flags.py index c5d5634..6cd1a86 100644 --- a/cordy/models/flags.py +++ b/cordy/models/flags.py @@ -78,7 +78,7 @@ def from_int(cls, value: int): return cls class FrozenFlags(int): # Read-Only, for user flags - def __new__(cls, data: int = None): + def __new__(cls, data: int | None = None): if data is not None: return super().__new__(cls, data) diff --git a/cordy/runner.py b/cordy/runner.py index 183ff24..4437b10 100644 --- a/cordy/runner.py +++ b/cordy/runner.py @@ -77,7 +77,7 @@ def launch_all(clients: Iterable[Client]): # Proactor Transport require open event loop. DEFAULT_CLOSE = not sys.platform.startswith("win") -def run_loop(coro, *, close: bool = DEFAULT_CLOSE, debug: bool = None): +def run_loop(coro, *, close: bool = DEFAULT_CLOSE, debug: bool | None = None): try: loop = asyncio.get_running_loop() except RuntimeError: diff --git a/cordy/util.py b/cordy/util.py index ddfb47e..ecaf798 100644 --- a/cordy/util.py +++ b/cordy/util.py @@ -23,7 +23,7 @@ loads: Callable[[str], Any] = json.loads # Any allows custom type without cast dumps: Callable[[Json], str] = lambda dat: json.dumps(dat, separators=(',', ':')) -def make_proxy_for(org_cls, /, *, attr: str, proxied_attrs: Iterable[str] = None, proxied_methods: Iterable[str] = None): +def make_proxy_for(org_cls, /, *, attr: str, proxied_attrs: Iterable[str] | None = None, proxied_methods: Iterable[str] | None = None): def deco(cls): def make_encapsulators(name: str): nonlocal attr diff --git a/mypy.ini b/mypy.ini index 02df869..98782a9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,7 +5,7 @@ python_version = 3.9 warn_return_any = true warn_unreachable = true warn_redundant_casts = true -warn_unused_ignores = true +warn_unused_ignores = false warn_unused_configs = true check_untyped_defs = true