diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0b80f087..e8cd6be6 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -26,6 +26,18 @@ clients concurrently. asyncio/server asyncio/client +:mod:`asyncio` (new) +-------------------- + +This is a rewrite of the :mod:`asyncio` implementation. It will become the +default implementation. + +.. toctree:: + :titlesonly: + + new-asyncio/server + new-asyncio/client + :mod:`threading` ---------------- diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst new file mode 100644 index 00000000..b622a8b1 --- /dev/null +++ b/docs/reference/new-asyncio/client.rst @@ -0,0 +1,51 @@ +Client (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: unix_connect + :async: + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst new file mode 100644 index 00000000..86ae4a6d --- /dev/null +++ b/docs/reference/new-asyncio/common.rst @@ -0,0 +1,41 @@ +:orphan: + +Both sides (:mod:`asyncio` - new) +================================= + +.. automodule:: websockets.asyncio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst new file mode 100644 index 00000000..25fa3a4b --- /dev/null +++ b/docs/reference/new-asyncio/server.rst @@ -0,0 +1,62 @@ +Server (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.server + +.. Creating a server +.. ----------------- + +.. .. autofunction:: serve +.. :async: + +.. .. autofunction:: unix_serve +.. :async: + +.. Running a server +.. ---------------- + +.. .. autoclass:: WebSocketServer + +.. .. automethod:: serve_forever + +.. .. automethod:: shutdown + +.. .. automethod:: fileno + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py new file mode 100644 index 00000000..0706b619 --- /dev/null +++ b/src/websockets/asyncio/client.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Generator, Sequence + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + # May raise CancelledError if open_timeout is exceeded. + await self.response_rcvd + + if self.response is None: + raise ConnectionError("connection closed during handshake") + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.response_rcvd.done(): + self.response_rcvd.set_result(None) + + +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as a context manager:: + + async with websockets.asyncio.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + + wsuri = parse_uri(uri) + + if wsuri.secure: + if kwargs.get("ssl") is None: + kwargs["ssl"] = True + kwargs.setdefault("server_hostname", wsuri.host) + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + + def factory() -> ClientConnection: + # This is a protocol in websockets. + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_connection = loop.create_unix_connection(factory, **kwargs) + else: + if kwargs.get("sock") is None: + kwargs.setdefault("host", wsuri.host) + kwargs.setdefault("port", wsuri.port) + self._create_connection = loop.create_connection(factory, **kwargs) + + self._handshake_args = ( + additional_headers, + user_agent_header, + ) + + self._open_timeout = open_timeout + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self._open_timeout): + _transport, self.connection = await self._create_connection + try: + await self.connection.handshake(*self._handshake_args) + except (Exception, asyncio.CancelledError): + self.connection.transport.close() + raise + else: + return self.connection + except TimeoutError: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during handshake") from None + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 390f00ac..e1700006 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,14 +3,17 @@ import sys -__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] if sys.version_info[:2] >= (3, 11): TimeoutError = TimeoutError aiter = aiter anext = anext - from asyncio import timeout as asyncio_timeout + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) else: # Python < 3.11 from asyncio import TimeoutError @@ -21,4 +24,7 @@ def aiter(async_iterable): async def anext(async_iterator): return await type(async_iterator).__anext__(async_iterator) - from .async_timeout import timeout as asyncio_timeout + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py new file mode 100644 index 00000000..26fb7789 --- /dev/null +++ b/src/websockets/asyncio/connection.py @@ -0,0 +1,873 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import uuid +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Iterable, + Mapping, + cast, +) + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.fragmented_send_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get() + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, it will be impossible to read new + messages: future calls to :meth:`recv` or :meth:`recv_streaming` will + raise :exc:`RuntimeError`. This makes the connection unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(): + yield frame + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If the connection busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.fragmented_send_waiter is not None: + await asyncio.shield(self.fragmented_send_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + self.protocol.send_text(message.encode("utf-8")) + + elif isinstance(message, BytesLike): + async with self.send_context(): + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + else: + raise TypeError("data must be bytes, str, or iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.fragmented_send_waiter is not None: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def ping(self, data: Data | None = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pong_waiters: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pong_waiters: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). We cannot use time.perf_counter() + # because it doesn't count time elapsed while the process sleeps. + ping_timestamp = self.loop.time() + self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.protocol.send_ping(data) + return pong_waiter + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pong_waiters: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + ping_ids.append(ping_id) + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pong_waiters. + for ping_id in ping_ids: + del self.pong_waiters[ping_id] + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_waiter, _ping_timestamp in self.pong_waiters.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + self.pong_waiters.clear() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async async with self.send_context(): + self.protocol.send_text(message.encode("utf-8")) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_transport() + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + Raises: + OSError: When a socket operations fails. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + def close_transport(self) -> None: + """ + Close transport and message assembler. + + """ + self.transport.close() + self.recv_messages.close() + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.recv_messages = Assembler( + pause=self.transport.pause_reading, + resume=self.transport.resume_reading, + ) + + def connection_lost(self, exc: Exception | None) -> None: + self.protocol.receive_eof() # receive_eof is idempotent + self.recv_messages.close() + self.set_recv_exc(exc) + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + self.abort_pings() + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + self.set_recv_exc(exc) + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it shouldn't raise errors. + self.send_data() + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py new file mode 100644 index 00000000..68d132df --- /dev/null +++ b/src/websockets/asyncio/server.py @@ -0,0 +1,771 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import socket +import sys +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + Sequence, +) + +from websockets.frames import CloseCode + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + server: :class:`WebSocketServer` that created this connection. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: WebSocketServer, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + # May raise CancelledError if open_timeout is exceeded. + await self.request_rcvd + + if self.request is None: + raise ConnectionError("connection closed during handshake") + + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + self.response = self.protocol.accept(self.request) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.request_rcvd.done(): + self.request_rcvd.set_result(None) + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.connections: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + # Initialized here because we need a reference to the event loop. + # This should be moved back to __init__ when dropping Python < 3.10. + self.closed_waiter = server.get_loop().create_future() + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + # On failure, handshake() closes the transport, raises an + # exception, and logs it. + async with asyncio_timeout(self.open_timeout): + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + + try: + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except Exception: + # Don't leak connections on errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.connections[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.connections immediately. + # If it was registered in conn_handler(), a race condition would be + # possible when closing the server after scheduling conn_handler() + # but before it starts executing. + self.connections[connection] = self.loop.create_task( + self.conn_handler(connection) + ) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(connection.close(1001)) + for connection in self.connections + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.connections: + await asyncio.wait(self.connections.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler``. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`WebSocketServer` whose API mirrors + :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + ensure that it will be closed:: + + def handler(websocket): + ... + + # set this future to exit the server stop = + asyncio.get_running_loop().create_future() + + async with websockets.asyncio.server.serve(handler, host, port): + await stop + + Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests + and cancel it to stop the server:: + + server = await websockets.asyncio.server.serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable client + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~WebSocketServer.start_serving()` or + :meth:`~WebSocketServer.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = WebSocketServer( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[WebSocketServer]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py new file mode 100644 index 00000000..e5826add --- /dev/null +++ b/tests/asyncio/client.py @@ -0,0 +1,33 @@ +import contextlib + +from websockets.asyncio.client import * +from websockets.asyncio.server import WebSocketServer + +from .server import get_server_host_port + + +__all__ = [ + "run_client", + "run_unix_client", +] + + +@contextlib.asynccontextmanager +async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl" in kwargs + protocol = "wss" if secure else "ws" + host, port = get_server_host_port(wsuri_or_server) + wsuri = f"{protocol}://{host}:{port}{resource_name}" + async with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.asynccontextmanager +async def run_unix_client(path, **kwargs): + async with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py new file mode 100644 index 00000000..4ca82d1c --- /dev/null +++ b/tests/asyncio/connection.py @@ -0,0 +1,111 @@ +import asyncio +import contextlib + +from websockets.asyncio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def connection_made(self, transport): + super().connection_made(InterceptingTransport(transport)) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write is None + self.transport.delay_write = delay + try: + yield + finally: + self.transport.delay_write = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write_eof is None + self.transport.delay_write_eof = delay + try: + yield + finally: + self.transport.delay_write_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write + self.transport.drop_write = True + try: + yield + finally: + self.transport.drop_write = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write_eof + self.transport.drop_write_eof = True + try: + yield + finally: + self.transport.drop_write_eof = False + + +class InterceptingTransport: + """ + Transport wrapper that intercepts calls to write and write_eof. + + This is coupled to the implementation, which relies on these two methods. + + """ + + def __init__(self, transport): + self.loop = asyncio.get_running_loop() + self.transport = transport + self.delay_write = None + self.delay_write_eof = None + self.drop_write = False + self.drop_write_eof = False + + def __getattr__(self, name): + return getattr(self.transport, name) + + def write(self, data): + if not self.drop_write: + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + else: + self.transport.write(data) + + def write_eof(self): + if not self.drop_write_eof: + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + else: + self.transport.write_eof() diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py new file mode 100644 index 00000000..1b9917e5 --- /dev/null +++ b/tests/asyncio/server.py @@ -0,0 +1,50 @@ +import asyncio +import contextlib +import socket + +from websockets.asyncio.server import * + + +def get_server_host_port(server): + for sock in server.sockets: + if sock.family == socket.AF_INET: # pragma: no branch + return sock.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +async def crash(ws): + raise RuntimeError + + +async def do_nothing(ws): + pass + + +async def eval_shell(ws): + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + + +async def keep_running(ws): + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) + + +@contextlib.asynccontextmanager +async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + async with serve(handler, host, port, **kwargs) as server: + yield server + + +@contextlib.asynccontextmanager +async def run_unix_server(path, handler=eval_shell, **kwargs): + async with unix_serve(handler, path, **kwargs) as server: + yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py new file mode 100644 index 00000000..c5b3c212 --- /dev/null +++ b/tests/asyncio/test_client.py @@ -0,0 +1,306 @@ +import asyncio +import socket +import ssl +import unittest + +from websockets.asyncio.client import * +from websockets.asyncio.compatibility import TimeoutError +from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.extensions.permessage_deflate import PerMessageDeflate + +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from .client import run_client, run_unix_client +from .server import do_nothing, get_server_host_port, run_server, run_unix_server + + +class ClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with run_server() as server: + with socket.create_connection(get_server_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + async with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with run_client( + server, create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server( + do_nothing, process_response=remove_accept_header + ) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = asyncio.get_running_loop().create_future() + + async def stall_connection(self, request): + await gate + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaises(TimeoutError) as raised: + async with run_client(server, open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + finally: + gate.set_result(None) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_transport() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(ConnectionError) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) + + +class SecureClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate isn't trusted system-wide. + async with run_client(server, secure=True): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # This hostname isn't included in the test certificate. + async with run_client( + server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_secure_uri_without_ssl(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_connect() + self.assertEqual( + str(raised.exception), + "no path and sock were specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 00000000..9d1707e8 --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,938 @@ +import asyncio +import contextlib +import logging +import socket +import unittest +import uuid +from unittest.mock import patch + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout +from websockets.asyncio.connection import * +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection +from .utils import alist + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + loop = asyncio.get_running_loop() + socket_, remote_socket = socket.socketpair() + self.transport, self.connection = await loop.create_connection( + lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + sock=socket_, + ) + self.remote_transport, self.remote_connection = await loop.create_connection( + lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + sock=remote_socket, + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + # Run the event loop twice for consistency with assertFrameSent. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + aiterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(aiterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnnectionClosedError after an error.""" + aiterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(aiterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be cancelled before receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv cannot be cancelled after receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the complete message. + gate.set_result(None) + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be cancelled before receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be cancelled after receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + gate.set_result(None) + # Running recv_streaming again fails. + with self.assertRaises(RuntimeError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_while_send_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an Iterable. + self.connection.pause_writing() + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_while_send_async_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + self.connection.pause_writing() + + async def fragments(): + yield "⏳" + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + gate.set_result(None) + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.drop_eof_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + async def test_close_does_not_wait_for_recv(self): + # The asyncio implementation has a buffer for incoming messages. Closing + # the connection discards buffered messages. This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. + await self.remote_connection.send("😀") + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(MS) + await self.connection.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await recv_task + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + + asyncio.create_task(self.connection.close()) + await asyncio.sleep(MS) + + gate.set_result(None) + + with self.assertRaises(ConnectionClosedError) as raised: + await send_task + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("this") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + async with asyncio_timeout(MS): + await pong_waiter + + await self.connection.ping("idem") # doesn't raise an exception + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) + ) + async def test_local_address(self, get_extra_info): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + get_extra_info.assert_called_with("sockname") + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) + ) + async def test_remote_address(self, get_extra_info): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + get_extra_info.assert_called_with("peername") + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + async def test_writing_in_data_received_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + async def test_unexpected_failure_in_data_received(self, events_received): + """Unexpected internal error in data_received() is correctly reported.""" + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py new file mode 100644 index 00000000..f0fdea35 --- /dev/null +++ b/tests/asyncio/test_server.py @@ -0,0 +1,510 @@ +import asyncio +import dataclasses +import http +import logging +import socket +import unittest + +from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout +from websockets.asyncio.server import * +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( + EvalShellMixin, + crash, + do_nothing, + eval_shell, + get_server_host_port, + keep_running, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server(do_nothing) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server(crash) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) + + async def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket.create_server(("localhost", 0)) as sock: + async with run_server(sock=sock, host=None, port=None): + uri = "ws://{}:{}/".format(*sock.getsockname()) + async with run_client(uri) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + async with run_client(server, subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request(self): + """Server runs async process_request before processing the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_abort_handshake(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response(self): + """Server runs async process_response after processing the handshake.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_override_response(self): + """Server runs process_response and overrides the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_async_process_response_override_response(self): + """Server runs async process_response and overrides the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with run_client(server) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close connections.""" + async with run_server() as server: + async with run_client(server) as client: + server.close(close_connections=False) + + # Server cannot receive new connections. + await asyncio.sleep(0) + self.assertFalse(server.sockets) + + # The server waits for the client to close the connection. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + # Once the client closes the connection, the server terminates. + await client.close() + async with asyncio_timeout(MS): + await server.wait_closed() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server(keep_running) as server: + async with run_client(server) as client: + # Delay termination of connection handler. + await client.send(str(2 * MS)) + + server.close() + + # The server waits for the connection handler to terminate. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + async with asyncio_timeout(2 * MS): + await server.wait_closed() + + +SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" + + +class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "path was not specified, and no sock specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): + async def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + + +# async def test_fileno(self): +# """WebSocketServer provides a fileno attribute.""" +# async with run_server() as server: +# self.assertIsInstance(server.fileno(), int) + +# async def test_shutdown(self): +# """WebSocketServer provides a shutdown method.""" +# async with run_server() as server: +# server.shutdown() +# # Check that the server socket is closed. +# with self.assertRaises(OSError): +# server.socket.accept()