From f99cbbd42a58e72060308288c822fc910fff99e4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 21:21:49 +0100 Subject: [PATCH] Add support for HTTP(S) proxies. Fix #364. --- docs/project/changelog.rst | 4 +- docs/reference/features.rst | 3 +- docs/topics/proxies.rst | 21 +++- src/websockets/asyncio/client.py | 154 ++++++++++++++++++++++- src/websockets/sync/client.py | 204 ++++++++++++++++++++++++++++++- tests/asyncio/test_client.py | 200 +++++++++++++++++++++++++++++- tests/proxy.py | 38 +++++- tests/sync/test_client.py | 185 ++++++++++++++++++++++++++++ tests/utils.py | 7 +- 9 files changed, 796 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfbfa793..7bb94b34 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,12 +35,12 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: Client connections use SOCKS proxies automatically. +.. admonition:: Client connections use SOCKS and HTTP proxies automatically. :class: important If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. - This feature requires installing the third-party library `python-socks`_. + SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eaecd02a..93b083d2 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -166,12 +166,11 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index fd3ae78b..63a3c3b1 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -18,7 +18,7 @@ Then, it looks for a proxy in the following locations: ``ws://`` connections respectively. They allow configuring a specific proxy for WebSocket connections. 2. A SOCKS proxy configured in the operating system. -3. An HTTP proxy configured in the operating system or in the ``https_proxy`` +3. An HTTPS proxy configured in the operating system or in the ``https_proxy`` environment variable, for both ``wss://`` and ``ws://`` connections. 4. An HTTP proxy configured in the operating system or in the ``http_proxy`` environment variable, only for ``ws://`` connections. @@ -30,6 +30,9 @@ most common, for `historical reasons`_, and recommended. .. _historical reasons: https://unix.stackexchange.com/questions/212894/ +websockets authenticates automatically when the address of the proxy includes +credentials e.g. ``http://user:password@proxy:8080/``. + .. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. :class: tip @@ -64,3 +67,19 @@ SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) but does not support other authentication methods such as GSSAPI (:rfc:`1961`). + +HTTP proxies +------------ + +When the address of the proxy starts with ``https://``, websockets secures the +connection to the proxy with TLS. + +When the address of the server starts with ``wss://``, websockets secures the +connection from the proxy to the server with TLS. + +These two options are compatible. TLS-in-TLS is supported. + +The documentation of :func:`~asyncio.client.connect` describes how to configure +TLS from websockets to the proxy and from the proxy to the server. + +websockets supports proxy authentication with Basic Auth. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index a3fcab03..348c5c31 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,20 +4,29 @@ import logging import os import socket +import ssl as ssl_module import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidMessage, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout @@ -266,6 +275,16 @@ class connect: :meth:`~asyncio.loop.create_connection` method) to create a suitable client socket and customize it. + When using a proxy: + + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. + * Use the standard keyword arguments for configuring TLS between the proxy + and the WebSocket server: ``ssl``, ``server_hostname``, + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. + * Other keyword arguments are used only for connecting to the proxy. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. InvalidProxy: If ``proxy`` isn't a valid proxy. @@ -397,6 +416,47 @@ def factory() -> ClientConnection: sock=sock, **kwargs, ) + elif proxy_parsed.scheme[:4] == "http": + # Split keyword arguments between the proxy and the server. + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} + for key, value in all_kwargs.items(): + if key.startswith("ssl") or key == "server_hostname": + kwargs[key] = value + elif key.startswith("proxy_"): + proxy_kwargs[key[6:]] = value + else: + proxy_kwargs[key] = value + # Validate the proxy_ssl argument. + if proxy_parsed.scheme == "https": + proxy_kwargs.setdefault("ssl", True) + if proxy_kwargs.get("ssl") is None: + raise ValueError( + "proxy_ssl=None is incompatible with an https:// proxy" + ) + else: + if proxy_kwargs.get("ssl") is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + transport = await connect_http_proxy( + proxy_parsed, + ws_uri, + **proxy_kwargs, + ) + # Initialize WebSocket connection via the proxy. + connection = factory() + transport.set_protocol(connection) + ssl = kwargs.pop("ssl", None) + if ssl is True: + ssl = ssl_module.create_default_context() + if ssl is not None: + new_transport = await loop.start_tls( + transport, connection, ssl, **kwargs + ) + assert new_transport is not None # help mypy + transport = new_transport + connection.connection_made(transport) else: raise AssertionError("unsupported proxy") else: @@ -655,3 +715,89 @@ async def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +class HTTPProxyConnection(asyncio.Protocol): + def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + self.ws_uri = ws_uri + self.proxy = proxy + + self.reader = StreamReader() + self.parser = Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + include_body=False, + ) + + loop = asyncio.get_running_loop() + self.response: asyncio.Future[Response] = loop.create_future() + + def run_parser(self) -> None: + try: + next(self.parser) + except StopIteration as exc: + response = exc.value + if 200 <= response.status_code < 300: + self.response.set_result(response) + else: + self.response.set_exception(InvalidProxyStatus(response)) + except Exception as exc: + proxy_exc = InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) + proxy_exc.__cause__ = exc + self.response.set_exception(proxy_exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + self.run_parser() + + def eof_received(self) -> None: + self.reader.feed_eof() + self.run_parser() + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.feed_eof() + if exc is not None: + self.response.set_exception(exc) + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, +) -> asyncio.Transport: + transport, protocol = await asyncio.get_running_loop().create_connection( + lambda: HTTPProxyConnection(ws_uri, proxy), + proxy.host, + proxy.port, + **kwargs, + ) + + try: + # This raises exceptions if the connection to the proxy fails. + await protocol.response + except Exception: + transport.close() + raise + + return transport diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 722def31..432036cc 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,16 +5,17 @@ import threading import warnings from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import HeadersLike -from ..exceptions import ProxyError +from ..datastructures import Headers, HeadersLike +from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection @@ -141,6 +142,8 @@ def connect( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -195,6 +198,9 @@ def connect( to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -295,6 +301,21 @@ def connect( # python_socks is consistent across implementations. local_addr=kwargs.pop("source_address", None), ) + elif proxy_parsed.scheme[:4] == "http": + # Validate the proxy_ssl argument. + if proxy_parsed.scheme != "https" and proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + sock = connect_http_proxy( + proxy_parsed, + ws_uri, + deadline, + ssl=proxy_ssl, + server_hostname=proxy_server_hostname, + **kwargs, + ) else: raise AssertionError("unsupported proxy") else: @@ -318,7 +339,12 @@ def connect( if server_hostname is None: server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) - sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + if proxy_ssl is None: + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + else: + sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) + # Let's pretend that sock is a socket, even though it isn't. + sock = cast(socket.socket, sock_2) sock.settimeout(None) # Initialize WebSocket protocol @@ -444,3 +470,171 @@ def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + include_body=False, + ) + try: + while True: + sock.settimeout(deadline.timeout()) + data = sock.recv(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except socket.timeout: + raise TimeoutError("timed out while connecting to HTTP proxy") + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + finally: + sock.settimeout(None) + + +def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + *, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> socket.socket: + # Connect socket + + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((proxy.host, proxy.port), **kwargs) + + # Initialize TLS wrapper and perform TLS handshake + + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + sock.settimeout(deadline.timeout()) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Send CONNECT request to the proxy and read response. + + sock.sendall(prepare_connect_request(proxy, ws_uri)) + try: + read_connect_response(sock, deadline) + except Exception: + sock.close() + raise + + return sock + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., T]) + + +class SSLSSLSocket: + """ + Socket-like object providing TLS-in-TLS. + + Only methods that are used by websockets are implemented. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl_module.SSLContext, + server_hostname: str | None = None, + ) -> None: + self.incoming = ssl_module.MemoryBIO() + self.outgoing = ssl_module.MemoryBIO() + self.ssl_socket = sock + self.ssl_object = ssl_context.wrap_bio( + self.incoming, + self.outgoing, + server_hostname=server_hostname, + ) + self.run_io(self.ssl_object.do_handshake) + + def run_io(self, func: Callable[..., T], *args: Any) -> T: + while True: + want_read = False + want_write = False + try: + result = func(*args) + except ssl_module.SSLWantReadError: + want_read = True + except ssl_module.SSLWantWriteError: # pragma: no cover + want_write = True + + # Write outgoing data in all cases. + data = self.outgoing.read() + if data: + self.ssl_socket.sendall(data) + + # Read incoming data and retry on SSLWantReadError. + if want_read: + data = self.ssl_socket.recv(self.recv_bufsize) + if data: + self.incoming.write(data) + else: + self.incoming.write_eof() + continue + # Retry after writing outgoing data on SSLWantWriteError. + if want_write: # pragma: no cover + continue + # Return result if no error happened. + return result + + def recv(self, buflen: int) -> bytes: + try: + return self.run_io(self.ssl_object.read, buflen) + except ssl_module.SSLEOFError: + return b"" # always ignore ragged EOFs + + def send(self, data: bytes) -> int: + return self.run_io(self.ssl_object.write, data) + + def sendall(self, data: bytes) -> None: + # adapted from ssl_module.SSLSocket.sendall() + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + count += self.send(byte_view[count:]) + + # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the + # flags argument aren't implemented because websockets doesn't need them. + + def __getattr__(self, name: str) -> Any: + return getattr(self.ssl_socket, name) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 8db76710..f3a63688 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -15,6 +15,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -664,6 +665,183 @@ async def test_ignore_proxy_with_existing_socket(self): # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + async def test_http_proxy_http_error(self): + """Client receives an error when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + # TODO: figure out how to put a specific error message. + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -736,7 +914,7 @@ async def test_ssl_without_secure_uri(self): ) async def test_secure_uri_without_ssl(self): - """Client rejects no ssl when URI is secure.""" + """Client rejects ssl=None when URI is secure.""" with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( @@ -744,6 +922,26 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_https_proxy_without_ssl(self): + """Client rejects proxy_ssl=None when proxy is HTTPS.""" + with patch_environ({"https_proxy": "https://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=None) + self.assertEqual( + str(raised.exception), + "proxy_ssl=None is incompatible with an https:// proxy", + ) + async def test_unsupported_proxy(self): """Client rejects unsupported proxy.""" with self.assertRaises(InvalidProxy) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 804f2e9d..f3db4201 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,15 +1,20 @@ import asyncio +import pathlib +import ssl import threading import warnings try: # Ignore deprecation warnings raised by mitmproxy dependencies at import time. + warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + from mitmproxy import ctx from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.http import Response from mitmproxy.master import Master - from mitmproxy.options import Options + from mitmproxy.options import CONF_BASENAME, CONF_DIR, Options except ImportError: pass @@ -30,6 +35,31 @@ def reset_flows(self): self.flows = [] +class FailHttpConnect: + def load(self, loader): + loader.add_option( + name="break_http_connect", + typespec=bool, + default=False, + help="Respond to HTTP CONNECT requests with a 999 status code.", + ) + loader.add_option( + name="close_http_connect", + typespec=bool, + default=False, + help="Do not respond to HTTP CONNECT requests.", + ) + + def http_connect(self, flow): + if ctx.options.break_http_connect: + # mitmproxy can send a response with a status code not between 100 + # and 599, while websockets treats it as a protocol error. + # This is used for testing HTTP parsing errors. + flow.response = Response.make(999, "not a valid HTTP response") + elif ctx.options.close_http_connect: + flow.kill() + + class ProxyMixin: """ Run mitmproxy in a background thread. @@ -60,6 +90,7 @@ async def run_proxy(cls): next_layer.NextLayer(), tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), + FailHttpConnect(), ) task = loop.create_task(cls.proxy_master.run()) @@ -84,6 +115,11 @@ def setUpClass(cls): cls.proxy_thread.start() cls.proxy_ready.wait() + certificate = pathlib.Path(CONF_DIR) / f"{CONF_BASENAME}-ca-cert.pem" + certificate = certificate.expanduser() + cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + cls.proxy_context.load_verify_locations(bytes(certificate)) + def assertNumFlows(self, num_flows): record_flows = self.proxy_master.addons.get("recordflows") self.assertEqual(len(record_flows.get_flows()), num_flows) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index eefdbe33..08b429a5 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -12,6 +12,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -404,6 +405,180 @@ def test_ignore_proxy_with_existing_socket(self): # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + def test_http_proxy_http_error(self): + """Client receives an HTTP error when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out while connecting to HTTP proxy", + ) + + def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_https_proxy_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + self.assertNumFlows(1) + + def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + self.assertNumFlows(0) + + def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when server certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -452,6 +627,16 @@ def test_ssl_without_secure_uri(self): "ssl argument is incompatible with a ws:// URI", ) + def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/utils.py b/tests/utils.py index f68a447b..38938134 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,14 +20,13 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CERTIFICATE = pathlib.Path(__file__).with_name("test_localhost.pem") CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - +CLIENT_CONTEXT.load_verify_locations(bytes(CERTIFICATE)) SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) +SERVER_CONTEXT.load_cert_chain(bytes(CERTIFICATE)) # Work around https://github.com/openssl/openssl/issues/7967