diff --git a/docs/conf.py b/docs/conf.py index 2c621bf4..c6b9ac7d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,7 +82,10 @@ assert PythonDomain.object_types["data"].roles == ("data", "obj") PythonDomain.object_types["data"].roles = ("data", "class", "obj") -intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), +} spelling_show_suggestions = True diff --git a/docs/faq/client.rst b/docs/faq/client.rst index d3f62768..c39e588c 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -81,7 +81,7 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: +Use :func:`connect` as an asynchronous iterator:: from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6a3abe8..d00dcafb 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,7 +116,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.server.broadcast`:: +Then, call :func:`broadcast`:: from websockets.asyncio.server import broadcast @@ -219,6 +219,8 @@ You may route a connection to different handlers depending on the request path:: # No handler for this path; close the connection. return +For more complex routing, you may use :func:`~websockets.asyncio.router.route`. + You may also route the connection based on the first message received from the client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to authenticate the connection before routing it, this is usually more convenient. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1f02a6cd..d7db6167 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -56,6 +56,12 @@ Backwards-incompatible changes See :doc:`keepalive and latency <../topics/keepalive>` for details. +New features +............ + +* Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to + dispatch connections to different handlers depending on the URL. + Improvements ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 49bd6f07..8d8b700f 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -12,6 +12,21 @@ Creating a server .. autofunction:: unix_serve :async: +Routing connections +------------------- + +.. automodule:: websockets.asyncio.router + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Router + +.. currentmodule:: websockets.asyncio.server + Running a server ---------------- @@ -89,7 +104,7 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.server.broadcast +.. autofunction:: broadcast HTTP Basic Authentication ------------------------- @@ -97,4 +112,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.asyncio.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 93b083d2..0da966cc 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -127,6 +127,8 @@ Server +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ + | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index c3d0e8f2..f6a45a65 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -10,6 +10,31 @@ Creating a server .. autofunction:: unix_serve +Routing connections +------------------- + +.. automodule:: websockets.sync.router + +.. autofunction:: route + +.. autofunction:: unix_route + +.. autoclass:: Router + +.. currentmodule:: websockets.sync.server + +Routing connections +------------------- + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Server + + Running a server ---------------- @@ -78,4 +103,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.sync.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/requirements.txt b/docs/requirements.txt index bcd1d711..77c87f4d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio sphinxext-opengraph +werkzeug diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 28a10910..f90aff5b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -12,6 +12,10 @@ "connect", "unix_connect", "ClientConnection", + # .asyncio.router + "route", + "unix_route", + "Router", # .asyncio.server "basic_auth", "broadcast", @@ -79,6 +83,7 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.router import Router, route, unix_route from .asyncio.server import ( Server, ServerConnection, @@ -138,6 +143,10 @@ "connect": ".asyncio.client", "unix_connect": ".asyncio.client", "ClientConnection": ".asyncio.client", + # .asyncio.router + "route": ".asyncio.router", + "unix_route": ".asyncio.router", + "Router": ".asyncio.router", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", diff --git a/src/websockets/asyncio/router.py b/src/websockets/asyncio/router.py new file mode 100644 index 00000000..cd95022c --- /dev/null +++ b/src/websockets/asyncio/router.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Awaitable, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + async def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return await connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.asyncio.router import route + from werkzeug.routing import Map, Rule + + async def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with route(url_map, ...) as server: + await stop + + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] = router.route_request + else: + + async def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if isinstance(response, Awaitable): + response = await response + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.asyncio.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 2e2b7878..ec7fc438 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -9,7 +9,7 @@ import sys from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -87,6 +87,8 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/src/websockets/sync/router.py b/src/websockets/sync/router.py new file mode 100644 index 00000000..33105bf3 --- /dev/null +++ b/src/websockets/sync/router.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.sync.router import route + from werkzeug.routing import Map, Rule + + def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + with route(url_map, ...) as server: + server.serve_forever() + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Response | None, + ] = router.route_request + else: + + def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.sync.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10e3b681..efb40a7f 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -82,6 +82,8 @@ def __init__( max_queue=max_queue, ) self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], None] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index acf6500c..b142bcd7 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -1,5 +1,6 @@ import asyncio import socket +import urllib.parse def get_host_port(server): @@ -9,15 +10,16 @@ def get_host_port(server): raise AssertionError("expected at least one IPv4 socket") -def get_uri(server): - secure = server.server._ssl_context is not None # hack +def get_uri(server, secure=None): + if secure is None: + secure = server.server._ssl_context is not None # hack protocol = "wss" if secure else "ws" host, port = get_host_port(server) return f"{protocol}://{host}:{port}" async def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. async for expr in ws: diff --git a/tests/asyncio/test_router.py b/tests/asyncio/test_router.py new file mode 100644 index 00000000..1426cc9f --- /dev/null +++ b/tests/asyncio/test_router.py @@ -0,0 +1,198 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.asyncio.client import connect, unix_connect +from websockets.asyncio.router import * +from websockets.exceptions import InvalidStatus + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler +from .utils import alist + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +async def echo(websocket, count): + message = await websocket.recv() + for _ in range(count): + await websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + async def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + async with route(url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/echo") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello"]) + + async with connect(get_uri(server) + "/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + async def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/?a=b") as client: + await self.assertEval(client, "ws.request.path", "/?a=b") + + async def test_redirect(self): + """Router redirects connections according to redirect_to.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/r") as client: + await self.assertEval(client, "ws.request.path", "/") + + async def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.request.path", "/") + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + async with route(self.url_map, "localhost", 0, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + async with route(self.url_map, "localhost", 0, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + async def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + async with route(self.url_map, "localhost", 0) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + async def test_process_request_function_returning_none(self): + """Router supports a process_request function returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_coroutine_returning_none(self): + """Router supports a process_request coroutine returning None.""" + + async def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_function_returning_response(self): + """Router supports a process_request function returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_coroutine_returning_response(self): + """Router supports a process_request coroutine returning a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + async def handler(self, connection): + connection.my_router_ran = True + return await super().handler(connection) + + async with route( + self.url_map, "localhost", 0, create_router=MyRouter + ) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + async with unix_route(url_map, path): + async with unix_connect(path, "ws://localhost/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/sync/server.py b/tests/sync/server.py index fd7a03d8..cadaa267 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,19 +1,22 @@ import contextlib import ssl import threading +import urllib.parse +from websockets.sync.router import * from websockets.sync.server import * -def get_uri(server): - secure = isinstance(server.socket, ssl.SSLSocket) # hack +def get_uri(server, secure=None): + if secure is None: + secure = isinstance(server.socket, ssl.SSLSocket) # hack protocol = "wss" if secure else "ws" host, port = server.socket.getsockname() return f"{protocol}://{host}:{port}" def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. for expr in ws: @@ -34,8 +37,14 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(handler=handler, host="localhost", port=0, **kwargs): - with serve(handler, host, port, **kwargs) as server: +def run_server_or_router( + serve_or_route, + handler_or_url_map, + host="localhost", + port=0, + **kwargs, +): + with serve_or_route(handler_or_url_map, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() @@ -63,9 +72,22 @@ def handler(sock, addr): handler_thread.join() +def run_server(handler=handler, **kwargs): + return run_server_or_router(serve, handler, **kwargs) + + +def run_router(url_map, **kwargs): + return run_server_or_router(route, url_map, **kwargs) + + @contextlib.contextmanager -def run_unix_server(path, handler=handler, **kwargs): - with unix_serve(handler, path, **kwargs) as server: +def run_unix_server_or_router( + path, + unix_serve_or_route, + handler_or_url_map, + **kwargs, +): + with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -73,3 +95,11 @@ def run_unix_server(path, handler=handler, **kwargs): finally: server.shutdown() thread.join() + + +def run_unix_server(path, handler=handler, **kwargs): + return run_unix_server_or_router(path, unix_serve, handler, **kwargs) + + +def run_unix_router(path, url_map, **kwargs): + return run_unix_server_or_router(path, unix_route, url_map, **kwargs) diff --git a/tests/sync/test_router.py b/tests/sync/test_router.py new file mode 100644 index 00000000..07274e62 --- /dev/null +++ b/tests/sync/test_router.py @@ -0,0 +1,174 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.exceptions import InvalidStatus +from websockets.sync.client import connect, unix_connect +from websockets.sync.router import * + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler, run_router, run_unix_router + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +def echo(websocket, count): + message = websocket.recv() + for _ in range(count): + websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.TestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + with run_router(url_map) as server: + with connect(get_uri(server) + "/echo") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello"]) + + with connect(get_uri(server) + "/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + with run_router(self.url_map) as server: + with connect(get_uri(server) + "/?a=b") as client: + self.assertEval(client, "ws.request.path", "/?a=b") + + def test_redirect(self): + """Router redirects connections according to redirect_to.""" + with run_router(self.url_map, server_name="localhost") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://localhost/", + ) + + def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + with run_router( + self.url_map, server_name="localhost", ssl=SERVER_CONTEXT + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "wss://localhost/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + with run_router(self.url_map, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + with run_router(self.url_map, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + with run_router(self.url_map) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + def test_process_request_returning_none(self): + """Router supports a process_request returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + with run_router(self.url_map, process_request=process_request) as server: + with connect(get_uri(server) + "/") as client: + self.assertEval(client, "ws.process_request_ran", "True") + + def test_process_request_returning_response(self): + """Router supports a process_request returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + with run_router(self.url_map, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + def handler(self, connection): + connection.my_router_ran = True + return super().handler(connection) + + with run_router(self.url_map, create_router=MyRouter) as server: + with connect(get_uri(server)) as client: + self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + with run_unix_router(path, url_map): + with unix_connect(path, "ws://localhost/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/test_exports.py b/tests/test_exports.py index 88e27e69..34a47066 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -2,6 +2,7 @@ import websockets import websockets.asyncio.client +import websockets.asyncio.router import websockets.asyncio.server import websockets.client import websockets.datastructures @@ -16,6 +17,7 @@ for name in ( [] + websockets.asyncio.client.__all__ + + websockets.asyncio.router.__all__ + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ diff --git a/tox.ini b/tox.ini index 918aeaae..9450e971 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] + werkzeug [testenv:coverage] commands = @@ -47,3 +48,4 @@ commands = deps = mypy python-socks + werkzeug