From 6e72e5584e8655f8360738367b5809617c0d62c8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 21:51:10 +0100 Subject: [PATCH] Replace patch_environ with unittest.mock.patch.dict. --- tests/asyncio/test_client.py | 228 +++++++++++++++++------------------ tests/sync/test_client.py | 206 +++++++++++++++---------------- tests/test_uri.py | 6 +- tests/utils.py | 16 --- 4 files changed, 219 insertions(+), 237 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index c6ff26ae..0c04eb65 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -2,10 +2,12 @@ import contextlib import http import logging +import os import socket import ssl import sys import unittest +from unittest.mock import patch from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError @@ -24,13 +26,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..proxy import ProxyMixin -from ..utils import ( - CLIENT_CONTEXT, - MS, - SERVER_CONTEXT, - patch_environ, - temp_unix_socket_path, -) +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler @@ -570,46 +566,44 @@ def redirect(connection, request): class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -619,14 +613,14 @@ async def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -636,7 +630,7 @@ async def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -657,14 +651,14 @@ async def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -672,46 +666,44 @@ async def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) async def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -720,14 +712,14 @@ async def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -736,14 +728,14 @@ async def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -752,12 +744,12 @@ async def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:68080"}) # bad port async def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -766,7 +758,7 @@ async def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -775,67 +767,67 @@ async def test_http_proxy_connection_timeout(self): "timed out during opening handshake", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - async with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - async with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -923,9 +915,12 @@ async def test_secure_uri_without_ssl(self): async def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", @@ -933,9 +928,12 @@ async def test_proxy_ssl_without_https_proxy(self): async def test_https_proxy_without_ssl(self): """Client rejects proxy_ssl=None when proxy is HTTPS.""" - with patch_environ({"https_proxy": "https://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=None) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="https://localhost:8080", + proxy_ssl=None, + ) self.assertEqual( str(raised.exception), "proxy_ssl=None is incompatible with an https:// proxy", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 386caf56..62391ce8 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,5 +1,6 @@ import http import logging +import os import socket import socketserver import ssl @@ -7,6 +8,7 @@ import threading import time import unittest +from unittest.mock import patch from websockets.exceptions import ( InvalidHandshake, @@ -26,7 +28,6 @@ MS, SERVER_CONTEXT, DeprecationTestCase, - patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -310,46 +311,44 @@ def test_reject_invalid_server_hostname(self): class SocksProxyClientTests(ProxyMixin, unittest.TestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -359,14 +358,14 @@ def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -378,7 +377,7 @@ def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -397,14 +396,14 @@ def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"ws_proxy": "http://localhost:58080"}) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"ws_proxy": "http://localhost:58080"}): - with run_server() as server: - with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. - with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -412,45 +411,43 @@ def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -459,14 +456,14 @@ def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -475,14 +472,14 @@ def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -491,12 +488,12 @@ def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:68080"}) # bad port def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -505,7 +502,7 @@ def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that does't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -514,66 +511,66 @@ def test_http_proxy_connection_timeout(self): "timed out while connecting to HTTP proxy", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - self.assertEqual(client.socket.server_hostname, "overridden") + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when server certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -629,9 +626,12 @@ def test_ssl_without_secure_uri(self): def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", diff --git a/tests/test_uri.py b/tests/test_uri.py index 35b51fa5..3ccf2115 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,11 +1,11 @@ +import os import unittest +from unittest.mock import patch from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * from websockets.uri import Proxy, get_proxy, parse_proxy -from .utils import patch_environ - VALID_URIS = [ ( @@ -255,6 +255,6 @@ def test_parse_proxy_user_info(self): def test_get_proxy(self): for environ, uri, proxy in PROXY_ENVS: - with patch_environ(environ): + with patch.dict(os.environ, environ): with self.subTest(environ=environ, uri=uri): self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index 38938134..7932aae6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -138,22 +138,6 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) -@contextlib.contextmanager -def patch_environ(environ): - backup = {} - for key, value in environ.items(): - backup[key] = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - for key, value in backup.items(): - if value is None: - del os.environ[key] - else: # pragma: no cover - os.environ[key] = value - - @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: