diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fd186a5fc..108b7c9c0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,7 +52,7 @@ Backwards-incompatible changes async def handler(request, path): ... - You should switch to the recommended pattern since 10.1:: + You should switch to the pattern recommended since version 10.1:: async def handler(request): path = request.path # only if handler() uses the path argument @@ -61,6 +61,16 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 11.0 introduces a new :mod:`asyncio` implementation. + :class: important + + This new implementation is intended to be a drop-in replacement for the + current implementation. It will become the default in a future release. + Please try it and report any issue that you encounter! + + See :func:`websockets.asyncio.client.connect` and + :func:`websockets.asyncio.server.serve` for details. + * Validated compatibility with Python 3.12. 12.0 @@ -175,7 +185,8 @@ New features It is particularly suited to client applications that establish only one connection. It may be used for servers handling few connections. - See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + See :func:`websockets.sync.client.connect` and + :func:`websockets.sync.server.serve` for details. * Added ``open_timeout`` to :func:`~server.serve`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0b80f087a..2486ac564 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -26,6 +26,18 @@ clients concurrently. asyncio/server asyncio/client +:mod:`asyncio` (new) +-------------------- + +This is a rewrite of the :mod:`asyncio` implementation. It will become the +default in the future. + +.. toctree:: + :titlesonly: + + new-asyncio/server + new-asyncio/client + :mod:`threading` ---------------- diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst new file mode 100644 index 000000000..552d83b2f --- /dev/null +++ b/docs/reference/new-asyncio/client.rst @@ -0,0 +1,53 @@ +Client (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: unix_connect + :async: + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst new file mode 100644 index 000000000..ba23552dc --- /dev/null +++ b/docs/reference/new-asyncio/common.rst @@ -0,0 +1,43 @@ +:orphan: + +Both sides (:mod:`asyncio` - new) +================================= + +.. automodule:: websockets.asyncio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst new file mode 100644 index 000000000..f3446fb80 --- /dev/null +++ b/docs/reference/new-asyncio/server.rst @@ -0,0 +1,72 @@ +Server (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. autofunction:: unix_serve + :async: + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/pyproject.toml b/pyproject.toml index 2367849ca..de8acd6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", - "*/websockets/legacy/async_timeout.py", - "*/websockets/legacy/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/asyncio/__init__.py b/src/websockets/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/asyncio/async_timeout.py similarity index 100% rename from src/websockets/legacy/async_timeout.py rename to src/websockets/asyncio/async_timeout.py diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py new file mode 100644 index 000000000..040d68ece --- /dev/null +++ b/src/websockets/asyncio/client.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Generator, Sequence + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + # May raise CancelledError if open_timeout is exceeded. + await self.response_rcvd + + if self.response is None: + raise ConnectionError("connection closed during handshake") + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.response_rcvd.done(): + self.response_rcvd.set_result(None) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + async with websockets.asyncio.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + + wsuri = parse_uri(uri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + + def factory() -> ClientConnection: + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_connection = loop.create_unix_connection(factory, **kwargs) + else: + if kwargs.get("sock") is None: + kwargs.setdefault("host", wsuri.host) + kwargs.setdefault("port", wsuri.port) + self._create_connection = loop.create_connection(factory, **kwargs) + + self._handshake_args = ( + additional_headers, + user_agent_header, + ) + + self._open_timeout = open_timeout + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self._open_timeout): + _transport, self.connection = await self._create_connection + try: + await self.connection.handshake(*self._handshake_args) + except (Exception, asyncio.CancelledError): + self.connection.transport.close() + raise + else: + return self.connection + except TimeoutError: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during handshake") from None + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py new file mode 100644 index 000000000..e17000069 --- /dev/null +++ b/src/websockets/asyncio/compatibility.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import sys + + +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] + + +if sys.version_info[:2] >= (3, 11): + TimeoutError = TimeoutError + aiter = aiter + anext = anext + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) + +else: # Python < 3.11 + from asyncio import TimeoutError + + def aiter(async_iterable): + return type(async_iterable).__aiter__(async_iterable) + + async def anext(async_iterator): + return await type(async_iterator).__anext__(async_iterator) + + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py new file mode 100644 index 000000000..152c6789e --- /dev/null +++ b/src/websockets/asyncio/connection.py @@ -0,0 +1,909 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import uuid +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Iterable, + Mapping, + cast, +) + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.fragmented_send_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return a bytestring (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the + connection unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If the connection busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.fragmented_send_waiter is not None: + await asyncio.shield(self.fragmented_send_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + else: + raise TypeError("data must be bytes, str, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.fragmented_send_waiter is not None: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Data | None = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pong_waiters: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pong_waiters: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). We cannot use time.perf_counter() + # because it doesn't count time elapsed while the process sleeps. + ping_timestamp = self.loop.time() + self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.protocol.send_ping(data) + return pong_waiter + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pong_waiters: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + ping_ids.append(ping_id) + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pong_waiters. + for ping_id in ping_ids: + del self.pong_waiters[ping_id] + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_waiter, _ping_timestamp in self.pong_waiters.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + self.pong_waiters.clear() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_transport() + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + Raises: + OSError: When a socket operations fails. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + def close_transport(self) -> None: + """ + Close transport and message assembler. + + """ + self.transport.close() + self.recv_messages.close() + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.recv_messages = Assembler( + pause=self.transport.pause_reading, + resume=self.transport.resume_reading, + ) + + def connection_lost(self, exc: Exception | None) -> None: + self.protocol.receive_eof() # receive_eof is idempotent + self.recv_messages.close() + self.set_recv_exc(exc) + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + self.abort_pings() + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + self.set_recv_exc(exc) + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it shouldn't raise errors. + self.send_data() + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..bc33df8d7 --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Iterable, + TypeVar, +) + +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of :class:`asyncio.Queue`. + + Provides only the subset of functionality needed by :class:`Assembler`. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: asyncio.Future[None] | None = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue without waiting.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if self.get_waiter is not None: + raise RuntimeError("get is already running") + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def reset(self, items: Iterable[T]) -> None: + """Put back items into an empty, idle queue.""" + assert self.get_waiter is None, "cannot reset() while get() is running" + assert not self.queue, "cannot reset() while queue isn't empty" + self.queue.extend(items) + + def abort(self) -> None: + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + # Clear the queue to avoid storing unnecessary data in memory. + self.queue.clear() + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + # coverage reports incorrectly: "line NN didn't jump to the function exit" + def __init__( # pragma: no cover + self, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming messages. Each item is a queue of frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + self.high = 16 + self.low = 4 + self.paused = False + self.pause = pause + self.resume = resume + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle asyncio.CancelledError because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise RuntimeError. + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def get_limits(self) -> tuple[int, int]: + """Return low and high water marks for flow control.""" + return self.low, self.high + + def set_limits(self, low: int = 4, high: int = 16) -> None: + """Configure low and high water marks for flow control.""" + self.low, self.high = low, high + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Check for "> high" to support high = 0 + if len(self.frames) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Check for "<= low" to support low = 0 + if len(self.frames) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.frames.abort() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py new file mode 100644 index 000000000..aa175f775 --- /dev/null +++ b/src/websockets/asyncio/server.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import socket +import sys +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + Sequence, +) + +from websockets.frames import CloseCode + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + server: Server that manages this connection. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: WebSocketServer, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + # May raise CancelledError if open_timeout is exceeded. + await self.request_rcvd + + if self.request is None: + raise ConnectionError("connection closed during handshake") + + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.request_rcvd.done(): + self.request_rcvd.set_result(None) + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + # On failure, handshake() closes the transport, raises an + # exception, and logs it. + async with asyncio_timeout(self.open_timeout): + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + + try: + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except Exception: + # Don't leak connections on errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.handlers immediately. + # If it was registered in conn_handler(), a race condition could + # happen when closing the server after scheduling conn_handler() + # but before it starts executing. + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(connection.close(1001)) + for connection in self.handlers + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.handlers: + await asyncio.wait(self.handlers.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~ServerConnection.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~ServerConnection.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +# This is spelled in lower case because it's exposed as a callable in the API. +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`WebSocketServer` whose API mirrors + :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + ensure that the server will be closed:: + + def handler(websocket): + ... + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with websockets.asyncio.server.serve(handler, host, port): + await stop + + Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests + and cancel it to stop the server:: + + server = await websockets.asyncio.server.serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable server + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~WebSocketServer.start_serving()` or + :meth:`~WebSocketServer.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = WebSocketServer( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[WebSocketServer]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b0e15b543..d1d8d5608 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -16,6 +16,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidHandshake, @@ -40,7 +41,6 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py deleted file mode 100644 index 6bd01e70d..000000000 --- a/src/websockets/legacy/compatibility.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -import sys - - -__all__ = ["asyncio_timeout"] - - -if sys.version_info[:2] >= (3, 11): - from asyncio import timeout as asyncio_timeout # noqa: F401 -else: - from .async_timeout import timeout as asyncio_timeout # noqa: F401 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c28bdcf48..120ff8e73 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,6 +23,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -49,7 +50,6 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f4442fecc..208ffa780 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -21,6 +21,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -42,7 +43,6 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 33d8299e2..2bcb3aa0e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -373,6 +373,7 @@ def send(self, message: Data | Iterable[Data]) -> None: except RuntimeError: # We didn't start sending a fragmented message. + # The connection is still usable. raise except Exception: @@ -756,7 +757,7 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ assert self.protocol_mutex.locked() - if self.recv_exc is None: + if self.recv_exc is None: # pragma: no branch self.recv_exc = exc def close_socket(self) -> None: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 6cbff2595..ff90345ac 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -85,7 +85,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") self.get_in_progress = True @@ -144,7 +144,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py new file mode 100644 index 000000000..e5826add7 --- /dev/null +++ b/tests/asyncio/client.py @@ -0,0 +1,33 @@ +import contextlib + +from websockets.asyncio.client import * +from websockets.asyncio.server import WebSocketServer + +from .server import get_server_host_port + + +__all__ = [ + "run_client", + "run_unix_client", +] + + +@contextlib.asynccontextmanager +async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl" in kwargs + protocol = "wss" if secure else "ws" + host, port = get_server_host_port(wsuri_or_server) + wsuri = f"{protocol}://{host}:{port}{resource_name}" + async with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.asynccontextmanager +async def run_unix_client(path, **kwargs): + async with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py new file mode 100644 index 000000000..ad1c121bf --- /dev/null +++ b/tests/asyncio/connection.py @@ -0,0 +1,115 @@ +import asyncio +import contextlib + +from websockets.asyncio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def connection_made(self, transport): + super().connection_made(InterceptingTransport(transport)) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write is None + self.transport.delay_write = delay + try: + yield + finally: + self.transport.delay_write = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write_eof is None + self.transport.delay_write_eof = delay + try: + yield + finally: + self.transport.delay_write_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write + self.transport.drop_write = True + try: + yield + finally: + self.transport.drop_write = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write_eof + self.transport.drop_write_eof = True + try: + yield + finally: + self.transport.drop_write_eof = False + + +class InterceptingTransport: + """ + Transport wrapper that intercepts calls to ``write()`` and ``write_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + Since ``write()`` and ``write_eof()`` are not coroutines, this effect is + achieved by scheduling writes at a later time, after the methods return. + This can easily result in out-of-order writes, which is unrealistic. + + """ + + def __init__(self, transport): + self.loop = asyncio.get_running_loop() + self.transport = transport + self.delay_write = None + self.delay_write_eof = None + self.drop_write = False + self.drop_write_eof = False + + def __getattr__(self, name): + return getattr(self.transport, name) + + def write(self, data): + if not self.drop_write: + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + else: + self.transport.write(data) + + def write_eof(self): + if not self.drop_write_eof: + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + else: + self.transport.write_eof() diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py new file mode 100644 index 000000000..0fe20dc65 --- /dev/null +++ b/tests/asyncio/server.py @@ -0,0 +1,50 @@ +import asyncio +import contextlib +import socket + +from websockets.asyncio.server import * + + +def get_server_host_port(server): + for sock in server.sockets: + if sock.family == socket.AF_INET: # pragma: no branch + return sock.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +async def eval_shell(ws): + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) + + +async def crash(ws): + raise RuntimeError + + +async def do_nothing(ws): + pass + + +async def keep_running(ws): + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + + +@contextlib.asynccontextmanager +async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + async with serve(handler, host, port, **kwargs) as server: + yield server + + +@contextlib.asynccontextmanager +async def run_unix_server(path, handler=eval_shell, **kwargs): + async with unix_serve(handler, path, **kwargs) as server: + yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py new file mode 100644 index 000000000..aab65cd2e --- /dev/null +++ b/tests/asyncio/test_client.py @@ -0,0 +1,306 @@ +import asyncio +import socket +import ssl +import unittest + +from websockets.asyncio.client import * +from websockets.asyncio.compatibility import TimeoutError +from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.extensions.permessage_deflate import PerMessageDeflate + +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from .client import run_client, run_unix_client +from .server import do_nothing, get_server_host_port, run_server, run_unix_server + + +class ClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with run_server() as server: + with socket.create_connection(get_server_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + async with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with run_client( + server, create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server( + do_nothing, process_response=remove_accept_header + ) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = asyncio.get_running_loop().create_future() + + async def stall_connection(self, request): + await gate + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaises(TimeoutError) as raised: + async with run_client(server, open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + finally: + gate.set_result(None) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_transport() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(ConnectionError) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) + + +class SecureClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate isn't trusted system-wide. + async with run_client(server, secure=True): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # This hostname isn't included in the test certificate. + async with run_client( + server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_secure_uri_without_ssl(self): + """Client rejects no ssl when URI is secure.""" + with self.assertRaises(TypeError) as raised: + await connect("wss://localhost/", ssl=None) + self.assertEqual( + str(raised.exception), + "ssl=None is incompatible with a wss:// URI", + ) + + async def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_connect() + self.assertEqual( + str(raised.exception), + "no path and sock were specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 000000000..2efd4e96d --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,974 @@ +import asyncio +import contextlib +import logging +import socket +import unittest +import uuid +from unittest.mock import patch + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout +from websockets.asyncio.connection import * +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection +from .utils import alist + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + loop = asyncio.get_running_loop() + socket_, remote_socket = socket.socketpair() + self.transport, self.connection = await loop.create_connection( + lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + sock=socket_, + ) + self.remote_transport, self.remote_connection = await loop.create_connection( + lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + sock=remote_socket, + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + # Run the event loop twice for consistency with assertFrameSent. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + aiterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(aiterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnnectionClosedError after an error.""" + aiterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(aiterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_encoded_text(self): + """recv receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_decoded_binary(self): + """recv receives an UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be cancelled before receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv cannot be cancelled after receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the complete message. + gate.set_result(None) + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_encoded_text(self): + """recv_streaming receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_decoded_binary(self): + """recv_streaming receives a UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be cancelled before receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be cancelled after receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + gate.set_result(None) + # Running recv_streaming again fails. + with self.assertRaises(RuntimeError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_while_send_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an Iterable. + self.connection.pause_writing() + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_while_send_async_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + self.connection.pause_writing() + + async def fragments(): + yield "⏳" + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + gate.set_result(None) + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.drop_eof_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + async def test_close_does_not_wait_for_recv(self): + # The asyncio implementation has a buffer for incoming messages. Closing + # the connection discards buffered messages. This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. + await self.remote_connection.send("😀") + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(MS) + await self.connection.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await recv_task + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + + asyncio.create_task(self.connection.close()) + await asyncio.sleep(MS) + + gate.set_result(None) + + with self.assertRaises(ConnectionClosedError) as raised: + await send_task + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + wait_closed_task = asyncio.create_task(self.connection.wait_closed()) + await asyncio.sleep(0) # let the event loop start wait_closed_task + self.assertFalse(wait_closed_task.done()) + await self.connection.close() + self.assertTrue(wait_closed_task.done()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("this") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + async with asyncio_timeout(MS): + await pong_waiter + + await self.connection.ping("idem") # doesn't raise an exception + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) + ) + async def test_local_address(self, get_extra_info): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + get_extra_info.assert_called_with("sockname") + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) + ) + async def test_remote_address(self, get_extra_info): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + get_extra_info.assert_called_with("peername") + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + async def test_writing_in_data_received_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + async def test_unexpected_failure_in_data_received(self, events_received): + """Unexpected internal error in data_received() is correctly reported.""" + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..c8a2d7cd5 --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,471 @@ +import asyncio +import unittest +import unittest.mock + +from websockets.asyncio.compatibility import aiter, anext +from websockets.asyncio.messages import * +from websockets.asyncio.messages import SimpleQueue +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame + +from .utils import alist + + +class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.queue = SimpleQueue() + + async def test_len(self): + """__len__ returns queue length.""" + self.assertEqual(len(self.queue), 0) + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + await self.queue.get() + self.assertEqual(len(self.queue), 0) + + async def test_put_then_get(self): + """get returns an item that is already put.""" + self.queue.put(42) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_get_then_put(self): + """get returns an item when it is put.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.put(42) + item = await getter_task + self.assertEqual(item, 42) + + async def test_get_concurrently(self): + """get cannot be called concurrently with itself.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + with self.assertRaises(RuntimeError): + await self.queue.get() + getter_task.cancel() + + async def test_reset(self): + """reset sets the content of the queue.""" + self.queue.reset([42]) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_abort(self): + """abort throws an exception in get.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.abort() + with self.assertRaises(EOFError): + await getter_task + + async def test_abort_clears_queue(self): + """abort clears buffered data from the queue.""" + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + self.queue.abort() + self.assertEqual(len(self.queue), 0) + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(pause=self.pause, resume=self.resume) + self.assembler.set_limits(low=1, high=2) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await getter_task + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await getter_task + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + # Test getting and setting limits + + async def test_get_limits(self): + """get_limits returns low and high water marks.""" + low, high = self.assembler.get_limits() + self.assertEqual(low, 1) + self.assertEqual(high, 2) + + async def test_set_limits(self): + """set_limits changes low and high water marks.""" + self.assembler.set_limits(low=2, high=4) + low, high = self.assembler.get_limits() + self.assertEqual(low, 2) + self.assertEqual(high, 4) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py new file mode 100644 index 000000000..2e59f49b1 --- /dev/null +++ b/tests/asyncio/test_server.py @@ -0,0 +1,525 @@ +import asyncio +import dataclasses +import http +import logging +import socket +import unittest + +from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout +from websockets.asyncio.server import * +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( + EvalShellMixin, + crash, + do_nothing, + eval_shell, + get_server_host_port, + keep_running, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server(do_nothing) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server(crash) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) + + async def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket.create_server(("localhost", 0)) as sock: + async with run_server(sock=sock, host=None, port=None): + uri = "ws://{}:{}/".format(*sock.getsockname()) + async with run_client(uri) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + async with run_client(server, subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request(self): + """Server runs async process_request before processing the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_abort_handshake(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response(self): + """Server runs async process_response after processing the handshake.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_override_response(self): + """Server runs process_response and overrides the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_async_process_response_override_response(self): + """Server runs async process_response and overrides the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with run_client(server) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while ws.server.is_serving(): + await asyncio.sleep(0) + + async with run_server(process_request=process_request) as server: + asyncio.get_running_loop().call_later(MS, server.close) + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close(close_connections=False) + + # Server cannot receive new connections. + await asyncio.sleep(0) + self.assertFalse(server.sockets) + + # The server waits for the client to close the connection. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + # Once the client closes the connection, the server terminates. + await client.close() + async with asyncio_timeout(MS): + await server.wait_closed() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server(keep_running) as server: + async with run_client(server) as client: + # Delay termination of connection handler. + await client.send(str(2 * MS)) + + server.close() + + # The server waits for the connection handler to terminate. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + async with asyncio_timeout(2 * MS): + await server.wait_closed() + + +SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" + + +class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "path was not specified, and no sock specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): + async def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py new file mode 100644 index 000000000..a611bfc4b --- /dev/null +++ b/tests/asyncio/utils.py @@ -0,0 +1,5 @@ +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c3f22156..b5c5d726a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -14,6 +14,7 @@ import urllib.request import warnings +from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -30,7 +31,6 @@ from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * -from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index bc4a44e8c..83686c3d3 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -52,8 +52,9 @@ def get_mapping(src_dir="src"): os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) if "legacy" not in os.path.dirname(src_file) - if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "async_timeout.py" and os.path.basename(src_file) != "compatibility.py" ] test_files = [ @@ -102,7 +103,8 @@ def get_ignored_files(src_dir="src"): "*/websockets/typing.py", # We don't test compatibility modules with previous versions of Python # or websockets (import locations). - "*/websockets/*/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. diff --git a/tests/sync/connection.py b/tests/sync/connection.py index 89d4909ee..9c8bacea0 100644 --- a/tests/sync/connection.py +++ b/tests/sync/connection.py @@ -8,7 +8,7 @@ class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. - By interfacing with this connection, you can simulate network conditions + By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ @@ -80,7 +80,7 @@ def drop_eof_sent(self): class InterceptingSocket: """ - Socket wrapper that intercepts calls to sendall and shutdown. + Socket wrapper that intercepts calls to ``sendall()`` and ``shutdown()``. This is coupled to the implementation, which relies on these two methods. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 953c8c253..88cbcd669 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -246,13 +246,15 @@ def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_during_recv(self): """recv_streaming raises RuntimeError when called concurrently with recv.""" @@ -260,7 +262,8 @@ def test_recv_streaming_during_recv(self): recv_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " @@ -278,7 +281,8 @@ def test_recv_streaming_during_recv_streaming(self): recv_streaming_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another thread " @@ -374,7 +378,7 @@ def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" self.connection.send([]) self.connection.close() - self.assertEqual(list(iter(self.remote_connection)), []) + self.assertEqual(list(self.remote_connection), []) def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" @@ -437,7 +441,7 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" - with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: @@ -464,6 +468,10 @@ def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) def test_close_waits_for_recv(self): + # The sync implementation doesn't have a buffer for incoming messsages. + # It requires reading incoming frames until the close frame is reached. + # This behavior — close() blocks until recv() is called — is less than + # ideal and inconsistent with the asyncio implementation. self.remote_connection.send("😀") close_thread = threading.Thread(target=self.connection.close) @@ -547,6 +555,25 @@ def closer(): close_thread.join() + def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + def closer(): + time.sleep(MS) + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + close_thread.join() + def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() @@ -599,42 +626,45 @@ def test_ping_explicit_binary(self): self.connection.ping(b"ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) - def test_ping_duplicate_payload(self): - """ping rejects the same payload until receiving the pong.""" - with self.remote_connection.protocol_mutex: # block response to ping - pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: - self.connection.ping("idem") - self.assertEqual( - str(raised.exception), - "already waiting for a pong with the same data", - ) - self.assertTrue(pong_waiter.wait(MS)) - self.connection.ping("idem") # doesn't raise an exception - def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") - self.assertFalse(pong_waiter.wait(MS)) self.remote_connection.pong("this") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("that") self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong with the same payload as a later ping.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + self.remote_connection.pong("idem") + self.assertTrue(pong_waiter.wait(MS)) + + self.connection.ping("idem") # doesn't raise an exception + # Test pong. def test_pong(self):