diff --git a/atproto/firehose/client.py b/atproto/firehose/client.py index 49c6aee3..9cb792da 100644 --- a/atproto/firehose/client.py +++ b/atproto/firehose/client.py @@ -28,9 +28,10 @@ _MAX_MESSAGE_SIZE_BYTES = 1024 * 1024 * 5 # 5MB OnMessageCallback = t.Callable[['MessageFrame'], t.Generator[t.Any, None, t.Any]] -AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Coroutine[t.Any, t.Any, t.Any]] +AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Coroutine[t.Any, t.Any, None]] OnCallbackErrorCallback = t.Callable[[BaseException], None] +AsyncOnCallbackErrorCallback = t.Callable[[BaseException], t.Coroutine[t.Any, t.Any, None]] def _build_websocket_uri( @@ -56,10 +57,6 @@ def _handle_frame_decoding_error(exception: Exception) -> None: raise exception -def _print_exception(exception: BaseException) -> None: - traceback.print_exception(type(exception), exception, exception.__traceback__) - - def _handle_websocket_error_or_stop(exception: Exception) -> bool: """Returns if the connection should be properly being closed or reraise exception.""" if isinstance(exception, (ConnectionClosedOK,)): @@ -72,6 +69,15 @@ def _handle_websocket_error_or_stop(exception: Exception) -> bool: raise FirehoseError from exception +def _get_message_frame_from_bytes_or_raise(data: bytes) -> MessageFrame: + frame = Frame.from_bytes(data) + if isinstance(frame, ErrorFrame): + raise FirehoseError(XrpcError(frame.body.error, frame.body.message)) + if isinstance(frame, MessageFrame): + return frame + raise FirehoseDecodingError('Unknown frame type') + + class _WebsocketClientBase: def __init__( self, @@ -86,9 +92,6 @@ def __init__( self._reconnect_no = 0 self._max_reconnect_delay_sec = 64 - self._on_message_callback: t.Optional[t.Union[OnMessageCallback, AsyncOnMessageCallback]] = None - self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None - def update_params(self, params: t.Union[ParamsModelBase, t.Dict[str, t.Any]]) -> None: """Update params. @@ -110,6 +113,9 @@ def _get_client(self): return connect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES) def _get_async_client(self): + # FIXME(DXsmiley): I've noticed that the close operation often takes the entire timeout for some reason + # By default, this is 10 seconds, which is pretty long. + # Maybe shorten it? return aconnect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES) def _get_reconnection_delay(self) -> int: @@ -118,43 +124,6 @@ def _get_reconnection_delay(self) -> int: return min(base_sec, self._max_reconnect_delay_sec) + rand_sec - def _process_raw_frame(self, data: bytes) -> None: - frame = Frame.from_bytes(data) - if isinstance(frame, ErrorFrame): - raise FirehoseError(XrpcError(frame.body.error, frame.body.message)) - if isinstance(frame, MessageFrame): - self._process_message_frame(frame) - else: - raise FirehoseDecodingError('Unknown frame type') - - def start( - self, - on_message_callback: t.Union[OnMessageCallback, AsyncOnMessageCallback], - on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None, - ) -> None: - """Subscribe to Firehose and start client. - - Args: - on_message_callback: Callback that will be called on the new Firehose message. - on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception. - - Returns: - :obj:`None` - """ - self._on_message_callback = on_message_callback - self._on_callback_error_callback = on_callback_error_callback - - def stop(self): - """Unsubscribe and stop the Firehose client. - - Returns: - :obj:`None` - """ - raise NotImplementedError - - def _process_message_frame(self, frame: 'MessageFrame') -> None: - raise NotImplementedError - class _WebsocketClient(_WebsocketClientBase): def __init__( @@ -162,11 +131,16 @@ def __init__( ) -> None: super().__init__(method, base_uri, params) + # TODO(DXsmiley): Not sure if this should be a Lock or not, the async is using an Event now self._stop_lock = threading.Lock() + self._on_message_callback: t.Optional[OnMessageCallback] = None + self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None + def _process_message_frame(self, frame: 'MessageFrame') -> None: try: - self._on_message_callback(frame) + if self._on_message_callback is not None: + self._on_message_callback(frame) except Exception as e: # noqa: BLE001 if self._on_callback_error_callback: try: @@ -176,8 +150,22 @@ def _process_message_frame(self, frame: 'MessageFrame') -> None: else: traceback.print_exc() - def start(self, *args, **kwargs): - super().start(*args, **kwargs) + def start( + self, + on_message_callback: OnMessageCallback, + on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None, + ) -> None: + """Subscribe to Firehose and start client. + + Args: + on_message_callback: Callback that will be called on the new Firehose message. + on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception. + + Returns: + :obj:`None` + """ + self._on_message_callback = on_message_callback + self._on_callback_error_callback = on_callback_error_callback while not self._stop_lock.locked(): try: @@ -194,7 +182,8 @@ def start(self, *args, **kwargs): continue try: - self._process_raw_frame(raw_frame) + frame = _get_message_frame_from_bytes_or_raise(raw_frame) + self._process_message_frame(frame) except Exception as e: # noqa: BLE001 _handle_frame_decoding_error(e) except Exception as e: # noqa: BLE001 @@ -207,7 +196,12 @@ def start(self, *args, **kwargs): if self._stop_lock.locked(): self._stop_lock.release() - def stop(self): + def stop(self) -> None: + """Unsubscribe and stop the Firehose client. + + Returns: + :obj:`None` + """ if not self._stop_lock.locked(): self._stop_lock.acquire() @@ -218,34 +212,42 @@ def __init__( ) -> None: super().__init__(method, base_uri, params) - self._loop = asyncio.get_event_loop() - self._on_message_tasks: t.Set[asyncio.Task] = set() - - self._stop_lock = asyncio.Lock() + self._stop_event = asyncio.Event() - def _on_message_callback_done(self, task: asyncio.Task) -> None: - self._on_message_tasks.discard(task) + self._on_message_callback: t.Optional[AsyncOnMessageCallback] = None + self._on_callback_error_callback: t.Optional[AsyncOnCallbackErrorCallback] = None - exception = task.exception() - if exception: - if not self._on_callback_error_callback: - _print_exception(exception) - return - - try: - self._on_callback_error_callback(exception) - except: # noqa + async def _process_message_frame(self, frame: 'MessageFrame') -> None: + try: + if self._on_message_callback is not None: + await self._on_message_callback(frame) + except Exception as e: # noqa: BLE001 + if self._on_callback_error_callback: + try: + await self._on_callback_error_callback(e) + except: # noqa + traceback.print_exc() + else: traceback.print_exc() - def _process_message_frame(self, frame: 'MessageFrame') -> None: - task: asyncio.Task = self._loop.create_task(self._on_message_callback(frame)) - self._on_message_tasks.add(task) - task.add_done_callback(self._on_message_callback_done) + async def start( + self, + on_message_callback: AsyncOnMessageCallback, + on_callback_error_callback: t.Optional[AsyncOnCallbackErrorCallback] = None, + ) -> None: + """Subscribe to Firehose and start client. - async def start(self, *args, **kwargs): - super().start(*args, **kwargs) + Args: + on_message_callback: Callback that will be called on the new Firehose message. + on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception. - while not self._stop_lock.locked(): + Returns: + :obj:`None` + """ + self._on_message_callback = on_message_callback + self._on_callback_error_callback = on_callback_error_callback + + while not self._stop_event.is_set(): try: if self._reconnect_no != 0: await asyncio.sleep(self._get_reconnection_delay()) @@ -253,16 +255,18 @@ async def start(self, *args, **kwargs): async with self._get_async_client() as client: self._reconnect_no = 0 - while not self._stop_lock.locked(): + while not self._stop_event.is_set(): raw_frame = await client.recv() if isinstance(raw_frame, str): # skip text frames (should not be occurred) continue try: - self._process_raw_frame(raw_frame) + frame = _get_message_frame_from_bytes_or_raise(raw_frame) + await self._process_message_frame(frame) except Exception as e: # noqa: BLE001 _handle_frame_decoding_error(e) + except Exception as e: # noqa: BLE001 self._reconnect_no += 1 @@ -270,12 +274,13 @@ async def start(self, *args, **kwargs): if should_stop: break - if self._stop_lock.locked(): - self._stop_lock.release() + async def stop(self) -> None: + """Unsubscribe and stop the Firehose client. - async def stop(self): - if not self._stop_lock.locked(): - await self._stop_lock.acquire() + Returns: + :obj:`None` + """ + self._stop_event.set() FirehoseClient = _WebsocketClient