Skip to content

Commit

Permalink
Add automatic reconnection to the new asyncio implementation.
Browse files Browse the repository at this point in the history
Missing tests for now.

Fix #1480.
  • Loading branch information
aaugustin committed Aug 28, 2024
1 parent 6ffb6b0 commit 5240361
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 23 deletions.
33 changes: 20 additions & 13 deletions docs/howto/upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,6 @@ Following redirects
The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
redirects yet.

Automatic reconnection
......................

The new implementation of :func:`~asyncio.client.connect` doesn't provide
automatic reconnection yet.

In other words, the following pattern isn't supported::

from websockets.asyncio.client import connect

async for websocket in connect(...): # this doesn't work yet
...

.. _Update import paths:

Import paths
Expand Down Expand Up @@ -185,6 +172,26 @@ it simpler.
``process_response`` replaces ``extra_headers`` and provides more flexibility.
See process_request_, select_subprotocol_, and process_response_ below.

Customizing automatic reconnection
..................................

On the client side, if you're reconnecting automatically with ``async for ... in
connect(...)``, the behavior when a connection attempt fails was enhanced and
made configurable.

The original implementation retried on any error. The new implementation uses an
heuristic to determine whether an error is retryable or fatal. By default, only
network errors and servers errors are considered retryable. You can customize
this behavior with the ``process_exception`` argument of
:func:`~asyncio.client.connect`.

See :func:`~asyncio.client.process_exception` for more information.

Here's how to revert to the behavior of the original implementation::

async for ... in connect(..., process_exception=lambda exc: exc):
...

Tracking open connections
.........................

Expand Down
8 changes: 6 additions & 2 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ Backwards-incompatible changes
New features
............

* Made the set of active connections available in the :attr:`Server.connections
<asyncio.server.Server.connections>` property.
* Added support for reconnecting automatically by using
:func:`~asyncio.client.connect` as an asynchronous iterator to the new
:mod:`asyncio` implementation.

* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
implementations of servers.

* Made the set of active connections available in the :attr:`Server.connections
<asyncio.server.Server.connections>` property.

.. _13.0:

13.0
Expand Down
2 changes: 2 additions & 0 deletions docs/reference/asyncio/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Opening a connection
.. autofunction:: unix_connect
:async:

.. autofunction:: process_exception

Using a connection
------------------

Expand Down
123 changes: 117 additions & 6 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import asyncio
import functools
import logging
from types import TracebackType
from typing import Any, Generator, Sequence
from typing import Any, AsyncIterator, Callable, Generator, Sequence

from ..client import ClientProtocol
from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidStatus
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
Expand Down Expand Up @@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None:
self.response_rcvd.set_result(None)


def process_exception(exc: Exception) -> Exception | None:
"""
Determine whether an error is retryable or fatal.
When reconnecting automatically with ``async for ... in connect(...)``, if a
connection attempt fails, :func:`process_exception` is called to determine
whether to retry connecting or to raise the exception.
This function defines the default behavior, which is to retry on:
* :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors;
* :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
502, 503, or 504: server or proxy errors.
All other exceptions are considered fatal.
You can change this behavior with the ``process_exception`` argument of
:func:`connect`.
Return :obj:`None` if the exception is retryable i.e. when the error could
be transient and trying to reconnect with the same parameters could succeed.
The exception will be logged at the ``INFO`` level.
Return an exception, either ``exc`` or a new exception, if the exception is
fatal i.e. when trying to reconnect will most likely produce the same error.
That exception will be raised, breaking out of the retry loop.
"""
if isinstance(exc, (OSError, asyncio.TimeoutError)):
return None
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
500, # Internal Server Error
502, # Bad Gateway
503, # Service Unavailable
504, # Gateway Timeout
]:
return None
return exc


# This is spelled in lower case because it's exposed as a callable in the API.
class connect:
"""
Expand All @@ -138,6 +181,21 @@ class connect:
The connection is closed automatically when exiting the context.
:func:`connect` can be used as an infinite asynchronous iterator to
reconnect automatically on errors::
async for websocket in connect(...):
try:
...
except websockets.ConnectionClosed:
continue
If the connection fails with a transient error, it is retried with
exponential backoff. If it fails with a fatal error, the exception is
raised, breaking out of the loop.
The connection is closed automatically after each iteration of the loop.
Args:
uri: URI of the WebSocket server.
origin: Value of the ``Origin`` header, for servers that require it.
Expand All @@ -153,6 +211,9 @@ class connect:
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
process_exception: When reconnecting automatically, tell whether an
error is transient or fatal. The default behavior is defined by
:func:`process_exception`. Refer to its documentation for details.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
Expand Down Expand Up @@ -219,6 +280,7 @@ def __init__(
additional_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
compression: str | None = "deflate",
process_exception: Callable[[Exception], Exception | None] = process_exception,
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
Expand Down Expand Up @@ -281,19 +343,26 @@ def factory() -> ClientConnection:

loop = asyncio.get_running_loop()
if kwargs.pop("unix", False):
self.create_connection = loop.create_unix_connection(factory, **kwargs)
self.create_connection = functools.partial(
loop.create_unix_connection, factory, **kwargs
)
else:
if kwargs.get("sock") is None:
kwargs.setdefault("host", wsuri.host)
kwargs.setdefault("port", wsuri.port)
self.create_connection = loop.create_connection(factory, **kwargs)
self.create_connection = functools.partial(
loop.create_connection, factory, **kwargs
)

self.handshake_args = (
additional_headers,
user_agent_header,
)

self.process_exception = process_exception
self.open_timeout = open_timeout
if logger is None:
logger = logging.getLogger("websockets.client")
self.logger = logger

# ... = await connect(...)

Expand All @@ -304,7 +373,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
async def __await_impl__(self) -> ClientConnection:
try:
async with asyncio_timeout(self.open_timeout):
_transport, self.connection = await self.create_connection
_transport, self.connection = await self.create_connection()
try:
await self.connection.handshake(*self.handshake_args)
except (Exception, asyncio.CancelledError):
Expand Down Expand Up @@ -333,6 +402,48 @@ async def __aexit__(
) -> None:
await self.connection.close()

# async for ... in connect(...):

async def __aiter__(self) -> AsyncIterator[ClientConnection]:
delays: Generator[float, None, None] | None = None
while True:
try:
async with self as protocol:
yield protocol
except Exception as exc:
# Determine whether the exception is retryable or fatal.
# The API of process_exception is "return an exception or None";
# "raise an exception" is also supported because it's a frequent
# mistake. It isn't documented in order to keep the API simple.
try:
new_exc = self.process_exception(exc)
except Exception as raised_exc:
new_exc = raised_exc

# The connection failed with a fatal error.
# Raise the exception and exit the loop.
if new_exc is exc:
raise
if new_exc is not None:
raise new_exc from exc

# The connection failed with a retryable error.
# Start or continue backoff and reconnect.
if delays is None:
delays = backoff()
delay = next(delays)
self.logger.info(
"! connect failed; reconnecting in %.1f seconds",
delay,
exc_info=True,
)
await asyncio.sleep(delay)
continue

else:
# The connection succeeded. Reset backoff.
delays = None


def unix_connect(
path: str | None = None,
Expand Down
Loading

0 comments on commit 5240361

Please sign in to comment.