Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Firehose Client: block on make message handler call, add on error callback #157

Merged
merged 11 commits into from
Oct 27, 2023
163 changes: 84 additions & 79 deletions atproto/firehose/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
_MAX_MESSAGE_SIZE_BYTES = 1024 * 1024 * 5 # 5MB

OnMessageCallback = t.Callable[['MessageFrame'], t.Generator[t.Any, None, t.Any]]
AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Coroutine[t.Any, t.Any, t.Any]]
AsyncOnMessageCallback = t.Callable[['MessageFrame'], t.Coroutine[t.Any, t.Any, None]]

OnCallbackErrorCallback = t.Callable[[BaseException], None]
AsyncOnCallbackErrorCallback = t.Callable[[BaseException], t.Coroutine[t.Any, t.Any, None]]


def _build_websocket_uri(
Expand All @@ -56,10 +57,6 @@ def _handle_frame_decoding_error(exception: Exception) -> None:
raise exception


def _print_exception(exception: BaseException) -> None:
traceback.print_exception(type(exception), exception, exception.__traceback__)


def _handle_websocket_error_or_stop(exception: Exception) -> bool:
"""Returns if the connection should be properly being closed or reraise exception."""
if isinstance(exception, (ConnectionClosedOK,)):
Expand All @@ -72,6 +69,15 @@ def _handle_websocket_error_or_stop(exception: Exception) -> bool:
raise FirehoseError from exception


def _get_message_frame_from_bytes_or_raise(data: bytes) -> MessageFrame:
frame = Frame.from_bytes(data)
if isinstance(frame, ErrorFrame):
raise FirehoseError(XrpcError(frame.body.error, frame.body.message))
if isinstance(frame, MessageFrame):
return frame
raise FirehoseDecodingError('Unknown frame type')


class _WebsocketClientBase:
def __init__(
self,
Expand All @@ -86,9 +92,6 @@ def __init__(
self._reconnect_no = 0
self._max_reconnect_delay_sec = 64

self._on_message_callback: t.Optional[t.Union[OnMessageCallback, AsyncOnMessageCallback]] = None
self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None

def update_params(self, params: t.Union[ParamsModelBase, t.Dict[str, t.Any]]) -> None:
"""Update params.

Expand All @@ -110,6 +113,9 @@ def _get_client(self):
return connect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES)

def _get_async_client(self):
# FIXME(DXsmiley): I've noticed that the close operation often takes the entire timeout for some reason
# By default, this is 10 seconds, which is pretty long.
# Maybe shorten it?
return aconnect(self._websocket_uri, max_size=_MAX_MESSAGE_SIZE_BYTES)

def _get_reconnection_delay(self) -> int:
Expand All @@ -118,55 +124,23 @@ def _get_reconnection_delay(self) -> int:

return min(base_sec, self._max_reconnect_delay_sec) + rand_sec

def _process_raw_frame(self, data: bytes) -> None:
frame = Frame.from_bytes(data)
if isinstance(frame, ErrorFrame):
raise FirehoseError(XrpcError(frame.body.error, frame.body.message))
if isinstance(frame, MessageFrame):
self._process_message_frame(frame)
else:
raise FirehoseDecodingError('Unknown frame type')

def start(
self,
on_message_callback: t.Union[OnMessageCallback, AsyncOnMessageCallback],
on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None,
) -> None:
"""Subscribe to Firehose and start client.

Args:
on_message_callback: Callback that will be called on the new Firehose message.
on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception.

Returns:
:obj:`None`
"""
self._on_message_callback = on_message_callback
self._on_callback_error_callback = on_callback_error_callback

def stop(self):
"""Unsubscribe and stop the Firehose client.

Returns:
:obj:`None`
"""
raise NotImplementedError

def _process_message_frame(self, frame: 'MessageFrame') -> None:
raise NotImplementedError


class _WebsocketClient(_WebsocketClientBase):
def __init__(
self, method: str, base_uri: t.Optional[str] = None, params: t.Optional[t.Dict[str, t.Any]] = None
) -> None:
super().__init__(method, base_uri, params)

# TODO(DXsmiley): Not sure if this should be a Lock or not, the async is using an Event now
self._stop_lock = threading.Lock()
DXsmiley marked this conversation as resolved.
Show resolved Hide resolved

self._on_message_callback: t.Optional[OnMessageCallback] = None
self._on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None

def _process_message_frame(self, frame: 'MessageFrame') -> None:
try:
self._on_message_callback(frame)
if self._on_message_callback is not None:
self._on_message_callback(frame)
except Exception as e: # noqa: BLE001
if self._on_callback_error_callback:
try:
Expand All @@ -176,8 +150,22 @@ def _process_message_frame(self, frame: 'MessageFrame') -> None:
else:
traceback.print_exc()

def start(self, *args, **kwargs):
super().start(*args, **kwargs)
def start(
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
self,
on_message_callback: OnMessageCallback,
on_callback_error_callback: t.Optional[OnCallbackErrorCallback] = None,
) -> None:
"""Subscribe to Firehose and start client.

Args:
on_message_callback: Callback that will be called on the new Firehose message.
on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception.

Returns:
:obj:`None`
"""
self._on_message_callback = on_message_callback
self._on_callback_error_callback = on_callback_error_callback

while not self._stop_lock.locked():
try:
Expand All @@ -194,7 +182,8 @@ def start(self, *args, **kwargs):
continue

try:
self._process_raw_frame(raw_frame)
frame = _get_message_frame_from_bytes_or_raise(raw_frame)
self._process_message_frame(frame)
except Exception as e: # noqa: BLE001
_handle_frame_decoding_error(e)
except Exception as e: # noqa: BLE001
Expand All @@ -207,7 +196,12 @@ def start(self, *args, **kwargs):
if self._stop_lock.locked():
self._stop_lock.release()

def stop(self):
def stop(self) -> None:
"""Unsubscribe and stop the Firehose client.

Returns:
:obj:`None`
"""
if not self._stop_lock.locked():
self._stop_lock.acquire()

Expand All @@ -218,64 +212,75 @@ def __init__(
) -> None:
super().__init__(method, base_uri, params)

self._loop = asyncio.get_event_loop()
self._on_message_tasks: t.Set[asyncio.Task] = set()

self._stop_lock = asyncio.Lock()
self._stop_event = asyncio.Event()

def _on_message_callback_done(self, task: asyncio.Task) -> None:
self._on_message_tasks.discard(task)
self._on_message_callback: t.Optional[AsyncOnMessageCallback] = None
self._on_callback_error_callback: t.Optional[AsyncOnCallbackErrorCallback] = None

exception = task.exception()
if exception:
if not self._on_callback_error_callback:
_print_exception(exception)
return

try:
self._on_callback_error_callback(exception)
except: # noqa
async def _process_message_frame(self, frame: 'MessageFrame') -> None:
try:
if self._on_message_callback is not None:
await self._on_message_callback(frame)
except Exception as e: # noqa: BLE001
if self._on_callback_error_callback:
try:
await self._on_callback_error_callback(e)
except: # noqa
traceback.print_exc()
else:
traceback.print_exc()

def _process_message_frame(self, frame: 'MessageFrame') -> None:
task: asyncio.Task = self._loop.create_task(self._on_message_callback(frame))
self._on_message_tasks.add(task)
task.add_done_callback(self._on_message_callback_done)
async def start(
self,
on_message_callback: AsyncOnMessageCallback,
on_callback_error_callback: t.Optional[AsyncOnCallbackErrorCallback] = None,
) -> None:
"""Subscribe to Firehose and start client.

async def start(self, *args, **kwargs):
super().start(*args, **kwargs)
Args:
on_message_callback: Callback that will be called on the new Firehose message.
on_callback_error_callback: Callback that will be called if the `on_message_callback` raised an exception.

while not self._stop_lock.locked():
Returns:
:obj:`None`
"""
self._on_message_callback = on_message_callback
self._on_callback_error_callback = on_callback_error_callback

while not self._stop_event.is_set():
try:
if self._reconnect_no != 0:
await asyncio.sleep(self._get_reconnection_delay())

async with self._get_async_client() as client:
self._reconnect_no = 0

while not self._stop_lock.locked():
while not self._stop_event.is_set():
raw_frame = await client.recv()
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(raw_frame, str):
# skip text frames (should not be occurred)
continue

try:
self._process_raw_frame(raw_frame)
frame = _get_message_frame_from_bytes_or_raise(raw_frame)
await self._process_message_frame(frame)
except Exception as e: # noqa: BLE001
_handle_frame_decoding_error(e)

except Exception as e: # noqa: BLE001
self._reconnect_no += 1

should_stop = _handle_websocket_error_or_stop(e)
if should_stop:
break

if self._stop_lock.locked():
self._stop_lock.release()
async def stop(self) -> None:
"""Unsubscribe and stop the Firehose client.

async def stop(self):
if not self._stop_lock.locked():
await self._stop_lock.acquire()
Returns:
:obj:`None`
"""
self._stop_event.set()


FirehoseClient = _WebsocketClient
Expand Down